#include <valarray>
#include <stdexcept>

#include "matrix.h"

using namespace std;

template<class T>
  valarray<T>& matmul_v(valarray<T>& a, const valarray<T>& b, size_t n) {
    if (n == 1)
      return a = a * b;

    gslice
      s11(0, {n / 2, n / 2}, {n, 1}),
      s12(n / 2, {n / 2, n / 2}, {n, 1}),
      s21(n * n / 2, {n / 2, n / 2}, {n, 1}),
      s22(n * n / 2 + n / 2, {n / 2, n / 2}, {n, 1});

    valarray<T> h = b[s11]; h += b[s22];
    valarray<T> m1 = a[s11]; m1 += a[s22];
    m1 = matmul_v(m1, h, n / 2);
    h = b[s11];
    valarray<T> m2 = a[s21]; m2 += a[s22];
    m2 = matmul_v(m2, h, n / 2);
    h = b[s12]; h -= b[s22];
    valarray<T> m3 = a[s11];
    m3 = matmul_v(m3, h, n / 2);
    h = b[s21]; h -= b[s11];
    valarray<T> m4 = a[s22];
    m4 = matmul_v(m4, h, n / 2);
    h = b[s22];
    valarray<T> m5 = a[s11]; m5 += a[s12];
    m5 = matmul_v(m5, h, n / 2);
    h = b[s11]; h += b[s12];
    valarray<T> m6 = a[s21]; m6 -= a[s11];
    m6 = matmul_v(m6, h, n / 2);
    h = b[s21]; h += b[s22];
    valarray<T> m7 = a[s12]; m7 -= a[s22];
    m7 = matmul_v(m7, h, n / 2);

    a[s11] = m1 + m4 - m5 + m7;
    a[s12] = m3 + m5;
    a[s21] = m2 + m4;
    a[s22] = m1 - m2 + m3 + m6;
    return a;
}

bool power_of_two(size_t i)
  { return i && !(i & (i - 1)); }

template<class T>
  matrix<T> matmul(matrix<T> a, matrix<T> b) {
    if (a.rows() != a.columns() || a.rows() != b.rows() || a.rows() != b.columns())
      throw invalid_argument("matrices not square or not of same dimension");
    if (!power_of_two(a.columns()))
      throw invalid_argument("matrix dimension not power of two");

    valarray<T> av = a;
    av = matmul_v<T>(av, b, a.rows());
    return matrix<T>{valarray<T>(av), a.rows(), a.columns()};
}

template matrix<double> matmul<double>(matrix<double>, matrix<double>);
