#include <algorithm> #include <assert.h> #include <cmath> #include <cstdlib> #include <iostream> #include <iomanip> #include <string> #include <vector> // ------------------------------------------------------------------------------------------------- class Matrix { private: static const size_t MAX_DIM = 6; std::vector<double> m_data; // data container size_t m_ndim=0; // actual number of dimensions size_t m_size=0; // total number of entries == data.size() == prod(shape) size_t m_shape[MAX_DIM]; // number of entries in each dimensions size_t m_strides[MAX_DIM]; // stride length for each index public: // constructor Matrix(){}; // empty Matrix(const std::vector<size_t> &shape); // allocate, don't initialize // index operators: access plain storage double& operator[](size_t i); const double& operator[](size_t i) const; // index operators: access using Matrix indices double& operator()(size_t a); const double& operator()(size_t a) const; double& operator()(size_t a, size_t b); const double& operator()(size_t a, size_t b) const; double& operator()(size_t a, size_t b, size_t c); const double& operator()(size_t a, size_t b, size_t c) const; double& operator()(size_t a, size_t b, size_t c, size_t d); const double& operator()(size_t a, size_t b, size_t c, size_t d) const; double& operator()(size_t a, size_t b, size_t c, size_t d, size_t e); const double& operator()(size_t a, size_t b, size_t c, size_t d, size_t e) const; double& operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f); const double& operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f) const; // pointer to data double* data(); const double* data() const; // iterator to first and last entry auto begin(); auto begin() const; auto end(); auto end() const; // get dimensions size_t size() const; size_t shape(size_t i) const; std::vector<size_t> shape() const; // initialization void setZero(); // basic algebra Matrix sum(size_t axis) const; }; // ------------------------------------------------------------------------------------------------- inline std::vector<size_t> del(const std::vector<size_t> &A, size_t idx); // ------------------------------------------------------------------------------------------------- inline Matrix::Matrix(const std::vector<size_t> &shape) { assert( shape.size() > 0 ); assert( shape.size() <= MAX_DIM ); for ( size_t i = 0 ; i < MAX_DIM ; ++i ) { m_shape [i] = 1; m_strides[i] = 1; } m_ndim = shape.size(); m_size = 1; for ( size_t i = 0 ; i < m_ndim ; ++i ) { m_shape[i] = shape[i]; m_size *= shape[i]; } for ( size_t i = 0 ; i < m_ndim ; ++i ) for ( size_t j = i+1 ; j < m_ndim ; ++j ) m_strides[i] *= m_shape[j]; m_data.resize(m_size); } // ------------------------------------------------------------------------------------------------- inline double& Matrix::operator[](size_t i) { return m_data[i]; } // ------------------------------------------------------------------------------------------------- inline const double& Matrix::operator[](size_t i) const { return m_data[i]; } // ------------------------------------------------------------------------------------------------- inline double& Matrix::operator()(size_t a) { return m_data[a*m_strides[0]]; } // ------------------------------------------------------------------------------------------------- inline const double& Matrix::operator()(size_t a) const { return m_data[a*m_strides[0]]; } // ------------------------------------------------------------------------------------------------- inline double& Matrix::operator()(size_t a, size_t b) { return m_data[a*m_strides[0]+b*m_strides[1]]; } // ------------------------------------------------------------------------------------------------- inline const double& Matrix::operator()(size_t a, size_t b) const { return m_data[a*m_strides[0]+b*m_strides[1]]; } // ------------------------------------------------------------------------------------------------- inline double& Matrix::operator()(size_t a, size_t b, size_t c) { return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]]; } // ------------------------------------------------------------------------------------------------- inline const double& Matrix::operator()(size_t a, size_t b, size_t c) const { return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]]; } // ------------------------------------------------------------------------------------------------- inline double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d) { return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]]; } // ------------------------------------------------------------------------------------------------- inline const double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d) const { return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]]; } // ------------------------------------------------------------------------------------------------- inline double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e) { return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]]; } // ------------------------------------------------------------------------------------------------- inline const double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e) const { return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]]; } // ------------------------------------------------------------------------------------------------- inline double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f) { return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]+f*m_strides[5]]; } // ------------------------------------------------------------------------------------------------- inline const double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f) const { return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]+f*m_strides[5]]; } // ------------------------------------------------------------------------------------------------- inline double* Matrix::data() { return m_data.data(); } // ------------------------------------------------------------------------------------------------- inline const double* Matrix::data() const { return m_data.data(); } // ------------------------------------------------------------------------------------------------- inline auto Matrix::begin() { return m_data.begin(); } // ------------------------------------------------------------------------------------------------- inline auto Matrix::begin() const { return m_data.begin(); } // ------------------------------------------------------------------------------------------------- inline auto Matrix::end() { return m_data.end(); } // ------------------------------------------------------------------------------------------------- inline auto Matrix::end() const { return m_data.end(); } // ------------------------------------------------------------------------------------------------- inline size_t Matrix::shape(size_t i) const { return m_shape[i]; } // ------------------------------------------------------------------------------------------------- inline std::vector<size_t> Matrix::shape() const { std::vector<size_t> out(m_ndim); for ( size_t i = 0 ; i < m_ndim ; ++i ) out[i] = m_shape[i]; return out; } // ------------------------------------------------------------------------------------------------- inline size_t Matrix::size() const { return m_size; } // ------------------------------------------------------------------------------------------------- inline void Matrix::setZero() { for ( size_t i = 0 ; i < m_size ; ++i ) m_data[i] = static_cast<double>(0); } // ------------------------------------------------------------------------------------------------- inline Matrix Matrix::sum(size_t axis) const { Matrix out(del(this->shape(),axis)); out.setZero(); // How many elements to advance after each reduction size_t step_axis = m_strides[MAX_DIM-1]; if ( axis == MAX_DIM-1) step_axis = m_strides[MAX_DIM-2]; // Position of the first element of the current reduction size_t offset = 0; for ( auto &v : out ) { // Current reduced element size_t offset_i = offset; for ( size_t i = 0 ; i < m_shape[axis] ; ++i ) { // - reduce v += *(m_data.data() + offset_i); // - advance to next element offset_i += m_strides[axis]; } offset += step_axis; } return out; } // ------------------------------------------------------------------------------------------------- inline std::vector<size_t> del(const std::vector<size_t> &A, size_t idx) { assert( idx < A.size() ); std::vector<size_t> B = A; B.erase(B.begin()+idx, B.begin()+idx+1); return B; } // ------------------------------------------------------------------------------------------------- int main() { Matrix A({6,11,16,3}); for ( size_t i = 0 ; i < A.size() ; ++i ) A[i] = static_cast<double>(i) + 0.1; Matrix B = A.sum(1); Matrix b({6,16,3}); b.setZero(); for ( size_t i = 0 ; i < A.shape(0) ; i++ ) for ( size_t j = 0 ; j < A.shape(1) ; j++ ) for ( size_t k = 0 ; k < A.shape(2) ; k++ ) for ( size_t l = 0 ; l < A.shape(3) ; l++ ) b(i,k,l) += A(i,j,k,l); for ( size_t i = 0 ; i < b.size() ; ++i ) if ( std::abs( b[i] - B[i] ) > 1.e-12 ) std::cout << b[i] << ", " << B[i] << std::endl; return 0; }
Standard input is empty
8449.1, 3169.1 8460.1, 3180.1 8471.1, 3191.1 8482.1, 3202.1 8493.1, 3213.1 8504.1, 3224.1 8515.1, 3235.1 8526.1, 3246.1 8537.1, 3257.1 8548.1, 3268.1 8559.1, 3279.1 8570.1, 3290.1 8581.1, 3301.1 8592.1, 3312.1 8603.1, 3323.1 8614.1, 3334.1 8625.1, 3345.1 8636.1, 3356.1 8647.1, 3367.1 8658.1, 3378.1 8669.1, 3389.1 8680.1, 3400.1 8691.1, 3411.1 8702.1, 3422.1 8713.1, 3433.1 8724.1, 3444.1 8735.1, 3455.1 8746.1, 3466.1 8757.1, 3477.1 8768.1, 3488.1 8779.1, 3499.1 8790.1, 3510.1 8801.1, 3521.1 8812.1, 3532.1 8823.1, 3543.1 8834.1, 3554.1 8845.1, 3565.1 8856.1, 3576.1 8867.1, 3587.1 8878.1, 3598.1 8889.1, 3609.1 8900.1, 3620.1 8911.1, 3631.1 8922.1, 3642.1 8933.1, 3653.1 8944.1, 3664.1 8955.1, 3675.1 8966.1, 3686.1 14257.1, 3697.1 14268.1, 3708.1 14279.1, 3719.1 14290.1, 3730.1 14301.1, 3741.1 14312.1, 3752.1 14323.1, 3763.1 14334.1, 3774.1 14345.1, 3785.1 14356.1, 3796.1 14367.1, 3807.1 14378.1, 3818.1 14389.1, 3829.1 14400.1, 3840.1 14411.1, 3851.1 14422.1, 3862.1 14433.1, 3873.1 14444.1, 3884.1 14455.1, 3895.1 14466.1, 3906.1 14477.1, 3917.1 14488.1, 3928.1 14499.1, 3939.1 14510.1, 3950.1 14521.1, 3961.1 14532.1, 3972.1 14543.1, 3983.1 14554.1, 3994.1 14565.1, 4005.1 14576.1, 4016.1 14587.1, 4027.1 14598.1, 4038.1 14609.1, 4049.1 14620.1, 4060.1 14631.1, 4071.1 14642.1, 4082.1 14653.1, 4093.1 14664.1, 4104.1 14675.1, 4115.1 14686.1, 4126.1 14697.1, 4137.1 14708.1, 4148.1 14719.1, 4159.1 14730.1, 4170.1 14741.1, 4181.1 14752.1, 4192.1 14763.1, 4203.1 14774.1, 4214.1 20065.1, 4225.1 20076.1, 4236.1 20087.1, 4247.1 20098.1, 4258.1 20109.1, 4269.1 20120.1, 4280.1 20131.1, 4291.1 20142.1, 4302.1 20153.1, 4313.1 20164.1, 4324.1 20175.1, 4335.1 20186.1, 4346.1 20197.1, 4357.1 20208.1, 4368.1 20219.1, 4379.1 20230.1, 4390.1 20241.1, 4401.1 20252.1, 4412.1 20263.1, 4423.1 20274.1, 4434.1 20285.1, 4445.1 20296.1, 4456.1 20307.1, 4467.1 20318.1, 4478.1 20329.1, 4489.1 20340.1, 4500.1 20351.1, 4511.1 20362.1, 4522.1 20373.1, 4533.1 20384.1, 4544.1 20395.1, 4555.1 20406.1, 4566.1 20417.1, 4577.1 20428.1, 4588.1 20439.1, 4599.1 20450.1, 4610.1 20461.1, 4621.1 20472.1, 4632.1 20483.1, 4643.1 20494.1, 4654.1 20505.1, 4665.1 20516.1, 4676.1 20527.1, 4687.1 20538.1, 4698.1 20549.1, 4709.1 20560.1, 4720.1 20571.1, 4731.1 20582.1, 4742.1 25873.1, 4753.1 25884.1, 4764.1 25895.1, 4775.1 25906.1, 4786.1 25917.1, 4797.1 25928.1, 4808.1 25939.1, 4819.1 25950.1, 4830.1 25961.1, 4841.1 25972.1, 4852.1 25983.1, 4863.1 25994.1, 4874.1 26005.1, 4885.1 26016.1, 4896.1 26027.1, 4907.1 26038.1, 4918.1 26049.1, 4929.1 26060.1, 4940.1 26071.1, 4951.1 26082.1, 4962.1 26093.1, 4973.1 26104.1, 4984.1 26115.1, 4995.1 26126.1, 5006.1 26137.1, 5017.1 26148.1, 5028.1 26159.1, 5039.1 26170.1, 5050.1 26181.1, 5061.1 26192.1, 5072.1 26203.1, 5083.1 26214.1, 5094.1 26225.1, 5105.1 26236.1, 5116.1 26247.1, 5127.1 26258.1, 5138.1 26269.1, 5149.1 26280.1, 5160.1 26291.1, 5171.1 26302.1, 5182.1 26313.1, 5193.1 26324.1, 5204.1 26335.1, 5215.1 26346.1, 5226.1 26357.1, 5237.1 26368.1, 5248.1 26379.1, 5259.1 26390.1, 5270.1 31681.1, 5281.1 31692.1, 5292.1 31703.1, 5303.1 31714.1, 5314.1 31725.1, 5325.1 31736.1, 5336.1 31747.1, 5347.1 31758.1, 5358.1 31769.1, 5369.1 31780.1, 5380.1 31791.1, 5391.1 31802.1, 5402.1 31813.1, 5413.1 31824.1, 5424.1 31835.1, 5435.1 31846.1, 5446.1 31857.1, 5457.1 31868.1, 5468.1 31879.1, 5479.1 31890.1, 5490.1 31901.1, 5501.1 31912.1, 5512.1 31923.1, 5523.1 31934.1, 5534.1 31945.1, 5545.1 31956.1, 5556.1 31967.1, 5567.1 31978.1, 5578.1 31989.1, 5589.1 32000.1, 5600.1 32011.1, 5611.1 32022.1, 5622.1 32033.1, 5633.1 32044.1, 5644.1 32055.1, 5655.1 32066.1, 5666.1 32077.1, 5677.1 32088.1, 5688.1 32099.1, 5699.1 32110.1, 5710.1 32121.1, 5721.1 32132.1, 5732.1 32143.1, 5743.1 32154.1, 5754.1 32165.1, 5765.1 32176.1, 5776.1 32187.1, 5787.1 32198.1, 5798.1