fork download
  1. import java.util.*;
  2. import java.lang.*;
  3. import java.io.*;
  4.  
  5. class Throw {
  6. public static void badArg(String fmt, Object... args) {
  7. throw new IllegalArgumentException(String.format(fmt, args));
  8. }
  9. }
  10.  
  11. class Matrix {
  12. public final int w, h;
  13. public final int[] raw;
  14.  
  15. public Matrix(int w, int h) {
  16. validateSize(w, h);
  17. this.w = w;
  18. this.h = h;
  19. raw = new int[w * h];
  20. }
  21.  
  22. public static Matrix explicit(int w, int h, int... cells) {
  23. if (cells.length != w * h) Throw.badArg("Expected %d values of %d×%d matrix, got %d.", w*h, w, h, cells.length);
  24. Matrix r = new Matrix(w, h);
  25. System.arraycopy(cells, 0, r.raw, 0, cells.length);
  26. return r;
  27. }
  28.  
  29. @FunctionalInterface
  30. public interface ValueGetter {
  31. int getValue(int x, int y);
  32. }
  33.  
  34. public Matrix(int w, int h, ValueGetter getValue) {
  35. this(w, h);
  36. fill(getValue);
  37. }
  38.  
  39. public String toString() {
  40. String[] reprs = new String[raw.length];
  41. int maxLen = 0;
  42. for (int i = 0; i < reprs.length; i++) {
  43. reprs[i] = Integer.toString(raw[i]);
  44. maxLen = Math.max(maxLen, reprs[i].length());
  45. }
  46. String fmt = "%" + maxLen + "s";
  47.  
  48. StringBuilder rsb = new StringBuilder();
  49. for (int y = 0, ai = 0; y < h; y++)
  50. for (int x = 0; x < w; x++, ai++)
  51. rsb.append(String.format(fmt, reprs[ai])).append(x + 1 < w ? " " : y + 1 < h ? "\n" : "");
  52. return rsb.toString();
  53. }
  54.  
  55. public void fill(ValueGetter getValue)
  56. {
  57. for (int y = 0, ai = 0; y < h; y++)
  58. for (int x = 0; x < w; x++, ai++)
  59. raw[ai] = getValue.getValue(x, y);
  60. }
  61.  
  62. private static void moveSubmatrix(Matrix a, int ax, int ay, Matrix b, int bx, int by, int w, int h) {
  63. for (int y = 0, aAi = a.absIdx(ax, ay), bAi = b.absIdx(bx, by); y < h; y++, aAi += a.w, bAi += b.w)
  64. System.arraycopy(a.raw, aAi, b.raw, bAi, w);
  65. }
  66.  
  67. public Matrix submatrix(int sx, int sy, int sw, int sh) {
  68. validateSubmatrix(this.w, this.h, sx, sy, sw, sh);
  69. Matrix r = new Matrix(sw, sh);
  70. moveSubmatrix(this, sx, sy, r, 0, 0, sw, sh);
  71. return r;
  72. }
  73.  
  74. public Matrix expand(int filler, int left, int right, int top, int bottom) {
  75. if (left < 0 || right < 0 || top < 0 || bottom < 0)
  76. Throw.badArg("Bad expand(): (%d, %d, %d, %d).", left, right, top, bottom);
  77.  
  78. Matrix r = new Matrix(w + left + right, h + top + bottom);
  79. for (int y = 0, rAi = 0, thisAi = 0; y < r.h; y++)
  80. for (int x = 0; x < r.w; x++, rAi++)
  81. r.raw[rAi] = x < left || x >= left + w || y < top || y >= top + h ? filler : raw[thisAi++];
  82. return r;
  83. }
  84.  
  85. int absIdx(int x, int y) {
  86. return y * w + x;
  87. }
  88.  
  89. public static void validateSize(int w, int h) {
  90. if (w <= 0 || h <= 0) Throw.badArg("Bad matrix size: %d×%d.", w, h);
  91. }
  92.  
  93. public static void validateSubmatrix(int thisw, int thish, int x, int y, int w, int h) {
  94. validateSize(w, h);
  95. if (x < 0 || y < 0 || x + w > thisw || y + h > thish)
  96. Throw.badArg("Bad submatrix of %d×%d: (%d, %d) + %d×%d.", thisw, thish, x, y, w, h);
  97. }
  98.  
  99. public static void validateAddLike(int aw, int ah, int bw, int bh, int rw, int rh, String op) {
  100. if (aw != bw || ah != bh)
  101. Throw.badArg("Bad matrices for %s: %d×%d and %d×%d, size must be equal.", op, aw, ah, bw, bh);
  102. if (rw != aw || rh != ah)
  103. Throw.badArg("Bad %s result size: %d×%d, expected %d×%d.", op, rw, rh, aw, ah);
  104. }
  105.  
  106. public static void add(
  107. Matrix a, int ax, int ay, int aw, int ah,
  108. Matrix b, int bx, int by, int bw, int bh,
  109. Matrix r, int rx, int ry, int rw, int rh) {
  110.  
  111. validateSubmatrix(a.w, a.h, ax, ay, aw, ah);
  112. validateSubmatrix(b.w, b.h, bx, by, bw, bh);
  113. validateSubmatrix(r.w, r.h, rx, ry, rw, rh);
  114. validateAddLike(aw, ah, bw, bh, rw, rh, "addition");
  115. trustedAdd(a, ax, ay, b, bx, by, r, rx, ry, rw, rh);
  116. }
  117.  
  118. static void trustedAdd(
  119. Matrix a, int ax, int ay,
  120. Matrix b, int bx, int by,
  121. Matrix r, int rx, int ry, int w, int h) {
  122.  
  123. int aAi = a.absIdx(ax, ay), bAi = b.absIdx(bx, by), rAi = r.absIdx(rx, ry);
  124. for (int y = 0; y < h; y++, aAi += a.w - w, bAi += b.w - w, rAi += r.w - w)
  125. for (int x = 0; x < w; x++, aAi++, bAi++, rAi++)
  126. r.raw[rAi] = a.raw[aAi] + b.raw[bAi];
  127. }
  128.  
  129. public Matrix plus(Matrix b) {
  130. Matrix r = new Matrix(w, h);
  131. add(this, 0, 0, w, h, b, 0, 0, b.w, b.h, r, 0, 0, r.w, r.h);
  132. return r;
  133. }
  134.  
  135. // Mutatis mutandis of add()
  136. public static void subtract(
  137. Matrix a, int ax, int ay, int aw, int ah,
  138. Matrix b, int bx, int by, int bw, int bh,
  139. Matrix r, int rx, int ry, int rw, int rh) {
  140.  
  141. validateSubmatrix(a.w, a.h, ax, ay, aw, ah);
  142. validateSubmatrix(b.w, b.h, bx, by, bw, bh);
  143. validateSubmatrix(r.w, r.h, rx, ry, rw, rh);
  144. validateAddLike(aw, ah, bw, bh, rw, rh, "subtraction");
  145. trustedSubtract(a, ax, ay, b, bx, by, r, rx, ry, rw, rh);
  146. }
  147.  
  148. static void trustedSubtract(
  149. Matrix a, int ax, int ay,
  150. Matrix b, int bx, int by,
  151. Matrix r, int rx, int ry, int w, int h) {
  152.  
  153. int aAi = a.absIdx(ax, ay), bAi = b.absIdx(bx, by), rAi = r.absIdx(rx, ry);
  154. for (int y = 0; y < h; y++, aAi += a.w - w, bAi += b.w - w, rAi += r.w - w)
  155. for (int x = 0; x < w; x++, aAi++, bAi++, rAi++)
  156. r.raw[rAi] = a.raw[aAi] - b.raw[bAi];
  157. }
  158.  
  159. public Matrix minus(Matrix b) {
  160. Matrix r = new Matrix(w, h);
  161. subtract(this, 0, 0, w, h, b, 0, 0, b.w, b.h, r, 0, 0, r.w, r.h);
  162. return r;
  163. }
  164.  
  165. public static void validateMultiply(int aw, int ah, int bw, int bh, int rw, int rh) {
  166. if (aw != bh)
  167. Throw.badArg("Matrices of %d×%d and %d×%d cannot be multiplied: number of columns of the first must match the number of rows of the second.", aw, ah, bw, bh);
  168.  
  169. if (rw != bw || rh != ah)
  170. Throw.badArg("Bad resulting matrix size: expected %d×%d, got %d×%d.", bw, ah, rw, rh);
  171. }
  172.  
  173. public static void multiplyPlain(
  174. Matrix a, int ax, int ay, int aw, int ah,
  175. Matrix b, int bx, int by, int bw, int bh,
  176. Matrix r, int rx, int ry, int rw, int rh) {
  177.  
  178. validateSubmatrix(a.w, a.h, ax, ay, aw, ah);
  179. validateSubmatrix(b.w, b.h, bx, by, bw, bh);
  180. validateSubmatrix(r.w, r.h, rx, ry, rw, rh);
  181. validateMultiply(aw, ah, bw, bh, rw, rh);
  182.  
  183. int aRowAi = a.absIdx(ax, ay), bColAi = b.absIdx(bx, by), rAi = r.absIdx(rx, ry);
  184. for (int y = 0; y < rh; y++, aRowAi += a.w, bColAi -= rw, rAi += r.w - rw)
  185. for (int x = 0; x < rw; x++, rAi++, bColAi++) {
  186. int dot = 0;
  187. for (int i = 0, aAi = aRowAi, bAi = bColAi; i < aw; i++, aAi++, bAi += b.w)
  188. dot += a.raw[aAi] * b.raw[bAi];
  189. r.raw[rAi] = dot;
  190. }
  191. }
  192.  
  193. public Matrix multiplyPlain(Matrix b) {
  194. Matrix r = new Matrix(b.w, h);
  195. multiplyPlain(this, 0, 0, w, h, b, 0, 0, b.w, b.h, r, 0, 0, r.w, r.h);
  196. return r;
  197. }
  198.  
  199. public int get(int x, int y) {
  200. return raw[y * w + x];
  201. }
  202.  
  203. public void set(int x, int y, int value) {
  204. raw[y * w + x] = value;
  205. }
  206.  
  207. private Matrix strassenPaddedBlock(int sx, int sy, int rw, int rh, int takew, int takeh) {
  208. Matrix r = new Matrix(rw, rh);
  209. moveSubmatrix(this, sx, sy, r, 0, 0, takew, takeh);
  210. return r;
  211. }
  212.  
  213. private static boolean shouldResortToPlainMultiply(int bw, int ah, int aw) {
  214. final int THRESHOLD = 8;
  215. return bw <= THRESHOLD || ah <= THRESHOLD || aw <= THRESHOLD;
  216. }
  217.  
  218. public static void multiplyStrassen_StraightVer(
  219. Matrix a, int ax, int ay, int aw, int ah,
  220. Matrix b, int bx, int by, int bw, int bh,
  221. Matrix r, int rx, int ry, int rw, int rh) {
  222.  
  223. if (shouldResortToPlainMultiply(bw, ah, aw)) {
  224. multiplyPlain(a, ax, ay, aw, ah, b, bx, by, bw, bh, r, rx, ry, rw, rh);
  225. return;
  226. }
  227.  
  228. validateSubmatrix(a.w, a.h, ax, ay, aw, ah);
  229. validateSubmatrix(b.w, b.h, bx, by, bw, bh);
  230. validateSubmatrix(r.w, r.h, rx, ry, rw, rh);
  231. validateMultiply(aw, ah, bw, bh, rw, rh);
  232.  
  233. int halfAw = (aw + 1) / 2, halfAh = (ah + 1) / 2,
  234. halfBw = (bw + 1) / 2, halfBh = (bh + 1) / 2,
  235. halfRw = (rw + 1) / 2, halfRh = (rh + 1) / 2;
  236.  
  237. Matrix
  238. a11 = a.submatrix(ax, ay, halfAw, halfAh),
  239. a21 = a.strassenPaddedBlock(ax, ay + halfAh, halfAw, halfAh, halfAw, ah - halfAh),
  240. a12 = a.strassenPaddedBlock(ax + halfAw, ay, halfAw, halfAh, aw - halfAw, halfAh),
  241. a22 = a.strassenPaddedBlock(ax + halfAw, ay + halfAh, halfAw, halfAh, aw - halfAw, ah - halfAh),
  242.  
  243. b11 = b.submatrix(bx, by, halfBw, halfBh),
  244. b21 = b.strassenPaddedBlock(bx, by + halfBh, halfBw, halfBh, halfBw, bh - halfBh),
  245. b12 = b.strassenPaddedBlock(bx + halfBw, by, halfBw, halfBh, bw - halfBw, halfBh),
  246. b22 = b.strassenPaddedBlock(bx + halfBw, by + halfBh, halfBw, halfBh, bw - halfBw, bh - halfBh),
  247.  
  248. m1 = a11.plus(a22).multiplyStrassen_StraightVer(b11.plus(b22)),
  249. m2 = a21.plus(a22).multiplyStrassen_StraightVer(b11),
  250. m3 = a11.multiplyStrassen_StraightVer(b12.minus(b22)),
  251. m4 = a22.multiplyStrassen_StraightVer(b21.minus(b11)),
  252. m5 = a11.plus(a12).multiplyStrassen_StraightVer(b22),
  253. m6 = a21.minus(a11).multiplyStrassen_StraightVer(b11.plus(b12)),
  254. m7 = a12.minus(a22).multiplyStrassen_StraightVer(b21.plus(b22)),
  255.  
  256. c11 = m1.plus(m4).minus(m5).plus(m7),
  257. c12 = m3.plus(m5),
  258. c21 = m2.plus(m4),
  259. c22 = m1.minus(m2).plus(m3).plus(m6);
  260.  
  261. moveSubmatrix(c11, 0, 0, r, rx, ry, halfRw, halfRh);
  262. moveSubmatrix(c21, 0, 0, r, rx, ry + halfRh, halfRw, rh - halfRh);
  263. moveSubmatrix(c12, 0, 0, r, rx + halfRw, ry, rw - halfRw, halfRh);
  264. moveSubmatrix(c22, 0, 0, r, rx + halfRw, ry + halfRh, rw - halfRw, rh - halfRh);
  265. }
  266.  
  267. public Matrix multiplyStrassen_StraightVer(Matrix b) {
  268. Matrix r = new Matrix(b.w, h);
  269. multiplyStrassen_StraightVer(this, 0, 0, w, h, b, 0, 0, b.w, b.h, r, 0, 0, r.w, r.h);
  270. return r;
  271. }
  272.  
  273. private static class MatrixView {
  274. public Matrix m;
  275. public int x, y, w, h;
  276.  
  277. public MatrixView(int w, int h) {
  278. this(new Matrix(w, h), 0, 0, w, h);
  279. }
  280.  
  281. public MatrixView(Matrix m) {
  282. this(m, 0, 0, m.w, m.h);
  283. }
  284.  
  285. public MatrixView(Matrix m, int x, int y, int w, int h) {
  286. this.m = m;
  287. this.x = x;
  288. this.y = y;
  289. this.w = w;
  290. this.h = h;
  291. }
  292.  
  293. public void add(MatrixView b, MatrixView r) {
  294. Matrix.add(m, x, y, w, h, b.m, b.x, b.y, b.w, b.h, r.m, r.x, r.y, r.w, r.h);
  295. }
  296.  
  297. public MatrixView plus(MatrixView b) {
  298. MatrixView r = new MatrixView(w, h);
  299. add(b, r);
  300. return r;
  301. }
  302.  
  303. public void subtract(MatrixView b, MatrixView r) {
  304. Matrix.subtract(m, x, y, w, h, b.m, b.x, b.y, b.w, b.h, r.m, r.x, r.y, r.w, r.h);
  305. }
  306.  
  307. public MatrixView minus(MatrixView b) {
  308. MatrixView r = new MatrixView(w, h);
  309. subtract(b, r);
  310. return r;
  311. }
  312.  
  313. public void multiplyStrassen_ViewVer(MatrixView b, MatrixView r) {
  314. Matrix.multiplyStrassen_ViewVer(m, x, y, w, h, b.m, b.x, b.y, b.w, b.h, r.m, r.x, r.y, r.w, r.h);
  315. }
  316.  
  317. public MatrixView multiplyStrassen_ViewVer(MatrixView b) {
  318. MatrixView r = new MatrixView(b.w, h);
  319. multiplyStrassen_ViewVer(b, r);
  320. return r;
  321. }
  322.  
  323. public void multiplyStrassen_Pow2ViewVer(MatrixView b, MatrixView r) {
  324. Matrix.multiplyStrassen_Pow2ViewVer(m, x, y, w, h, b.m, b.x, b.y, b.w, b.h, r.m, r.x, r.y, r.w, r.h);
  325. }
  326.  
  327. public MatrixView multiplyStrassen_Pow2ViewVer(MatrixView b) {
  328. MatrixView r = new MatrixView(b.w, h);
  329. multiplyStrassen_Pow2ViewVer(b, r);
  330. return r;
  331. }
  332. }
  333.  
  334. private MatrixView strassenPaddedBlockView(int x, int y, int w, int h) {
  335. return x + w <= this.w && y + h <= this.h ?
  336. new MatrixView(this, x, y, w, h) :
  337. new MatrixView(this.strassenPaddedBlock(
  338. x, y, w, h, Math.min(w, this.w - x), Math.min(h, this.h - y)));
  339. }
  340.  
  341. private MatrixView strassenTargetBlockView(int x, int y, int w, int h) {
  342. return x + w <= this.w && y + h <= this.h ?
  343. new MatrixView(this, x, y, w, h) :
  344. new MatrixView(w, h);
  345. }
  346.  
  347. public static void multiplyStrassen_ViewVer(
  348. Matrix a, int ax, int ay, int aw, int ah,
  349. Matrix b, int bx, int by, int bw, int bh,
  350. Matrix r, int rx, int ry, int rw, int rh) {
  351.  
  352. if (shouldResortToPlainMultiply(bw, ah, aw)) {
  353. multiplyPlain(a, ax, ay, aw, ah, b, bx, by, bw, bh, r, rx, ry, rw, rh);
  354. return;
  355. }
  356.  
  357. validateSubmatrix(a.w, a.h, ax, ay, aw, ah);
  358. validateSubmatrix(b.w, b.h, bx, by, bw, bh);
  359. validateSubmatrix(r.w, r.h, rx, ry, rw, rh);
  360. validateMultiply(aw, ah, bw, bh, rw, rh);
  361.  
  362. int halfAw = (aw + 1) / 2, halfAh = (ah + 1) / 2,
  363. halfBw = (bw + 1) / 2, halfBh = (bh + 1) / 2,
  364. halfRw = (rw + 1) / 2, halfRh = (rh + 1) / 2;
  365.  
  366. MatrixView
  367. a11 = a.strassenPaddedBlockView(ax, ay, halfAw, halfAh),
  368. a21 = a.strassenPaddedBlockView(ax, ay + halfAh, halfAw, halfAh),
  369. a12 = a.strassenPaddedBlockView(ax + halfAw, ay, halfAw, halfAh),
  370. a22 = a.strassenPaddedBlockView(ax + halfAw, ay + halfAh, halfAw, halfAh),
  371.  
  372. b11 = b.strassenPaddedBlockView(bx, by, halfBw, halfBh),
  373. b21 = b.strassenPaddedBlockView(bx, by + halfBh, halfBw, halfBh),
  374. b12 = b.strassenPaddedBlockView(bx + halfBw, by, halfBw, halfBh),
  375. b22 = b.strassenPaddedBlockView(bx + halfBw, by + halfBh, halfBw, halfBh),
  376.  
  377. m1 = a11.plus(a22).multiplyStrassen_ViewVer(b11.plus(b22)),
  378. m2 = a21.plus(a22).multiplyStrassen_ViewVer(b11),
  379. m3 = a11.multiplyStrassen_ViewVer(b12.minus(b22)),
  380. m4 = a22.multiplyStrassen_ViewVer(b21.minus(b11)),
  381. m5 = a11.plus(a12).multiplyStrassen_ViewVer(b22),
  382. m6 = a21.minus(a11).multiplyStrassen_ViewVer(b11.plus(b12)),
  383. m7 = a12.minus(a22).multiplyStrassen_ViewVer(b21.plus(b22)),
  384.  
  385. c11 = new MatrixView(r, rx, ry, halfRw, halfRh),
  386. c21 = r.strassenTargetBlockView(rx, ry + halfRh, halfRw, halfRh),
  387. c12 = r.strassenTargetBlockView(rx + halfRw, ry, halfRw, halfRh),
  388. c22 = r.strassenTargetBlockView(rx + halfRw, ry + halfRh, halfRw, halfRh);
  389.  
  390. m1.add(m4, c11); c11.subtract(m5, c11); c11.add(m7, c11);
  391. m3.add(m5, c12);
  392. m2.add(m4, c21);
  393. m1.subtract(m2, c22); c22.add(m3, c22); c22.add(m6, c22);
  394.  
  395. if (c21.m != r) moveSubmatrix(c21.m, 0, 0, r, rx, ry + halfRh, halfRw, rh - halfRh);
  396. if (c12.m != r) moveSubmatrix(c12.m, 0, 0, r, rx + halfRw, ry, rw - halfRw, halfRh);
  397. if (c22.m != r) moveSubmatrix(c22.m, 0, 0, r, rx + halfRw, ry + halfRh, rw - halfRw, rh - halfRh);
  398. }
  399.  
  400. public Matrix multiplyStrassen_ViewVer(Matrix b) {
  401. Matrix r = new Matrix(b.w, h);
  402. multiplyStrassen_ViewVer(this, 0, 0, w, h, b, 0, 0, b.w, b.h, r, 0, 0, r.w, r.h);
  403. return r;
  404. }
  405.  
  406. static void trustedMultiplyStrassen_Pow2UnrolledVer(
  407. Matrix a, int ax, int ay, int aw, int ah,
  408. Matrix b, int bx, int by, int bw, int bh,
  409. Matrix r, int rx, int ry, int rw, int rh) {
  410.  
  411. if (shouldResortToPlainMultiply(bw, ah, aw)) {
  412. multiplyPlain(a, ax, ay, aw, ah, b, bx, by, bw, bh, r, rx, ry, rw, rh);
  413. return;
  414. }
  415.  
  416. int halfAw = aw / 2, halfAh = ah / 2,
  417. halfBw = bw / 2, halfBh = bh / 2,
  418. halfRw = rw / 2, halfRh = rh / 2;
  419.  
  420. Matrix
  421. tsuma = new Matrix(halfAw, halfAh),
  422. tsumb = new Matrix(halfBw, halfBh),
  423. m1 = new Matrix(halfRw, halfRh),
  424. m2 = new Matrix(halfRw, halfRh),
  425. m3 = new Matrix(halfRw, halfRh),
  426. m4 = new Matrix(halfRw, halfRh),
  427. m5 = new Matrix(halfRw, halfRh),
  428. m6 = new Matrix(halfRw, halfRh),
  429. m7 = new Matrix(halfRw, halfRh);
  430.  
  431. // tsuma = a11 + a22, tsumb = b11 + b22, m1 = (a11 + a22) * (b11 + b22)
  432. trustedAdd(
  433. a, ax, ay,
  434. a, ax + halfAw, ay + halfAh,
  435. tsuma, 0, 0, halfAw, halfAh);
  436. trustedAdd(
  437. b, bx, by,
  438. b, bx + halfBw, by + halfBh,
  439. tsumb, 0, 0, halfBw, halfBh);
  440. trustedMultiplyStrassen_Pow2UnrolledVer(
  441. tsuma, 0, 0, halfAw, halfAh,
  442. tsumb, 0, 0, halfBw, halfBh,
  443. m1, 0, 0, halfRw, halfRh);
  444.  
  445. // tsuma = a21 + a22, m2 = (a21 + a22) * b11
  446. trustedAdd(
  447. a, ax, ay + halfAh,
  448. a, ax + halfAw, ay + halfAh,
  449. tsuma, 0, 0, halfAw, halfAh);
  450. trustedMultiplyStrassen_Pow2UnrolledVer(
  451. tsuma, 0, 0, halfAw, halfAh,
  452. b, bx, by, halfBw, halfBh,
  453. m2, 0, 0, halfRw, halfRh);
  454.  
  455. // tsumb = b12 - b22, m3 = a11 * (b12 - b22)
  456. trustedSubtract(
  457. b, bx + halfBw, by,
  458. b, bx + halfBw, by + halfBh,
  459. tsumb, 0, 0, halfBw, halfBh);
  460. trustedMultiplyStrassen_Pow2UnrolledVer(
  461. a, ax, ay, halfAw, halfAh,
  462. tsumb, 0, 0, halfBw, halfBh,
  463. m3, 0, 0, halfRw, halfRh);
  464.  
  465. // tsumb = b21 - b11, m4 = a22 * (b21 - b11)
  466. trustedSubtract(
  467. b, bx, by + halfBh,
  468. b, bx, by,
  469. tsumb, 0, 0, halfBw, halfBh);
  470. trustedMultiplyStrassen_Pow2UnrolledVer(
  471. a, ax + halfAw, ay + halfAh, halfAw, halfAh,
  472. tsumb, 0, 0, halfBw, halfBh,
  473. m4, 0, 0, halfRw, halfRh);
  474.  
  475. // tsuma = a11 + a12, m5 = (a11 + a12) * b22
  476. trustedAdd(
  477. a, ax, ay,
  478. a, ax + halfAw, ay,
  479. tsuma, 0, 0, halfAw, halfAh);
  480. trustedMultiplyStrassen_Pow2UnrolledVer(
  481. tsuma, 0, 0, halfAw, halfAh,
  482. b, bx + halfBw, by + halfBh, halfBw, halfBh,
  483. m5, 0, 0, halfRw, halfRh);
  484.  
  485. // tsuma = a21 - a11, tsumb = b11 + b12, m6 = (a21 - a11) * (b11 + b12)
  486. trustedSubtract(
  487. a, ax, ay + halfAh,
  488. a, ax, ay,
  489. tsuma, 0, 0, halfAw, halfAh);
  490. trustedAdd(
  491. b, bx, by,
  492. b, bx + halfBw, by,
  493. tsumb, 0, 0, halfBw, halfBh);
  494. trustedMultiplyStrassen_Pow2UnrolledVer(
  495. tsuma, 0, 0, halfAw, halfAh,
  496. tsumb, 0, 0, halfBw, halfBh,
  497. m6, 0, 0, halfRw, halfRh);
  498.  
  499. // tsuma = a12 - a22, tsumb = b21 + b22, m7 = (a12 - a22) * (b21 + b22)
  500. trustedSubtract(
  501. a, ax + halfAw, ay,
  502. a, ax + halfAw, ay + halfAh,
  503. tsuma, 0, 0, halfAw, halfAh);
  504. trustedAdd(
  505. b, bx, by + halfBh,
  506. b, bx + halfBw, by + halfBh,
  507. tsumb, 0, 0, halfBw, halfBh);
  508. trustedMultiplyStrassen_Pow2UnrolledVer(
  509. tsuma, 0, 0, halfAw, halfAh,
  510. tsumb, 0, 0, halfBw, halfBh,
  511. m7, 0, 0, halfRw, halfRh);
  512.  
  513. // c11 = m1 + m4 - m5 + m7
  514. trustedAdd(
  515. m1, 0, 0,
  516. m4, 0, 0,
  517. r, rx, ry, halfRw, halfRh);
  518. trustedSubtract(
  519. r, rx, ry,
  520. m5, 0, 0,
  521. r, rx, ry, halfRw, halfRh);
  522. trustedAdd(
  523. r, rx, ry,
  524. m7, 0, 0,
  525. r, rx, ry, halfRw, halfRh);
  526.  
  527. // c12 = m3 + m5
  528. trustedAdd(
  529. m3, 0, 0,
  530. m5, 0, 0,
  531. r, rx + halfRw, ry, halfRw, halfRh);
  532.  
  533. // c21 = m2 + m4
  534. trustedAdd(
  535. m2, 0, 0,
  536. m4, 0, 0,
  537. r, rx, ry + halfRh, halfRw, halfRh);
  538.  
  539. // c22 = m1 - m2 + m3 + m6
  540. trustedSubtract(
  541. m1, 0, 0,
  542. m2, 0, 0,
  543. r, rx + halfRw, ry + halfRh, halfRw, halfRh);
  544. trustedAdd(
  545. r, rx + halfRw, ry + halfRh,
  546. m3, 0, 0,
  547. r, rx + halfRw, ry + halfRh, halfRw, halfRh);
  548. trustedAdd(
  549. r, rx + halfRw, ry + halfRh,
  550. m6, 0, 0,
  551. r, rx + halfRw, ry + halfRh, halfRw, halfRh);
  552. }
  553.  
  554. @FunctionalInterface
  555. interface GenericMultiply {
  556. void multiply(
  557. Matrix a, int ax, int ay, int aw, int ah,
  558. Matrix b, int bx, int by, int bw, int bh,
  559. Matrix r, int rx, int ry, int rw, int rh);
  560. }
  561.  
  562. static boolean isPow2(int x) {
  563. if (x <= 0) Throw.badArg("isPow2 requires positive value, got %d.", x);
  564. return (x & (x - 1)) == 0;
  565. }
  566.  
  567. static int ceilPow2(int x) {
  568. if (x <= 0) Throw.badArg("ceilPow2 requires positive value, got %d.", x);
  569. x -= 1;
  570. for (int shlp = 0; shlp <= 4; shlp++)
  571. x |= x >>> (1 << shlp);
  572. return x + 1;
  573. }
  574.  
  575. public static void adaptTrustedMultiplyPow2(
  576. GenericMultiply multiplyPow2,
  577. Matrix a, int ax, int ay, int aw, int ah,
  578. Matrix b, int bx, int by, int bw, int bh,
  579. Matrix r, int rx, int ry, int rw, int rh) {
  580.  
  581. validateSubmatrix(a.w, a.h, ax, ay, aw, ah);
  582. validateSubmatrix(b.w, b.h, bx, by, bw, bh);
  583. validateSubmatrix(r.w, r.h, rx, ry, rw, rh);
  584. validateMultiply(aw, ah, bw, bh, rw, rh);
  585.  
  586. boolean altered = false;
  587.  
  588. if (!isPow2(aw) || !isPow2(ah)) {
  589. Matrix na = new Matrix(ceilPow2(aw), ceilPow2(ah));
  590. moveSubmatrix(a, ax, ay, na, 0, 0, aw, ah);
  591. a = na; ax = 0; ay = 0; aw = na.w; ah = na.h; altered = true;
  592. }
  593.  
  594. if (!isPow2(bw) || !isPow2(bh)) {
  595. Matrix nb = new Matrix(ceilPow2(bw), ceilPow2(bh));
  596. moveSubmatrix(b, bx, by, nb, 0, 0, bw, bh);
  597. b = nb; bx = 0; by = 0; bw = nb.w; bh = nb.h; altered = true;
  598. }
  599.  
  600. int origRx = rx, origRy = ry, origRw = rw, origRh = rh;
  601. Matrix origR = r;
  602. if (altered) {
  603. r = new Matrix(bw, ah);
  604. rx = 0; ry = 0; rw = r.w; rh = r.h;
  605. }
  606.  
  607. multiplyPow2.multiply(a, ax, ay, aw, ah, b, bx, by, bw, bh, r, rx, ry, rw, rh);
  608. if (altered) {
  609. moveSubmatrix(r, 0, 0, origR, origRx, origRy, origRw, origRh);
  610. }
  611. }
  612.  
  613. public static void multiplyStrassen_Pow2UnrolledVer(
  614. Matrix a, int ax, int ay, int aw, int ah,
  615. Matrix b, int bx, int by, int bw, int bh,
  616. Matrix r, int rx, int ry, int rw, int rh) {
  617.  
  618. adaptTrustedMultiplyPow2(Matrix::trustedMultiplyStrassen_Pow2UnrolledVer,
  619. a, ax, ay, aw, ah, b, bx, by, bw, bh, r, rx, ry, rw, rh);
  620. }
  621.  
  622. public Matrix multiplyStrassen_Pow2UnrolledVer(Matrix b) {
  623. Matrix r = new Matrix(b.w, h);
  624. multiplyStrassen_Pow2UnrolledVer(this, 0, 0, w, h, b, 0, 0, b.w, b.h, r, 0, 0, r.w, r.h);
  625. return r;
  626. }
  627.  
  628. static void trustedMultiplyStrassen_Pow2ViewVer(
  629. Matrix a, int ax, int ay, int aw, int ah,
  630. Matrix b, int bx, int by, int bw, int bh,
  631. Matrix r, int rx, int ry, int rw, int rh) {
  632.  
  633. if (shouldResortToPlainMultiply(bw, ah, aw)) {
  634. multiplyPlain(a, ax, ay, aw, ah, b, bx, by, bw, bh, r, rx, ry, rw, rh);
  635. return;
  636. }
  637.  
  638. int halfAw = aw / 2, halfAh = ah / 2,
  639. halfBw = bw / 2, halfBh = bh / 2,
  640. halfRw = rw / 2, halfRh = rh / 2;
  641.  
  642. MatrixView
  643. a11 = new MatrixView(a, ax, ay, halfAw, halfAh),
  644. a21 = new MatrixView(a, ax, ay + halfAh, halfAw, halfAh),
  645. a12 = new MatrixView(a, ax + halfAw, ay, halfAw, halfAh),
  646. a22 = new MatrixView(a, ax + halfAw, ay + halfAh, halfAw, halfAh),
  647.  
  648. b11 = new MatrixView(b, bx, by, halfBw, halfBh),
  649. b21 = new MatrixView(b, bx, by + halfBh, halfBw, halfBh),
  650. b12 = new MatrixView(b, bx + halfBw, by, halfBw, halfBh),
  651. b22 = new MatrixView(b, bx + halfBw, by + halfBh, halfBw, halfBh),
  652.  
  653. m1 = a11.plus(a22).multiplyStrassen_Pow2ViewVer(b11.plus(b22)),
  654. m2 = a21.plus(a22).multiplyStrassen_Pow2ViewVer(b11),
  655. m3 = a11.multiplyStrassen_Pow2ViewVer(b12.minus(b22)),
  656. m4 = a22.multiplyStrassen_Pow2ViewVer(b21.minus(b11)),
  657. m5 = a11.plus(a12).multiplyStrassen_Pow2ViewVer(b22),
  658. m6 = a21.minus(a11).multiplyStrassen_Pow2ViewVer(b11.plus(b12)),
  659. m7 = a12.minus(a22).multiplyStrassen_Pow2ViewVer(b21.plus(b22)),
  660.  
  661. c11 = new MatrixView(r, rx, ry, halfRw, halfRh),
  662. c21 = new MatrixView(r, rx, ry + halfRh, halfRw, halfRh),
  663. c12 = new MatrixView(r, rx + halfRw, ry, halfRw, halfRh),
  664. c22 = new MatrixView(r, rx + halfRw, ry + halfRh, halfRw, halfRh);
  665.  
  666. m1.add(m4, c11); c11.subtract(m5, c11); c11.add(m7, c11);
  667. m3.add(m5, c12);
  668. m2.add(m4, c21);
  669. m1.subtract(m2, c22); c22.add(m3, c22); c22.add(m6, c22);
  670. }
  671.  
  672. public static void multiplyStrassen_Pow2ViewVer(
  673. Matrix a, int ax, int ay, int aw, int ah,
  674. Matrix b, int bx, int by, int bw, int bh,
  675. Matrix r, int rx, int ry, int rw, int rh) {
  676.  
  677. adaptTrustedMultiplyPow2(Matrix::trustedMultiplyStrassen_Pow2ViewVer,
  678. a, ax, ay, aw, ah, b, bx, by, bw, bh, r, rx, ry, rw, rh);
  679. }
  680.  
  681. public Matrix multiplyStrassen_Pow2ViewVer(Matrix b) {
  682. Matrix r = new Matrix(b.w, h);
  683. multiplyStrassen_Pow2ViewVer(this, 0, 0, w, h, b, 0, 0, b.w, b.h, r, 0, 0, r.w, r.h);
  684. return r;
  685. }
  686.  
  687. @Override
  688. public boolean equals(Object other) {
  689. if (this == other) return true;
  690. if (this.getClass() != other.getClass()) return false;
  691. Matrix om = (Matrix) other;
  692. return w == om.w && h == om.h && Arrays.equals(raw, om.raw);
  693. }
  694. }
  695.  
  696. class Ideone {
  697. static final int FILLER = 9;
  698.  
  699. static void testAddSubtract() {
  700. Matrix a = Matrix.explicit(4, 2,
  701. 1, 2, 3, 4,
  702. 5, 6, 7, 8);
  703.  
  704. Matrix b = Matrix.explicit(4, 2,
  705. 9, 11, 13, 15,
  706. 17, 19, 21, 23);
  707.  
  708. Matrix expectedSum = Matrix.explicit(4, 2,
  709. 10, 13, 16, 19,
  710. 22, 25, 28, 31);
  711.  
  712. Matrix expectedDif = Matrix.explicit(4, 2,
  713. -8, -9, -10, -11,
  714. -12, -13, -14, -15);
  715.  
  716. final int F = FILLER;
  717. expectEqual("Matrix.expand", a.expand(FILLER, 1, 4, 2, 3),
  718. Matrix.explicit(9, 7,
  719. F, F, F, F, F, F, F, F, F,
  720. F, F, F, F, F, F, F, F, F,
  721. F, 1, 2, 3, 4, F, F, F, F,
  722. F, 5, 6, 7, 8, F, F, F, F,
  723. F, F, F, F, F, F, F, F, F,
  724. F, F, F, F, F, F, F, F, F,
  725. F, F, F, F, F, F, F, F, F));
  726.  
  727. Matrix result = new Matrix(expectedSum.w + 5, expectedSum.h + 5, (x, y) -> FILLER);
  728. Matrix.add(
  729. a.expand(FILLER, 1, 2, 3, 4), 1, 3, a.w, a.h,
  730. b.expand(FILLER, 5, 6, 7, 8), 5, 7, b.w, b.h,
  731. result, 1, 2, expectedSum.w, expectedSum.h);
  732. expectEqual("Matrix.add", result, expectedSum.expand(FILLER, 1, 4, 2, 3));
  733.  
  734. result.fill((x, y) -> FILLER);
  735. Matrix.subtract(
  736. a.expand(FILLER, 1, 2, 3, 4), 1, 3, a.w, a.h,
  737. b.expand(FILLER, 5, 6, 7, 8), 5, 7, b.w, b.h,
  738. result, 1, 2, expectedDif.w, expectedDif.h);
  739. expectEqual("Matrix.subtract", result, expectedDif.expand(FILLER, 1, 4, 2, 3));
  740.  
  741. expectEqual("ceilPow2", Matrix.ceilPow2(1), 1);
  742. expectEqual("ceilPow2", Matrix.ceilPow2(6), 8);
  743. }
  744.  
  745. static void testPlainMultiply() {
  746. // www.mathwarehouse.com/algebra/matrix/images/product-matrix2.png
  747. Matrix a = Matrix.explicit(4, 2,
  748. 1, 4, 6, 10,
  749. 2, 7, 5, 3);
  750.  
  751. Matrix b = Matrix.explicit(3, 4,
  752. 1, 4, 6,
  753. 2, 7, 5,
  754. 9, 0, 11,
  755. 3, 1, 0);
  756.  
  757. Matrix expected = Matrix.explicit(3, 2,
  758. 93, 42, 92,
  759. 70, 60, 102);
  760.  
  761. Matrix result = new Matrix(expected.w + 5, expected.h + 5, (x, y) -> FILLER);
  762. Matrix.multiplyPlain(
  763. a.expand(FILLER, 1, 2, 3, 4), 1, 3, a.w, a.h,
  764. b.expand(FILLER, 5, 6, 7, 8), 5, 7, b.w, b.h,
  765. result, 1, 2, expected.w, expected.h);
  766. expectEqual("Matrix.multiplyPlain", result, expected.expand(FILLER, 1, 4, 2, 3));
  767. }
  768.  
  769. static void testStrassenMultiply() {
  770. Random rng = new Random(1);
  771. Matrix a = new Matrix(97, 71, (x, y) -> -10 + rng.nextInt(21));
  772. Matrix b = new Matrix(91, 97, (x, y) -> -10 + rng.nextInt(21));
  773. Matrix abProdRef = a.multiplyPlain(b);
  774. expectEqual("Matrix.multiplyStrassen_StraightVer", a.multiplyStrassen_StraightVer(b), abProdRef);
  775. expectEqual("Matrix.multiplyStrassen_ViewVer", a.multiplyStrassen_ViewVer(b), abProdRef);
  776. expectEqual("Matrix.multiplyStrassen_Pow2UnrolledVer", a.multiplyStrassen_Pow2UnrolledVer(b), abProdRef);
  777. expectEqual("Matrix.multiplyStrassen_Pow2ViewVer", a.multiplyStrassen_Pow2ViewVer(b), abProdRef);
  778. }
  779.  
  780. static void expectEqual(String what, Object a, Object b) {
  781. if (!a.equals(b))
  782. throw new RuntimeException(String.format("bad %s: got\n%s\nexpected\n%s.", what, a, b));
  783. }
  784.  
  785. static void benchmark() {
  786. final int TRIALS = 2;
  787. for (int trial = 0; trial < TRIALS; trial++) {
  788. System.out.println(String.format("%s--- TRIAL %d/%d%s ---",
  789. trial > 0? "\n" : "", 1 + trial, TRIALS, trial == 0 ? " — HEATING UP" : ""));
  790. final int AW = 850, AH = 650, BW = 750;
  791. // final int AW = 512, AH = 1024, BW = 768;
  792. Matrix a = new Matrix(AW, AH);
  793. Matrix b = new Matrix(BW, AW);
  794. Matrix r = new Matrix(BW, AH);
  795.  
  796. double refTime = benchmark("Straightforward O(N^3) multiply",
  797. () -> { Matrix.multiplyPlain(a, 0, 0, a.w, a.h, b, 0, 0, b.w, b.h, r, 0, 0, r.w, r.h); return r; });
  798.  
  799. benchmark("Strassen multiply with many intermediate matrix allocations",
  800. () -> { Matrix.multiplyStrassen_StraightVer(a, 0, 0, a.w, a.h, b, 0, 0, b.w, b.h, r, 0, 0, r.w, r.h); return r; }, refTime);
  801.  
  802. benchmark("Strassen multiply with submatrix views",
  803. () -> { Matrix.multiplyStrassen_ViewVer(a, 0, 0, a.w, a.h, b, 0, 0, b.w, b.h, r, 0, 0, r.w, r.h); return r; }, refTime);
  804.  
  805. benchmark("Strassen multiply with forceful 2^n size and unrolled calculations",
  806. () -> { Matrix.multiplyStrassen_Pow2UnrolledVer(a, 0, 0, a.w, a.h, b, 0, 0, b.w, b.h, r, 0, 0, r.w, r.h); return r; }, refTime);
  807.  
  808. benchmark("Strassen multiply with forceful 2^n size and submatrix views",
  809. () -> { Matrix.multiplyStrassen_Pow2ViewVer(a, 0, 0, a.w, a.h, b, 0, 0, b.w, b.h, r, 0, 0, r.w, r.h); return r; }, refTime);
  810. }
  811. }
  812.  
  813. @FunctionalInterface
  814. interface BenchmarkPayload {
  815. abstract Object run();
  816. }
  817.  
  818. static double benchmark(String title, BenchmarkPayload payload) {
  819. return benchmark(title, payload, -1);
  820. }
  821.  
  822. static double benchmark(String title, BenchmarkPayload payload, double refTime) {
  823. long start = System.nanoTime();
  824. Object keepAlive = payload.run();
  825. double time = (System.nanoTime() - start) * 1e-9;
  826. System.out.println(String.format(
  827. "%s: %.2f s%s",
  828. title, time, refTime >= 0 ? String.format(" (%s)", judgement(time, refTime)) : "", keepAlive));
  829. return time;
  830. }
  831.  
  832. static String judgement(double time, double refTime) {
  833. final double absThreshold = 0.1, relThreshold = .05;
  834.  
  835. return time <= absThreshold || refTime <= absThreshold ?
  836. String.format("can't judge, min. required time is %.1f s", absThreshold) :
  837. Math.max(time, refTime) > (1 + relThreshold) * Math.min(time, refTime) ?
  838. String.format("%.0f%% %s", Math.abs(time / refTime - 1) * 100, time < refTime ? "faster" : "slower") :
  839. String.format("no observable difference, min. %.0f%% required", relThreshold * 100);
  840. }
  841.  
  842. public static void main (String[] args) {
  843. try {
  844. testAddSubtract();
  845. testPlainMultiply();
  846. testStrassenMultiply();
  847. benchmark();
  848. } catch (Exception e) {
  849. e.printStackTrace(new PrintWriter(sw));
  850. System.out.println(sw.toString());
  851. }
  852. }
  853. }
Success #stdin #stdout 14.49s 2184192KB
stdin
Standard input is empty
stdout
--- TRIAL 1/2 — HEATING UP ---
Straightforward O(N^3) multiply: 2.93 s
Strassen multiply with many intermediate matrix allocations: 0.95 s (68% faster)
Strassen multiply with submatrix views: 2.23 s (24% faster)
Strassen multiply with forceful 2^n size and unrolled calculations: 1.01 s (65% faster)
Strassen multiply with forceful 2^n size and submatrix views: 1.26 s (57% faster)

--- TRIAL 2/2 ---
Straightforward O(N^3) multiply: 2.67 s
Strassen multiply with many intermediate matrix allocations: 0.75 s (72% faster)
Strassen multiply with submatrix views: 0.63 s (76% faster)
Strassen multiply with forceful 2^n size and unrolled calculations: 0.85 s (68% faster)
Strassen multiply with forceful 2^n size and submatrix views: 0.98 s (63% faster)