#include <iostream>
#include <chrono>
#include <ratio>
#include <random>
#include <algorithm>
#include <numeric>
#include <functional>
#include <stdlib.h>
#include <x86intrin.h>
using namespace std;

using time_clk = chrono::steady_clock;
using timepoint = chrono::time_point<time_clk>;
using duration = chrono::duration<double, nano>;

constexpr size_t kMegebyte = 1024 * 1024;
constexpr size_t size = 4 * kMegebyte;
alignas(64) uint8_t bytes[size];
alignas(64) unsigned mtable[0x10000];

void fillMtable()
{
    union
    {
        uint64_t dep;
        uint8_t v[8];
    };

    for (size_t i = 0; i < 0x10000; ++i)
    {
        dep = _pdep_u64(i, 0x0606060606060606) | 0x0101010101010101;
        mtable[i] = accumulate(begin(v), end(v), 1u, multiplies<unsigned>{});
    }
}

template<unsigned x>
unsigned bexp(unsigned pow)
{
    unsigned acc1 = x;
    unsigned acc2 = 1;

    for (; pow; pow >>= 1, acc1 *= acc1)
    {
        if (pow & 1)
            acc2 *= acc1;
    }

    return acc2;
}

int main()
{
    constexpr size_t bits = 3;
    mt19937 rnd;
    auto ud = uniform_int_distribution<size_t>(0, (1 << (bits - 1)) - 1);
    generate_n(bytes, size, [&]{ return ud(rnd) * 2 + 1; });

    timepoint start1 = time_clk::now();
    const auto prod1 = accumulate(begin(bytes), end(bytes), 1u, multiplies<unsigned>{});
    timepoint end1 = time_clk::now();
    duration d1 = end1 - start1;
    cout << "accumulate: res = " << prod1 << ", time = " << d1.count() / size << " ns\n";

    timepoint start2 = time_clk::now();
    unsigned p1 = 1u, p2 = 1u, p3 = 1u, p4 = 1u;
    for (size_t i = 0; i < size; i += 4)
    {
        p1 *= bytes[i + 0];
        p2 *= bytes[i + 1];
        p3 *= bytes[i + 2];
        p4 *= bytes[i + 3];
    }
    const auto prod2 = p1 * p2 * p3 * p4;
    timepoint end2 = time_clk::now();
    duration d2 = end2 - start2;
    cout << " accum * 4: res = " << prod2 << ", time = " << d2.count() / size << " ns\n";

    auto m128 = reinterpret_cast<__m128i*>(bytes);
    timepoint start3 = time_clk::now();
    __m256i pv1 = _mm256_set1_epi32(1);
    __m256i pv2 = _mm256_set1_epi32(1);
    __m256i pv3 = _mm256_set1_epi32(1);
    __m256i pv4 = _mm256_set1_epi32(1);
    __m256i pv5 = _mm256_set1_epi32(1);
    __m256i pv6 = _mm256_set1_epi32(1);
    __m256i pv7 = _mm256_set1_epi32(1);
    __m256i pv8 = _mm256_set1_epi32(1);
    for (size_t i = 0; i < size / 16; i += 8)
    {
        const __m256i t1 = _mm256_mullo_epi16(_mm256_cvtepu8_epi16(m128[i + 0]),
                                              _mm256_cvtepu8_epi16(m128[i + 1]));
        pv1 = _mm256_mullo_epi32(pv1, _mm256_and_si256(t1, _mm256_set1_epi32(0xFFFF)));
        pv2 = _mm256_mullo_epi32(pv2, _mm256_srli_epi32(t1, 16));
        const __m256i t2 = _mm256_mullo_epi16(_mm256_cvtepu8_epi16(m128[i + 2]),
                                              _mm256_cvtepu8_epi16(m128[i + 3]));
        pv3 = _mm256_mullo_epi32(pv3, _mm256_and_si256(t2, _mm256_set1_epi32(0xFFFF)));
        pv4 = _mm256_mullo_epi32(pv4, _mm256_srli_epi32(t2, 16));
        const __m256i t3 = _mm256_mullo_epi16(_mm256_cvtepu8_epi16(m128[i + 4]),
                                              _mm256_cvtepu8_epi16(m128[i + 5]));
        pv5 = _mm256_mullo_epi32(pv5, _mm256_and_si256(t3, _mm256_set1_epi32(0xFFFF)));
        pv6 = _mm256_mullo_epi32(pv6, _mm256_srli_epi32(t3, 16));
        const __m256i t4 = _mm256_mullo_epi16(_mm256_cvtepu8_epi16(m128[i + 6]),
                                              _mm256_cvtepu8_epi16(m128[i + 7]));
        pv7 = _mm256_mullo_epi32(pv7, _mm256_and_si256(t4, _mm256_set1_epi32(0xFFFF)));
        pv8 = _mm256_mullo_epi32(pv8, _mm256_srli_epi32(t4, 16));
    }
    pv1 = _mm256_mullo_epi32(pv1, pv2);
    pv3 = _mm256_mullo_epi32(pv3, pv4);
    pv5 = _mm256_mullo_epi32(pv5, pv6);
    pv7 = _mm256_mullo_epi32(pv7, pv8);
    pv1 = _mm256_mullo_epi32(pv1, pv3);
    pv5 = _mm256_mullo_epi32(pv5, pv7);
    pv1 = _mm256_mullo_epi32(pv1, pv5);
    __m128i hi = _mm256_extracti128_si256(pv1, 1);
    __m128i lo = _mm256_extracti128_si256(pv1, 0);
    lo = _mm_mullo_epi32(hi, lo);
    const auto prod3 = unsigned(_mm_extract_epi32(lo, 0) * _mm_extract_epi32(lo, 1) *
        _mm_extract_epi32(lo, 2) * _mm_extract_epi32(lo, 3));
    timepoint end3 = time_clk::now();
    duration d3 = end3 - start3;
    cout << "      AVX2: res = " << prod3 << ", time = " << d3.count() / size << " ns\n";

    auto qwords = reinterpret_cast<uint64_t*>(bytes);
    timepoint start4 = time_clk::now();
    unsigned acc1 = 1;
    unsigned acc3 = 0;
    unsigned acc5 = 0;
    unsigned acc7 = 0;
    for (size_t i = 0; i < size / 8; i += 2)
    {
        auto compr = (qwords[i] << 4) | qwords[i + 1];
        constexpr uint64_t lsb = 0x1111111111111111;
        if ((compr & lsb) != lsb) // if there is at least one even value
        {
            auto b = reinterpret_cast<uint8_t*>(qwords + i);
            acc1 *= accumulate(b, b + 16, acc1, multiplies<unsigned>{});
            if (!acc1)
                break;
        }
        else
        {
            const auto b2 = compr & 0x2222222222222222;
            const auto b4 = compr & 0x4444444444444444;
            const auto b24 = b4 & (b2 * 2);
            const unsigned c7 = __builtin_popcountll(b24);
            acc3 += __builtin_popcountll(b2) - c7;
            acc5 += __builtin_popcountll(b4) - c7;
            acc7 += c7;
        }
    }
    const auto prod4 = acc1 * bexp<3>(acc3) * bexp<5>(acc5) * bexp<7>(acc7);
    timepoint end4 = time_clk::now();
    duration d4 = end4 - start4;
    cout << "binary exp: res = " << prod4 << ", time = " << d4.count() / size << " ns\n";

    auto words = reinterpret_cast<uint16_t*>(bytes);
    timepoint start5 = time_clk::now();
    auto prod5 = 1u;
    for (size_t i = 0; i < size / 2; i += 4)
    {
        const auto t1 = uint32_t(words[i + 0]) * words[i + 1];
        const auto t2 = uint32_t(words[i + 2]) * words[i + 3];
        const auto t3 = uint64_t(t1 & 0xFF00FF) * (t2 & 0xFF00FF);
        prod5 *= uint32_t(t3 >> 32) * uint32_t(t3 & 0xFFFF);
    }
    timepoint end5 = time_clk::now();
    duration d5 = end5 - start5;
    cout << "2 mul in 1: res = " << prod5 << ", time = " << d5.count() / size << " ns\n";

    fillMtable();
    timepoint start6 = time_clk::now();
    unsigned a1 = 1;
    unsigned a2 = 1;
    for (size_t i = 0; i < size / 8; i += 2)
    {
        constexpr uint64_t lsb = 0x0101010101010101;
        if ((qwords[i] & lsb) != lsb || (qwords[i + 1] & lsb) != lsb)
        { // if there is at least one even value
            auto b = reinterpret_cast<uint8_t*>(qwords + i);
            acc1 *= accumulate(b, b + 16, acc1, multiplies<unsigned>{});
            if (!acc1)
                break;
        }
        else
        {
            a1 *= mtable[_pext_u64(qwords[i + 0], 0x0606060606060606)];
            a2 *= mtable[_pext_u64(qwords[i + 1], 0x0606060606060606)];
        }
    }
    const auto prod6 = a1 * a2;
    timepoint end6 = time_clk::now();
    duration d6 = end6 - start6;
    cout << "mult table: res = " << prod6 << ", time = " << d6.count() / size << " ns\n";

    return 0;
}
