fork(1) download
  1. import java.util.*;
  2. import java.lang.*;
  3. import java.io.*;
  4.  
  5. class Matrix
  6. {
  7. public final int w, h;
  8. public final int[] raw;
  9.  
  10. public Matrix(int w, int h) {
  11. validateSize(w, h);
  12. this.w = w;
  13. this.h = h;
  14. raw = new int[w * h];
  15. }
  16.  
  17. public static Matrix explicit(int w, int h, int... cells) {
  18. if (cells.length != w * h)
  19. throw new IllegalArgumentException(String.format("Expected %d values of %d×%d matrix, got %d.", w*h, w, h, cells.length));
  20. Matrix r = new Matrix(w, h);
  21. System.arraycopy(cells, 0, r.raw, 0, cells.length);
  22. return r;
  23. }
  24.  
  25. public String toString() {
  26. String[] reprs = new String[raw.length];
  27. int maxLen = 0;
  28. for (int i = 0; i < reprs.length; i++) {
  29. reprs[i] = Integer.toString(raw[i]);
  30. maxLen = Math.max(maxLen, reprs[i].length());
  31. }
  32. String fmt = "%" + maxLen + "s";
  33.  
  34. StringBuilder rsb = new StringBuilder();
  35. for (int y = 0, absIdx = 0; y < h; y++) {
  36. for (int x = 0; x < w; x++, absIdx++) {
  37. rsb.append(String.format(fmt, reprs[absIdx])).append(x + 1 < w ? " " : y + 1 < h ? "\n" : "");
  38. }
  39. }
  40. return rsb.toString();
  41. }
  42.  
  43. public static void validateSize(int w, int h) {
  44. if (w <= 0 || h <= 0)
  45. throw new IllegalArgumentException(String.format("Bad matrix size: %d×%d.", w, h));
  46. }
  47.  
  48. public void validateSubmatrix(int x, int y, int w, int h) {
  49. validateSize(w, h);
  50. if (x < 0 || y < 0 || x + w > this.w || y + h > this.h)
  51. throw new IllegalArgumentException(String.format("Bad submatrix of %d×%d: (%d, %d) + %d×%d.", this.w, this.h, x, y, w, h));
  52. }
  53.  
  54. public int absoluteIndex(int x, int y) {
  55. if (x < 0 || y < 0 || x >= w || y >= h)
  56. throw new IllegalArgumentException(String.format("Bad cell of %d×%d: (%d, %d).", w, h, x, y));
  57. return y * w + x;
  58. }
  59.  
  60. static Matrix prepareMultiply(
  61. Matrix a, int ax, int ay, int aw, int ah,
  62. Matrix b, int bx, int by, int bw, int bh,
  63. Matrix r) {
  64.  
  65. a.validateSubmatrix(ax, ay, aw, ah);
  66. b.validateSubmatrix(bx, by, bw, bh);
  67. if (aw != bh)
  68. throw new IllegalArgumentException(String.format("Matrices of %d×%d and %d×%d cannot be multiplied: number of columns of the first must match the number of rows of the second.", aw, ah, bw, bh));
  69.  
  70. if (r == null) {
  71. r = new Matrix(bw, ah);
  72. } else if (r.w != bw || r.h != ah) {
  73. throw new IllegalArgumentException(String.format("Bad resulting matrix size: expected %d×%d, got %d×%d.", bw, ah, r.w, r.h));
  74. }
  75. return r;
  76. }
  77.  
  78. public static Matrix multiply_AiVer(
  79. Matrix a, int ax, int ay, int aw, int ah,
  80. Matrix b, int bx, int by, int bw, int bh,
  81. Matrix r) {
  82.  
  83. r = prepareMultiply(a, ax, ay, aw, ah, b, bx, by, bw, bh, r);
  84.  
  85. int aRowStartAbsIdx = a.absoluteIndex(ax, ay); // points to the beginning of the current row of A, i. e. A[0, y].
  86. int bColStartAbsIdx = b.absoluteIndex(bx, by); // points to the beginning of the current column of B, i. e. B[x, 0].
  87.  
  88. // result cell (x, y) (pointed by rAbsIdx) gets the dot product of current row of A and current column of B.
  89. for (int y = 0, rAbsIdx = 0; y < r.h; y++, aRowStartAbsIdx += a.w, bColStartAbsIdx -= r.w) {
  90. for (int x = 0; x < r.w; x++, rAbsIdx++, bColStartAbsIdx++) {
  91. int dot = 0;
  92. for (int i = 0, aAbsIdx = aRowStartAbsIdx, bAbsIdx = bColStartAbsIdx; i < aw; i++, aAbsIdx++, bAbsIdx += b.w) {
  93. dot += a.raw[aAbsIdx] * b.raw[bAbsIdx];
  94. }
  95. r.raw[rAbsIdx] = dot;
  96. }
  97. }
  98. return r;
  99. }
  100.  
  101. public Matrix multiply_AiVer(Matrix b) {
  102. return multiply_AiVer(this, 0, 0, w, h, b, 0, 0, b.w, b.h, null);
  103. }
  104.  
  105. int get(int x, int y) {
  106. return raw[y * w + x];
  107. }
  108.  
  109. void set(int x, int y, int value) {
  110. raw[y * w + x] = value;
  111. }
  112.  
  113. public static Matrix multiply_GsVer(
  114. Matrix a, int ax, int ay, int aw, int ah,
  115. Matrix b, int bx, int by, int bw, int bh,
  116. Matrix r) {
  117.  
  118. r = prepareMultiply(a, ax, ay, aw, ah, b, bx, by, bw, bh, r);
  119.  
  120. for (int y = 0; y < r.h; y++) {
  121. for (int x = 0; x < r.w; x++) {
  122. int dot = 0;
  123. for (int i = 0; i < aw; i++) {
  124. dot += a.get(i, y) * b.get(x, i);
  125. }
  126. r.set(x, y, dot);
  127. }
  128. }
  129. return r;
  130. }
  131.  
  132. public Matrix multiply_GsVer(Matrix b) {
  133. return multiply_GsVer(this, 0, 0, w, h, b, 0, 0, b.w, b.h, null);
  134. }
  135.  
  136. @Override
  137. public boolean equals(Object other) {
  138. if (this == other) return true;
  139. if (this.getClass() != other.getClass()) return false;
  140. Matrix om = (Matrix) other;
  141. return w == om.w && h == om.h && Arrays.equals(raw, om.raw);
  142. }
  143.  
  144. // just for resulting matrices to not be subjected to dead store elimination.
  145. public int elementsSum() {
  146. int sum = 0;
  147. for (int i = 0; i < raw.length; i++) {
  148. sum += raw[i];
  149. }
  150. return sum;
  151. }
  152. }
  153.  
  154. class Ideone
  155. {
  156. static void test() {
  157. // www.mathwarehouse.com/algebra/matrix/images/product-matrix2.png
  158. Matrix
  159. a = Matrix.explicit(4, 2,
  160. 1, 4, 6, 10,
  161. 2, 7, 5, 3),
  162.  
  163. b = Matrix.explicit(3, 4,
  164. 1, 4, 6,
  165. 2, 7, 5,
  166. 9, 0, 11,
  167. 3, 1, 0),
  168.  
  169. expected = Matrix.explicit(3, 2,
  170. 93, 42, 92,
  171. 70, 60, 102),
  172.  
  173. abProd_AiVer = a.multiply_AiVer(b),
  174. abProd_GsVer = a.multiply_GsVer(b);
  175.  
  176. if (!abProd_AiVer.equals(expected)) throw new RuntimeException(String.format("bad Matrix.multiply_AiVer: got\n%s\nexpected\n%s.", abProd_AiVer, expected));
  177. if (!abProd_GsVer.equals(expected)) throw new RuntimeException(String.format("bad Matrix.multiply_GsVer: got\n%s\nexpected\n%s.", abProd_GsVer, expected));
  178. }
  179.  
  180. static void benchmark() {
  181. Matrix a = new Matrix(1000, 500);
  182. Matrix b = new Matrix(600, 1000);
  183. Matrix r = new Matrix(600, 500); // so allocation won't happen during time measurement
  184.  
  185. long start_AiVer = System.nanoTime();
  186. Matrix.multiply_AiVer(a, 0, 0, a.w, a.h, b, 0, 0, b.w, b.h, r);
  187. double time_AiVer = (System.nanoTime() - start_AiVer) * 1e-9;
  188. System.out.println(String.format("Multiplication using absolute indices: %.2f s", time_AiVer, r.elementsSum()));
  189.  
  190. long start_GsVer = System.nanoTime();
  191. Matrix.multiply_GsVer(a, 0, 0, a.w, a.h, b, 0, 0, b.w, b.h, r);
  192. double time_GsVer = (System.nanoTime() - start_GsVer) * 1e-9;
  193. System.out.println(String.format("Multiplication using .get()/.set(): %.2f s (%s)", time_GsVer, judgement(time_GsVer, time_AiVer), r.elementsSum()));
  194. }
  195.  
  196. static String judgement(double time, double refTime) {
  197. final double absThreshold = 0.1, relThreshold = .05;
  198.  
  199. return time <= absThreshold || refTime <= absThreshold ?
  200. String.format("can't judge, min. required time is %.1f s", absThreshold) :
  201. Math.max(time, refTime) > (1 + relThreshold) * Math.min(time, refTime) ?
  202. String.format("%.0f%% %s", Math.abs(time / refTime - 1) * 100, time < refTime ? "faster" : "slower") :
  203. String.format("no observable difference, min. %.0f%% required", relThreshold * 100);
  204. }
  205.  
  206. public static void main (String[] args)
  207. {
  208. try {
  209. test();
  210. benchmark();
  211. } catch (Exception e) {
  212. System.out.println(e);
  213. }
  214. }
  215. }
Success #stdin #stdout 1.83s 2184192KB
stdin
Standard input is empty
stdout
Multiplication using absolute indices: 0.88 s
Multiplication using .get()/.set(): 0.90 s (no observable difference, min. 5% required)