fork download
  1. #include <algorithm>
  2. #include <assert.h>
  3. #include <cmath>
  4. #include <cstdlib>
  5. #include <iostream>
  6. #include <iomanip>
  7. #include <string>
  8. #include <vector>
  9.  
  10. // -------------------------------------------------------------------------------------------------
  11.  
  12. class Matrix
  13. {
  14. private:
  15. static const size_t MAX_DIM = 6;
  16. std::vector<double> m_data; // data container
  17. size_t m_ndim=0; // actual number of dimensions
  18. size_t m_size=0; // total number of entries == data.size() == prod(shape)
  19. size_t m_shape[MAX_DIM]; // number of entries in each dimensions
  20. size_t m_strides[MAX_DIM]; // stride length for each index
  21.  
  22. public:
  23.  
  24. // constructor
  25. Matrix(){}; // empty
  26. Matrix(const std::vector<size_t> &shape); // allocate, don't initialize
  27.  
  28. // index operators: access plain storage
  29. double& operator[](size_t i);
  30. const double& operator[](size_t i) const;
  31.  
  32. // index operators: access using Matrix indices
  33. double& operator()(size_t a);
  34. const double& operator()(size_t a) const;
  35. double& operator()(size_t a, size_t b);
  36. const double& operator()(size_t a, size_t b) const;
  37. double& operator()(size_t a, size_t b, size_t c);
  38. const double& operator()(size_t a, size_t b, size_t c) const;
  39. double& operator()(size_t a, size_t b, size_t c, size_t d);
  40. const double& operator()(size_t a, size_t b, size_t c, size_t d) const;
  41. double& operator()(size_t a, size_t b, size_t c, size_t d, size_t e);
  42. const double& operator()(size_t a, size_t b, size_t c, size_t d, size_t e) const;
  43. double& operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f);
  44. const double& operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f) const;
  45.  
  46. // pointer to data
  47. double* data();
  48. const double* data() const;
  49.  
  50. // iterator to first and last entry
  51. auto begin();
  52. auto begin() const;
  53. auto end();
  54. auto end() const;
  55.  
  56. // get dimensions
  57. size_t size() const;
  58. size_t shape(size_t i) const;
  59. std::vector<size_t> shape() const;
  60.  
  61. // initialization
  62. void setZero();
  63.  
  64. // basic algebra
  65. Matrix sum(size_t axis) const;
  66. };
  67.  
  68. // -------------------------------------------------------------------------------------------------
  69.  
  70. inline std::vector<size_t> del(const std::vector<size_t> &A, size_t idx);
  71.  
  72. // -------------------------------------------------------------------------------------------------
  73.  
  74. inline Matrix::Matrix(const std::vector<size_t> &shape)
  75. {
  76. assert( shape.size() > 0 );
  77. assert( shape.size() <= MAX_DIM );
  78.  
  79. for ( size_t i = 0 ; i < MAX_DIM ; ++i )
  80. {
  81. m_shape [i] = 1;
  82. m_strides[i] = 1;
  83. }
  84.  
  85. m_ndim = shape.size();
  86. m_size = 1;
  87.  
  88. for ( size_t i = 0 ; i < m_ndim ; ++i )
  89. {
  90. m_shape[i] = shape[i];
  91. m_size *= shape[i];
  92. }
  93.  
  94. for ( size_t i = 0 ; i < m_ndim ; ++i )
  95. for ( size_t j = i+1 ; j < m_ndim ; ++j )
  96. m_strides[i] *= m_shape[j];
  97.  
  98. m_data.resize(m_size);
  99. }
  100.  
  101. // -------------------------------------------------------------------------------------------------
  102.  
  103. inline double& Matrix::operator[](size_t i)
  104. {
  105. return m_data[i];
  106. }
  107.  
  108. // -------------------------------------------------------------------------------------------------
  109.  
  110. inline const double& Matrix::operator[](size_t i) const
  111. {
  112. return m_data[i];
  113. }
  114.  
  115. // -------------------------------------------------------------------------------------------------
  116.  
  117. inline double& Matrix::operator()(size_t a)
  118. {
  119. return m_data[a*m_strides[0]];
  120. }
  121.  
  122. // -------------------------------------------------------------------------------------------------
  123.  
  124. inline const double& Matrix::operator()(size_t a) const
  125. {
  126. return m_data[a*m_strides[0]];
  127. }
  128.  
  129. // -------------------------------------------------------------------------------------------------
  130.  
  131. inline double& Matrix::operator()(size_t a, size_t b)
  132. {
  133. return m_data[a*m_strides[0]+b*m_strides[1]];
  134. }
  135.  
  136. // -------------------------------------------------------------------------------------------------
  137.  
  138. inline const double& Matrix::operator()(size_t a, size_t b) const
  139. {
  140. return m_data[a*m_strides[0]+b*m_strides[1]];
  141. }
  142.  
  143. // -------------------------------------------------------------------------------------------------
  144.  
  145. inline double& Matrix::operator()(size_t a, size_t b, size_t c)
  146. {
  147. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]];
  148. }
  149.  
  150. // -------------------------------------------------------------------------------------------------
  151.  
  152. inline const double& Matrix::operator()(size_t a, size_t b, size_t c) const
  153. {
  154. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]];
  155. }
  156.  
  157. // -------------------------------------------------------------------------------------------------
  158.  
  159. inline double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d)
  160. {
  161. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]];
  162. }
  163.  
  164. // -------------------------------------------------------------------------------------------------
  165.  
  166. inline const double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d) const
  167. {
  168. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]];
  169. }
  170.  
  171. // -------------------------------------------------------------------------------------------------
  172.  
  173. inline double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e)
  174. {
  175. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]];
  176. }
  177.  
  178. // -------------------------------------------------------------------------------------------------
  179.  
  180. inline const double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e) const
  181. {
  182. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]];
  183. }
  184.  
  185. // -------------------------------------------------------------------------------------------------
  186.  
  187. inline double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f)
  188. {
  189. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]+f*m_strides[5]];
  190. }
  191.  
  192. // -------------------------------------------------------------------------------------------------
  193.  
  194. inline const double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f) const
  195. {
  196. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]+f*m_strides[5]];
  197. }
  198.  
  199. // -------------------------------------------------------------------------------------------------
  200.  
  201. inline double* Matrix::data()
  202. {
  203. return m_data.data();
  204. }
  205.  
  206. // -------------------------------------------------------------------------------------------------
  207.  
  208. inline const double* Matrix::data() const
  209. {
  210. return m_data.data();
  211. }
  212.  
  213. // -------------------------------------------------------------------------------------------------
  214.  
  215. inline auto Matrix::begin()
  216. {
  217. return m_data.begin();
  218. }
  219.  
  220. // -------------------------------------------------------------------------------------------------
  221.  
  222. inline auto Matrix::begin() const
  223. {
  224. return m_data.begin();
  225. }
  226.  
  227. // -------------------------------------------------------------------------------------------------
  228.  
  229. inline auto Matrix::end()
  230. {
  231. return m_data.end();
  232. }
  233.  
  234. // -------------------------------------------------------------------------------------------------
  235.  
  236. inline auto Matrix::end() const
  237. {
  238. return m_data.end();
  239. }
  240.  
  241. // -------------------------------------------------------------------------------------------------
  242.  
  243. inline size_t Matrix::shape(size_t i) const
  244. {
  245. return m_shape[i];
  246. }
  247.  
  248. // -------------------------------------------------------------------------------------------------
  249.  
  250. inline std::vector<size_t> Matrix::shape() const
  251. {
  252. std::vector<size_t> out(m_ndim);
  253.  
  254. for ( size_t i = 0 ; i < m_ndim ; ++i ) out[i] = m_shape[i];
  255.  
  256. return out;
  257. }
  258.  
  259. // -------------------------------------------------------------------------------------------------
  260.  
  261. inline size_t Matrix::size() const
  262. {
  263. return m_size;
  264. }
  265.  
  266. // -------------------------------------------------------------------------------------------------
  267.  
  268. inline void Matrix::setZero()
  269. {
  270. for ( size_t i = 0 ; i < m_size ; ++i ) m_data[i] = static_cast<double>(0);
  271. }
  272.  
  273. // -------------------------------------------------------------------------------------------------
  274.  
  275. inline Matrix Matrix::sum(size_t axis) const
  276. {
  277. Matrix out(del(this->shape(),axis));
  278.  
  279. out.setZero();
  280.  
  281. // How many elements to advance after each reduction
  282. size_t step_axis = m_strides[m_ndim-1];
  283.  
  284. if ( axis == m_ndim-1)
  285. step_axis = m_strides[m_ndim-2];
  286.  
  287. // Position of the first element of the current reduction
  288. size_t offset_base = 0;
  289. size_t offset = 0;
  290.  
  291. size_t s = 0;
  292. for ( auto &v : out )
  293. {
  294. // Current reduced element
  295. size_t offset_i = offset;
  296.  
  297. for ( size_t i = 0 ; i < m_shape[axis] ; ++i )
  298. {
  299. // - reduce
  300. v += *(m_data.data() + offset_i);
  301. // - advance to next element
  302. offset_i += m_strides[axis];
  303. }
  304.  
  305. s = (s + 1) % m_strides[axis];
  306. if (s == 0)
  307. {
  308. offset_base += m_strides[axis - 1];
  309. offset = offset_base;
  310. }
  311. else
  312. {
  313. offset += step_axis;
  314. }
  315. }
  316.  
  317. return out;
  318. }
  319.  
  320. // -------------------------------------------------------------------------------------------------
  321.  
  322. inline std::vector<size_t> del(const std::vector<size_t> &A, size_t idx)
  323. {
  324. assert( idx < A.size() );
  325.  
  326. std::vector<size_t> B = A;
  327.  
  328. B.erase(B.begin()+idx, B.begin()+idx+1);
  329.  
  330. return B;
  331. }
  332.  
  333. // -------------------------------------------------------------------------------------------------
  334.  
  335. int main()
  336. {
  337. Matrix A({6, 11, 16, 3});
  338.  
  339. for ( size_t i = 0 ; i < A.size() ; ++i )
  340. A[i] = static_cast<double>(i) + 0.1;
  341.  
  342. // Select reduction axis
  343. size_t axis = 0;
  344.  
  345. Matrix B = A.sum(axis);
  346.  
  347. std::vector<size_t> b_shape;
  348. for (size_t i = 0; i < A.shape().size(); i++)
  349. {
  350. if (i == axis) continue;
  351. b_shape.push_back(A.shape(i));
  352. }
  353. Matrix b(b_shape);
  354.  
  355. b.setZero();
  356.  
  357. for ( size_t i = 0 ; i < A.shape(0) ; i++ )
  358. for ( size_t j = 0 ; j < A.shape(1) ; j++ )
  359. for ( size_t k = 0 ; k < A.shape(2) ; k++ )
  360. for ( size_t l = 0 ; l < A.shape(3) ; l++ )
  361. switch (axis)
  362. {
  363. case 0:
  364. b(j, k, l) += A(i, j, k, l);
  365. break;
  366. case 1:
  367. b(i, k, l) += A(i, j, k, l);
  368. break;
  369. case 2:
  370. b(i, j, l) += A(i, j, k, l);
  371. break;
  372. case 3:
  373. b(i, j, k) += A(i, j, k, l);
  374. break;
  375. }
  376.  
  377. bool errors = false;
  378. for ( size_t i = 0 ; i < b.size() ; ++i )
  379. {
  380. if ( std::abs( b[i] - B[i] ) > 1.e-12 )
  381. {
  382. std::cout << "Error in [" << i << "] - Correct: " << b[i] << ", result:" << B[i] << std::endl;
  383. errors = true;
  384. }
  385. }
  386. if (!errors)
  387. {
  388. std::cout << "Computation was correct." << std::endl;
  389. }
  390.  
  391. return 0;
  392. }
  393.  
Success #stdin #stdout 0s 4168KB
stdin
Standard input is empty
stdout
Computation was correct.