fork download
  1. #include <algorithm>
  2. #include <assert.h>
  3. #include <cmath>
  4. #include <cstdlib>
  5. #include <iostream>
  6. #include <iomanip>
  7. #include <string>
  8. #include <vector>
  9.  
  10. // -------------------------------------------------------------------------------------------------
  11.  
  12. class Matrix
  13. {
  14. private:
  15. static const size_t MAX_DIM = 6;
  16. std::vector<double> m_data; // data container
  17. size_t m_ndim=0; // actual number of dimensions
  18. size_t m_size=0; // total number of entries == data.size() == prod(shape)
  19. size_t m_shape[MAX_DIM]; // number of entries in each dimensions
  20. size_t m_strides[MAX_DIM]; // stride length for each index
  21.  
  22. public:
  23.  
  24. // constructor
  25. Matrix(){}; // empty
  26. Matrix(const std::vector<size_t> &shape); // allocate, don't initialize
  27.  
  28. // index operators: access plain storage
  29. double& operator[](size_t i);
  30. const double& operator[](size_t i) const;
  31.  
  32. // index operators: access using Matrix indices
  33. double& operator()(size_t a);
  34. const double& operator()(size_t a) const;
  35. double& operator()(size_t a, size_t b);
  36. const double& operator()(size_t a, size_t b) const;
  37. double& operator()(size_t a, size_t b, size_t c);
  38. const double& operator()(size_t a, size_t b, size_t c) const;
  39. double& operator()(size_t a, size_t b, size_t c, size_t d);
  40. const double& operator()(size_t a, size_t b, size_t c, size_t d) const;
  41. double& operator()(size_t a, size_t b, size_t c, size_t d, size_t e);
  42. const double& operator()(size_t a, size_t b, size_t c, size_t d, size_t e) const;
  43. double& operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f);
  44. const double& operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f) const;
  45.  
  46. // pointer to data
  47. double* data();
  48. const double* data() const;
  49.  
  50. // iterator to first and last entry
  51. auto begin();
  52. auto begin() const;
  53. auto end();
  54. auto end() const;
  55.  
  56. // get dimensions
  57. size_t size() const;
  58. size_t shape(size_t i) const;
  59. std::vector<size_t> shape() const;
  60.  
  61. // initialization
  62. void setZero();
  63.  
  64. // basic algebra
  65. Matrix sum(size_t axis) const;
  66. };
  67.  
  68. // -------------------------------------------------------------------------------------------------
  69.  
  70. inline std::vector<size_t> del(const std::vector<size_t> &A, size_t idx);
  71.  
  72. // -------------------------------------------------------------------------------------------------
  73.  
  74. inline Matrix::Matrix(const std::vector<size_t> &shape)
  75. {
  76. assert( shape.size() > 0 );
  77. assert( shape.size() <= MAX_DIM );
  78.  
  79. for ( size_t i = 0 ; i < MAX_DIM ; ++i )
  80. {
  81. m_shape [i] = 1;
  82. m_strides[i] = 1;
  83. }
  84.  
  85. m_ndim = shape.size();
  86. m_size = 1;
  87.  
  88. for ( size_t i = 0 ; i < m_ndim ; ++i )
  89. {
  90. m_shape[i] = shape[i];
  91. m_size *= shape[i];
  92. }
  93.  
  94. for ( size_t i = 0 ; i < m_ndim ; ++i )
  95. for ( size_t j = i+1 ; j < m_ndim ; ++j )
  96. m_strides[i] *= m_shape[j];
  97.  
  98. m_data.resize(m_size);
  99. }
  100.  
  101. // -------------------------------------------------------------------------------------------------
  102.  
  103. inline double& Matrix::operator[](size_t i)
  104. {
  105. return m_data[i];
  106. }
  107.  
  108. // -------------------------------------------------------------------------------------------------
  109.  
  110. inline const double& Matrix::operator[](size_t i) const
  111. {
  112. return m_data[i];
  113. }
  114.  
  115. // -------------------------------------------------------------------------------------------------
  116.  
  117. inline double& Matrix::operator()(size_t a)
  118. {
  119. return m_data[a*m_strides[0]];
  120. }
  121.  
  122. // -------------------------------------------------------------------------------------------------
  123.  
  124. inline const double& Matrix::operator()(size_t a) const
  125. {
  126. return m_data[a*m_strides[0]];
  127. }
  128.  
  129. // -------------------------------------------------------------------------------------------------
  130.  
  131. inline double& Matrix::operator()(size_t a, size_t b)
  132. {
  133. return m_data[a*m_strides[0]+b*m_strides[1]];
  134. }
  135.  
  136. // -------------------------------------------------------------------------------------------------
  137.  
  138. inline const double& Matrix::operator()(size_t a, size_t b) const
  139. {
  140. return m_data[a*m_strides[0]+b*m_strides[1]];
  141. }
  142.  
  143. // -------------------------------------------------------------------------------------------------
  144.  
  145. inline double& Matrix::operator()(size_t a, size_t b, size_t c)
  146. {
  147. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]];
  148. }
  149.  
  150. // -------------------------------------------------------------------------------------------------
  151.  
  152. inline const double& Matrix::operator()(size_t a, size_t b, size_t c) const
  153. {
  154. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]];
  155. }
  156.  
  157. // -------------------------------------------------------------------------------------------------
  158.  
  159. inline double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d)
  160. {
  161. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]];
  162. }
  163.  
  164. // -------------------------------------------------------------------------------------------------
  165.  
  166. inline const double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d) const
  167. {
  168. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]];
  169. }
  170.  
  171. // -------------------------------------------------------------------------------------------------
  172.  
  173. inline double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e)
  174. {
  175. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]];
  176. }
  177.  
  178. // -------------------------------------------------------------------------------------------------
  179.  
  180. inline const double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e) const
  181. {
  182. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]];
  183. }
  184.  
  185. // -------------------------------------------------------------------------------------------------
  186.  
  187. inline double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f)
  188. {
  189. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]+f*m_strides[5]];
  190. }
  191.  
  192. // -------------------------------------------------------------------------------------------------
  193.  
  194. inline const double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f) const
  195. {
  196. return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]+f*m_strides[5]];
  197. }
  198.  
  199. // -------------------------------------------------------------------------------------------------
  200.  
  201. inline double* Matrix::data()
  202. {
  203. return m_data.data();
  204. }
  205.  
  206. // -------------------------------------------------------------------------------------------------
  207.  
  208. inline const double* Matrix::data() const
  209. {
  210. return m_data.data();
  211. }
  212.  
  213. // -------------------------------------------------------------------------------------------------
  214.  
  215. inline auto Matrix::begin()
  216. {
  217. return m_data.begin();
  218. }
  219.  
  220. // -------------------------------------------------------------------------------------------------
  221.  
  222. inline auto Matrix::begin() const
  223. {
  224. return m_data.begin();
  225. }
  226.  
  227. // -------------------------------------------------------------------------------------------------
  228.  
  229. inline auto Matrix::end()
  230. {
  231. return m_data.end();
  232. }
  233.  
  234. // -------------------------------------------------------------------------------------------------
  235.  
  236. inline auto Matrix::end() const
  237. {
  238. return m_data.end();
  239. }
  240.  
  241. // -------------------------------------------------------------------------------------------------
  242.  
  243. inline size_t Matrix::shape(size_t i) const
  244. {
  245. return m_shape[i];
  246. }
  247.  
  248. // -------------------------------------------------------------------------------------------------
  249.  
  250. inline std::vector<size_t> Matrix::shape() const
  251. {
  252. std::vector<size_t> out(m_ndim);
  253.  
  254. for ( size_t i = 0 ; i < m_ndim ; ++i ) out[i] = m_shape[i];
  255.  
  256. return out;
  257. }
  258.  
  259. // -------------------------------------------------------------------------------------------------
  260.  
  261. inline size_t Matrix::size() const
  262. {
  263. return m_size;
  264. }
  265.  
  266. // -------------------------------------------------------------------------------------------------
  267.  
  268. inline void Matrix::setZero()
  269. {
  270. for ( size_t i = 0 ; i < m_size ; ++i ) m_data[i] = static_cast<double>(0);
  271. }
  272.  
  273. // -------------------------------------------------------------------------------------------------
  274.  
  275. inline Matrix Matrix::sum(size_t axis) const
  276. {
  277. Matrix out(del(this->shape(),axis));
  278.  
  279. out.setZero();
  280.  
  281. // How many elements to advance after each reduction
  282. size_t step_axis = m_strides[MAX_DIM-1];
  283.  
  284. if ( axis == MAX_DIM-1)
  285. step_axis = m_strides[MAX_DIM-2];
  286.  
  287. // Position of the first element of the current reduction
  288. size_t offset = 0;
  289.  
  290. for ( auto &v : out )
  291. {
  292. // Current reduced element
  293. size_t offset_i = offset;
  294.  
  295. for ( size_t i = 0 ; i < m_shape[axis] ; ++i )
  296. {
  297. // - reduce
  298. v += *(m_data.data() + offset_i);
  299. // - advance to next element
  300. offset_i += m_strides[axis];
  301. }
  302.  
  303. offset += step_axis;
  304. }
  305.  
  306. return out;
  307. }
  308.  
  309. // -------------------------------------------------------------------------------------------------
  310.  
  311. inline std::vector<size_t> del(const std::vector<size_t> &A, size_t idx)
  312. {
  313. assert( idx < A.size() );
  314.  
  315. std::vector<size_t> B = A;
  316.  
  317. B.erase(B.begin()+idx, B.begin()+idx+1);
  318.  
  319. return B;
  320. }
  321.  
  322. // -------------------------------------------------------------------------------------------------
  323.  
  324. int main()
  325. {
  326. Matrix A({6,11,16,3});
  327.  
  328. for ( size_t i = 0 ; i < A.size() ; ++i )
  329. A[i] = static_cast<double>(i) + 0.1;
  330.  
  331. Matrix B = A.sum(1);
  332.  
  333. Matrix b({6,16,3});
  334.  
  335. b.setZero();
  336.  
  337. for ( size_t i = 0 ; i < A.shape(0) ; i++ )
  338. for ( size_t j = 0 ; j < A.shape(1) ; j++ )
  339. for ( size_t k = 0 ; k < A.shape(2) ; k++ )
  340. for ( size_t l = 0 ; l < A.shape(3) ; l++ )
  341. b(i,k,l) += A(i,j,k,l);
  342.  
  343. for ( size_t i = 0 ; i < b.size() ; ++i )
  344. if ( std::abs( b[i] - B[i] ) > 1.e-12 )
  345. std::cout << b[i] << ", " << B[i] << std::endl;
  346.  
  347. return 0;
  348. }
  349.  
  350.  
Success #stdin #stdout 0s 4388KB
stdin
Standard input is empty
stdout
8449.1, 3169.1
8460.1, 3180.1
8471.1, 3191.1
8482.1, 3202.1
8493.1, 3213.1
8504.1, 3224.1
8515.1, 3235.1
8526.1, 3246.1
8537.1, 3257.1
8548.1, 3268.1
8559.1, 3279.1
8570.1, 3290.1
8581.1, 3301.1
8592.1, 3312.1
8603.1, 3323.1
8614.1, 3334.1
8625.1, 3345.1
8636.1, 3356.1
8647.1, 3367.1
8658.1, 3378.1
8669.1, 3389.1
8680.1, 3400.1
8691.1, 3411.1
8702.1, 3422.1
8713.1, 3433.1
8724.1, 3444.1
8735.1, 3455.1
8746.1, 3466.1
8757.1, 3477.1
8768.1, 3488.1
8779.1, 3499.1
8790.1, 3510.1
8801.1, 3521.1
8812.1, 3532.1
8823.1, 3543.1
8834.1, 3554.1
8845.1, 3565.1
8856.1, 3576.1
8867.1, 3587.1
8878.1, 3598.1
8889.1, 3609.1
8900.1, 3620.1
8911.1, 3631.1
8922.1, 3642.1
8933.1, 3653.1
8944.1, 3664.1
8955.1, 3675.1
8966.1, 3686.1
14257.1, 3697.1
14268.1, 3708.1
14279.1, 3719.1
14290.1, 3730.1
14301.1, 3741.1
14312.1, 3752.1
14323.1, 3763.1
14334.1, 3774.1
14345.1, 3785.1
14356.1, 3796.1
14367.1, 3807.1
14378.1, 3818.1
14389.1, 3829.1
14400.1, 3840.1
14411.1, 3851.1
14422.1, 3862.1
14433.1, 3873.1
14444.1, 3884.1
14455.1, 3895.1
14466.1, 3906.1
14477.1, 3917.1
14488.1, 3928.1
14499.1, 3939.1
14510.1, 3950.1
14521.1, 3961.1
14532.1, 3972.1
14543.1, 3983.1
14554.1, 3994.1
14565.1, 4005.1
14576.1, 4016.1
14587.1, 4027.1
14598.1, 4038.1
14609.1, 4049.1
14620.1, 4060.1
14631.1, 4071.1
14642.1, 4082.1
14653.1, 4093.1
14664.1, 4104.1
14675.1, 4115.1
14686.1, 4126.1
14697.1, 4137.1
14708.1, 4148.1
14719.1, 4159.1
14730.1, 4170.1
14741.1, 4181.1
14752.1, 4192.1
14763.1, 4203.1
14774.1, 4214.1
20065.1, 4225.1
20076.1, 4236.1
20087.1, 4247.1
20098.1, 4258.1
20109.1, 4269.1
20120.1, 4280.1
20131.1, 4291.1
20142.1, 4302.1
20153.1, 4313.1
20164.1, 4324.1
20175.1, 4335.1
20186.1, 4346.1
20197.1, 4357.1
20208.1, 4368.1
20219.1, 4379.1
20230.1, 4390.1
20241.1, 4401.1
20252.1, 4412.1
20263.1, 4423.1
20274.1, 4434.1
20285.1, 4445.1
20296.1, 4456.1
20307.1, 4467.1
20318.1, 4478.1
20329.1, 4489.1
20340.1, 4500.1
20351.1, 4511.1
20362.1, 4522.1
20373.1, 4533.1
20384.1, 4544.1
20395.1, 4555.1
20406.1, 4566.1
20417.1, 4577.1
20428.1, 4588.1
20439.1, 4599.1
20450.1, 4610.1
20461.1, 4621.1
20472.1, 4632.1
20483.1, 4643.1
20494.1, 4654.1
20505.1, 4665.1
20516.1, 4676.1
20527.1, 4687.1
20538.1, 4698.1
20549.1, 4709.1
20560.1, 4720.1
20571.1, 4731.1
20582.1, 4742.1
25873.1, 4753.1
25884.1, 4764.1
25895.1, 4775.1
25906.1, 4786.1
25917.1, 4797.1
25928.1, 4808.1
25939.1, 4819.1
25950.1, 4830.1
25961.1, 4841.1
25972.1, 4852.1
25983.1, 4863.1
25994.1, 4874.1
26005.1, 4885.1
26016.1, 4896.1
26027.1, 4907.1
26038.1, 4918.1
26049.1, 4929.1
26060.1, 4940.1
26071.1, 4951.1
26082.1, 4962.1
26093.1, 4973.1
26104.1, 4984.1
26115.1, 4995.1
26126.1, 5006.1
26137.1, 5017.1
26148.1, 5028.1
26159.1, 5039.1
26170.1, 5050.1
26181.1, 5061.1
26192.1, 5072.1
26203.1, 5083.1
26214.1, 5094.1
26225.1, 5105.1
26236.1, 5116.1
26247.1, 5127.1
26258.1, 5138.1
26269.1, 5149.1
26280.1, 5160.1
26291.1, 5171.1
26302.1, 5182.1
26313.1, 5193.1
26324.1, 5204.1
26335.1, 5215.1
26346.1, 5226.1
26357.1, 5237.1
26368.1, 5248.1
26379.1, 5259.1
26390.1, 5270.1
31681.1, 5281.1
31692.1, 5292.1
31703.1, 5303.1
31714.1, 5314.1
31725.1, 5325.1
31736.1, 5336.1
31747.1, 5347.1
31758.1, 5358.1
31769.1, 5369.1
31780.1, 5380.1
31791.1, 5391.1
31802.1, 5402.1
31813.1, 5413.1
31824.1, 5424.1
31835.1, 5435.1
31846.1, 5446.1
31857.1, 5457.1
31868.1, 5468.1
31879.1, 5479.1
31890.1, 5490.1
31901.1, 5501.1
31912.1, 5512.1
31923.1, 5523.1
31934.1, 5534.1
31945.1, 5545.1
31956.1, 5556.1
31967.1, 5567.1
31978.1, 5578.1
31989.1, 5589.1
32000.1, 5600.1
32011.1, 5611.1
32022.1, 5622.1
32033.1, 5633.1
32044.1, 5644.1
32055.1, 5655.1
32066.1, 5666.1
32077.1, 5677.1
32088.1, 5688.1
32099.1, 5699.1
32110.1, 5710.1
32121.1, 5721.1
32132.1, 5732.1
32143.1, 5743.1
32154.1, 5754.1
32165.1, 5765.1
32176.1, 5776.1
32187.1, 5787.1
32198.1, 5798.1