#include <algorithm>
#include <assert.h>
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <iomanip>
#include <string>
#include <vector>
 
// -------------------------------------------------------------------------------------------------
 
class Matrix
{
private:
  static const size_t MAX_DIM = 6;
  std::vector<double> m_data;             // data container
  size_t              m_ndim=0;           // actual number of dimensions
  size_t              m_size=0;           // total number of entries == data.size() == prod(shape)
  size_t              m_shape[MAX_DIM];   // number of entries in each dimensions
  size_t              m_strides[MAX_DIM]; // stride length for each index
 
public:
 
  // constructor
  Matrix(){};                                     // empty
  Matrix(const std::vector<size_t> &shape);       // allocate, don't initialize
 
  // index operators: access plain storage
  double&       operator[](size_t i);
  const double& operator[](size_t i) const;
 
  // index operators: access using Matrix indices
  double&       operator()(size_t a);
  const double& operator()(size_t a) const;
  double&       operator()(size_t a, size_t b);
  const double& operator()(size_t a, size_t b) const;
  double&       operator()(size_t a, size_t b, size_t c);
  const double& operator()(size_t a, size_t b, size_t c) const;
  double&       operator()(size_t a, size_t b, size_t c, size_t d);
  const double& operator()(size_t a, size_t b, size_t c, size_t d) const;
  double&       operator()(size_t a, size_t b, size_t c, size_t d, size_t e);
  const double& operator()(size_t a, size_t b, size_t c, size_t d, size_t e) const;
  double&       operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f);
  const double& operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f) const;
 
  // pointer to data
  double*       data();
  const double* data() const;
 
  // iterator to first and last entry
  auto begin();
  auto begin() const;
  auto end();
  auto end() const;
 
  // get dimensions
  size_t size() const;
  size_t shape(size_t i) const;
  std::vector<size_t> shape() const;
 
  // initialization
  void setZero();
 
  // basic algebra
  Matrix sum(size_t axis) const;
};
 
// -------------------------------------------------------------------------------------------------
 
inline std::vector<size_t> del(const std::vector<size_t> &A, size_t idx);
 
// -------------------------------------------------------------------------------------------------
 
inline Matrix::Matrix(const std::vector<size_t> &shape)
{
  assert( shape.size()  > 0       );
  assert( shape.size() <= MAX_DIM );
 
  for ( size_t i = 0 ; i < MAX_DIM ; ++i )
  {
    m_shape  [i] = 1;
    m_strides[i] = 1;
  }
 
  m_ndim = shape.size();
  m_size = 1;
 
  for ( size_t i = 0 ; i < m_ndim ; ++i )
  {
    m_shape[i] = shape[i];
    m_size    *= shape[i];
  }
 
  for ( size_t i = 0 ; i < m_ndim ; ++i )
    for ( size_t j = i+1 ; j < m_ndim ; ++j )
      m_strides[i] *= m_shape[j];
 
  m_data.resize(m_size);
}
 
// -------------------------------------------------------------------------------------------------
 
inline double& Matrix::operator[](size_t i)
{
  return m_data[i];
}
 
// -------------------------------------------------------------------------------------------------
 
inline const double& Matrix::operator[](size_t i) const
{
  return m_data[i];
}
 
// -------------------------------------------------------------------------------------------------
 
inline double& Matrix::operator()(size_t a)
{
  return m_data[a*m_strides[0]];
}
 
// -------------------------------------------------------------------------------------------------
 
inline const double& Matrix::operator()(size_t a) const
{
  return m_data[a*m_strides[0]];
}
 
// -------------------------------------------------------------------------------------------------
 
inline double& Matrix::operator()(size_t a, size_t b)
{
  return m_data[a*m_strides[0]+b*m_strides[1]];
}
 
// -------------------------------------------------------------------------------------------------
 
inline const double& Matrix::operator()(size_t a, size_t b) const
{
  return m_data[a*m_strides[0]+b*m_strides[1]];
}
 
// -------------------------------------------------------------------------------------------------
 
inline double& Matrix::operator()(size_t a, size_t b, size_t c)
{
  return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]];
}
 
// -------------------------------------------------------------------------------------------------
 
inline const double& Matrix::operator()(size_t a, size_t b, size_t c) const
{
  return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]];
}
 
// -------------------------------------------------------------------------------------------------
 
inline double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d)
{
  return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]];
}
 
// -------------------------------------------------------------------------------------------------
 
inline const double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d) const
{
  return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]];
}
 
// -------------------------------------------------------------------------------------------------
 
inline double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e)
{
  return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]];
}
 
// -------------------------------------------------------------------------------------------------
 
inline const double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e) const
{
  return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]];
}
 
// -------------------------------------------------------------------------------------------------
 
inline double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f)
{
  return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]+f*m_strides[5]];
}
 
// -------------------------------------------------------------------------------------------------
 
inline const double& Matrix::operator()(size_t a, size_t b, size_t c, size_t d, size_t e, size_t f) const
{
  return m_data[a*m_strides[0]+b*m_strides[1]+c*m_strides[2]+d*m_strides[3]+e*m_strides[4]+f*m_strides[5]];
}
 
// -------------------------------------------------------------------------------------------------
 
inline double* Matrix::data()
{
  return m_data.data();
}
 
// -------------------------------------------------------------------------------------------------
 
inline const double* Matrix::data() const
{
  return m_data.data();
}
 
// -------------------------------------------------------------------------------------------------
 
inline auto Matrix::begin()
{
  return m_data.begin();
}
 
// -------------------------------------------------------------------------------------------------
 
inline auto Matrix::begin() const
{
  return m_data.begin();
}
 
// -------------------------------------------------------------------------------------------------
 
inline auto Matrix::end()
{
  return m_data.end();
}
 
// -------------------------------------------------------------------------------------------------
 
inline auto Matrix::end() const
{
  return m_data.end();
}
 
// -------------------------------------------------------------------------------------------------
 
inline size_t Matrix::shape(size_t i) const
{
  return m_shape[i];
}
 
// -------------------------------------------------------------------------------------------------
 
inline std::vector<size_t> Matrix::shape() const
{
  std::vector<size_t> out(m_ndim);
 
  for ( size_t i = 0 ; i < m_ndim ; ++i ) out[i] = m_shape[i];
 
  return out;
}
 
// -------------------------------------------------------------------------------------------------
 
inline size_t Matrix::size() const
{
  return m_size;
}
 
// -------------------------------------------------------------------------------------------------
 
inline void Matrix::setZero()
{
  for ( size_t i = 0 ; i < m_size ; ++i ) m_data[i] = static_cast<double>(0);
}
 
// -------------------------------------------------------------------------------------------------
 
inline Matrix Matrix::sum(size_t axis) const
{
  Matrix out(del(this->shape(),axis));
 
  out.setZero();
 
  // How many elements to advance after each reduction
  size_t step_axis = m_strides[m_ndim-1];
 
  if ( axis == m_ndim-1)
    step_axis = m_strides[m_ndim-2];
 
  // Position of the first element of the current reduction
  size_t offset_base = 0;
  size_t offset = 0;
 
  size_t s = 0;
  for ( auto &v : out )
  {
    // Current reduced element
    size_t offset_i = offset;
 
    for ( size_t i = 0 ; i < m_shape[axis] ; ++i )
    {
      // - reduce
      v += *(m_data.data() + offset_i);
      // - advance to next element
      offset_i += m_strides[axis];
    }
 
    s = (s + 1) % m_strides[axis];
    if (s == 0)
    {
      offset_base += m_strides[axis - 1];
      offset = offset_base;
    }
    else
    {
      offset += step_axis;
    }
  }
 
  return out;
}
 
// -------------------------------------------------------------------------------------------------
 
inline std::vector<size_t> del(const std::vector<size_t> &A, size_t idx)
{
  assert( idx < A.size() );
 
  std::vector<size_t> B = A;
 
  B.erase(B.begin()+idx, B.begin()+idx+1);
 
  return B;
}
 
// -------------------------------------------------------------------------------------------------
 
int main()
{
  Matrix A({6, 11, 16, 3});
 
  for ( size_t i = 0 ; i < A.size() ; ++i )
    A[i] = static_cast<double>(i) + 0.1;
 
  // Select reduction axis
  size_t axis = 0;
 
  Matrix B = A.sum(axis);
 
  std::vector<size_t> b_shape;
  for (size_t i = 0; i < A.shape().size(); i++)
  {
      if (i == axis) continue;
      b_shape.push_back(A.shape(i));
  }
  Matrix b(b_shape);
 
  b.setZero();
 
  for ( size_t i = 0 ; i < A.shape(0) ; i++ )
    for ( size_t j = 0 ; j < A.shape(1) ; j++ )
      for ( size_t k = 0 ; k < A.shape(2) ; k++ )
        for ( size_t l = 0 ; l < A.shape(3) ; l++ )
          switch (axis)
          {
            case 0:
            b(j, k, l) += A(i, j, k, l);
            break;
            case 1:
            b(i, k, l) += A(i, j, k, l);
            break;
            case 2:
            b(i, j, l) += A(i, j, k, l);
            break;
            case 3:
            b(i, j, k) += A(i, j, k, l);
            break;
          }
 
  bool errors = false;
  for ( size_t i = 0 ; i < b.size() ; ++i )
  {
    if ( std::abs( b[i] - B[i] ) > 1.e-12 )
    {
      std::cout << "Error in [" << i << "] - Correct: " << b[i] << ", result:" << B[i] << std::endl;
      errors = true;
    }
  }
  if (!errors)
  {
      std::cout << "Computation was correct." << std::endl;
  }
 
  return 0;
}
