fork download
  1. #include <iostream>
  2. #include <vector>
  3.  
  4. int main()
  5. {
  6. // shape, stride & data of the matrix
  7.  
  8. size_t shape [] = { 2, 3, 4, 5};
  9. size_t strides[] = {60,20, 5, 1};
  10.  
  11. std::vector<double> data(2*3*4*5);
  12.  
  13. for ( size_t i = 0 ; i < data.size() ; ++i ) data[i] = 1.;
  14.  
  15. // shape, stride & data (zero-initialized) of the reduced matrix
  16.  
  17. size_t rshape [] = { 2, 4, 5};
  18. size_t rstrides[] = {20, 5, 1};
  19.  
  20. std::vector<double> rdata(2*4*5, 0.0);
  21.  
  22. // compute reduction
  23.  
  24. // for ( size_t a = 0 ; a < shape[0] ; ++a )
  25. // for ( size_t c = 0 ; c < shape[2] ; ++c )
  26. // for ( size_t d = 0 ; d < shape[3] ; ++d )
  27. // for ( size_t b = 0 ; b < shape[1] ; ++b )
  28. // rdata[ a*rstrides[0] + c*rstrides[1] + d*rstrides[2] ] += \
  29. // data [ a*strides [0] + b*strides [1] + c*strides [2] + d*strides [3] ];
  30.  
  31. size_t cmp_axis = 1, axis_count = sizeof shape/ sizeof *shape;
  32. std::vector<size_t> adjusted_strides;
  33. //adjusted strides is basically same as strides
  34. //only difference being that the first element is the
  35. //total number of elements in the n dim array.
  36.  
  37. //The only reason to introduce this array was
  38. //so that I don't have to write any if-elses
  39. adjusted_strides.push_back(shape[0]*strides[0]);
  40. adjusted_strides.insert(adjusted_strides.end(), strides, strides + axis_count);
  41. for(size_t i = 0; i < data.size(); ++i) {
  42. size_t ni = i/adjusted_strides[cmp_axis]*adjusted_strides[cmp_axis+1] + i%adjusted_strides[cmp_axis+1];
  43. rdata[ni] += data[i];
  44. }
  45. // print resulting reduced matrix
  46.  
  47. for ( size_t a = 0 ; a < rshape[0] ; ++a )
  48. for ( size_t b = 0 ; b < rshape[1] ; ++b )
  49. for ( size_t c = 0 ; c < rshape[2] ; ++c )
  50. std::cout << "(" << a << "," << b << "," << c << ") " << \
  51. rdata[ a*rstrides[0] + b*rstrides[1] + c*rstrides[2] ] << std::endl;
  52.  
  53. return 0;
  54. }
Success #stdin #stdout 0s 4504KB
stdin
Standard input is empty
stdout
(0,0,0) 3
(0,0,1) 3
(0,0,2) 3
(0,0,3) 3
(0,0,4) 3
(0,1,0) 3
(0,1,1) 3
(0,1,2) 3
(0,1,3) 3
(0,1,4) 3
(0,2,0) 3
(0,2,1) 3
(0,2,2) 3
(0,2,3) 3
(0,2,4) 3
(0,3,0) 3
(0,3,1) 3
(0,3,2) 3
(0,3,3) 3
(0,3,4) 3
(1,0,0) 3
(1,0,1) 3
(1,0,2) 3
(1,0,3) 3
(1,0,4) 3
(1,1,0) 3
(1,1,1) 3
(1,1,2) 3
(1,1,3) 3
(1,1,4) 3
(1,2,0) 3
(1,2,1) 3
(1,2,2) 3
(1,2,3) 3
(1,2,4) 3
(1,3,0) 3
(1,3,1) 3
(1,3,2) 3
(1,3,3) 3
(1,3,4) 3