#include <stdint.h>
#include <stdio.h>
#include <string.h>
#include <emmintrin.h>
#include <math.h>

struct md5x4_context {
    __m128i k[64];
    __m128i iv[4];
    __m128i ff;

    __m128i a, b, c, d;
};

void md5x4_pre_init(struct md5x4_context* ctx) {
    for (uint32_t i = 0; i < 64; ++i) {
        uint32_t x = floor(fabs(sin(i + 1)) * (double)0x100000000ull);
        ctx->k[i] = _mm_set1_epi32(x);
    }
    ctx->iv[0] = _mm_set1_epi32(0x67452301);
    ctx->iv[1] = _mm_set1_epi32(0xefcdab89);
    ctx->iv[2] = _mm_set1_epi32(0x98badcfe);
    ctx->iv[3] = _mm_set1_epi32(0x10325476);
    ctx->ff = _mm_set1_epi32(0xFFFFFFFF);
}

void md5x4_init(struct md5x4_context* ctx) {
    ctx->a = ctx->iv[0];
    ctx->b = ctx->iv[1];
    ctx->c = ctx->iv[2];
    ctx->d = ctx->iv[3];
}

inline __m128i md5x4_fx(int s, __m128i a, __m128i b, __m128i k, __m128i x, __m128i f) {
    f = _mm_add_epi32(_mm_add_epi32(f, a), _mm_add_epi32(k, x));
    return _mm_add_epi32(b, _mm_or_si128(_mm_slli_epi32(f, s), _mm_srli_epi32(f, 32 - s)));
}

inline __m128i md5x4_f1(int s, __m128i a, __m128i b, __m128i c, __m128i d, __m128i k, __m128i x) {
    return md5x4_fx(s, a, b, k, x, _mm_or_si128(_mm_and_si128(b, c), _mm_andnot_si128(b, d)));
}

inline __m128i md5x4_f2(int s, __m128i a, __m128i b, __m128i c, __m128i d, __m128i k, __m128i x) {
    return md5x4_fx(s, a, b, k, x, _mm_or_si128(_mm_and_si128(d, b), _mm_andnot_si128(d, c)));
}

inline __m128i md5x4_f3(int s, __m128i a, __m128i b, __m128i c, __m128i d, __m128i k, __m128i x) {
    return md5x4_fx(s, a, b, k, x, _mm_xor_si128(_mm_xor_si128(b, c), d));
}

inline __m128i md5x4_f4(int s, __m128i a, __m128i b, __m128i c, __m128i d, __m128i k, __m128i x) {
    return md5x4_fx(s, a, b, k, x, _mm_xor_si128(c, _mm_or_si128(b, _mm_xor_si128(d, _mm_set1_epi32(0xFFFFFFFF)))));
}

void md5x4_raw_update(struct md5x4_context* ctx, const uint8_t blocks[4][64])
{
    __m128i x[16];
    for (size_t i = 0; i < 4; ++i) {
        __m128 x0 = _mm_loadu_ps((const float*)blocks[0] + i * 4);
        __m128 x1 = _mm_loadu_ps((const float*)blocks[1] + i * 4);
        __m128 x2 = _mm_loadu_ps((const float*)blocks[2] + i * 4);
        __m128 x3 = _mm_loadu_ps((const float*)blocks[3] + i * 4);

        __m128 t0 = _mm_unpacklo_ps(x0, x1);
        __m128 t1 = _mm_unpackhi_ps(x0, x1);
        __m128 t2 = _mm_unpacklo_ps(x2, x3);
        __m128 t3 = _mm_unpackhi_ps(x2, x3);

        x[i * 4 + 0] = _mm_castps_si128(_mm_movelh_ps(t0, t2));
        x[i * 4 + 1] = _mm_castps_si128(_mm_movehl_ps(t2, t0));
        x[i * 4 + 2] = _mm_castps_si128(_mm_movelh_ps(t1, t3));
        x[i * 4 + 3] = _mm_castps_si128(_mm_movehl_ps(t3, t1));
    }

    __m128i a = ctx->a;
    __m128i b = ctx->b;
    __m128i c = ctx->c;
    __m128i d = ctx->d;

    a = md5x4_f1(7, a, b, c, d, ctx->k[0], x[0]);
    d = md5x4_f1(12, d, a, b, c, ctx->k[1], x[1]);
    c = md5x4_f1(17, c, d, a, b, ctx->k[2], x[2]);
    b = md5x4_f1(22, b, c, d, a, ctx->k[3], x[3]);
    a = md5x4_f1(7, a, b, c, d, ctx->k[4], x[4]);
    d = md5x4_f1(12, d, a, b, c, ctx->k[5], x[5]);
    c = md5x4_f1(17, c, d, a, b, ctx->k[6], x[6]);
    b = md5x4_f1(22, b, c, d, a, ctx->k[7], x[7]);
    a = md5x4_f1(7, a, b, c, d, ctx->k[8], x[8]);
    d = md5x4_f1(12, d, a, b, c, ctx->k[9], x[9]);
    c = md5x4_f1(17, c, d, a, b, ctx->k[10], x[10]);
    b = md5x4_f1(22, b, c, d, a, ctx->k[11], x[11]);
    a = md5x4_f1(7, a, b, c, d, ctx->k[12], x[12]);
    d = md5x4_f1(12, d, a, b, c, ctx->k[13], x[13]);
    c = md5x4_f1(17, c, d, a, b, ctx->k[14], x[14]);
    b = md5x4_f1(22, b, c, d, a, ctx->k[15], x[15]);

    a = md5x4_f2(5, a, b, c, d, ctx->k[16], x[1]);
    d = md5x4_f2(9, d, a, b, c, ctx->k[17], x[6]);
    c = md5x4_f2(14, c, d, a, b, ctx->k[18], x[11]);
    b = md5x4_f2(20, b, c, d, a, ctx->k[19], x[0]);
    a = md5x4_f2(5, a, b, c, d, ctx->k[20], x[5]);
    d = md5x4_f2(9, d, a, b, c, ctx->k[21], x[10]);
    c = md5x4_f2(14, c, d, a, b, ctx->k[22], x[15]);
    b = md5x4_f2(20, b, c, d, a, ctx->k[23], x[4]);
    a = md5x4_f2(5, a, b, c, d, ctx->k[24], x[9]);
    d = md5x4_f2(9, d, a, b, c, ctx->k[25], x[14]);
    c = md5x4_f2(14, c, d, a, b, ctx->k[26], x[3]);
    b = md5x4_f2(20, b, c, d, a, ctx->k[27], x[8]);
    a = md5x4_f2(5, a, b, c, d, ctx->k[28], x[13]);
    d = md5x4_f2(9, d, a, b, c, ctx->k[29], x[2]);
    c = md5x4_f2(14, c, d, a, b, ctx->k[30], x[7]);
    b = md5x4_f2(20, b, c, d, a, ctx->k[31], x[12]);

    a = md5x4_f3(4, a, b, c, d, ctx->k[32], x[5]);
    d = md5x4_f3(11, d, a, b, c, ctx->k[33], x[8]);
    c = md5x4_f3(16, c, d, a, b, ctx->k[34], x[11]);
    b = md5x4_f3(23, b, c, d, a, ctx->k[35], x[14]);
    a = md5x4_f3(4, a, b, c, d, ctx->k[36], x[1]);
    d = md5x4_f3(11, d, a, b, c, ctx->k[37], x[4]);
    c = md5x4_f3(16, c, d, a, b, ctx->k[38], x[7]);
    b = md5x4_f3(23, b, c, d, a, ctx->k[39], x[10]);
    a = md5x4_f3(4, a, b, c, d, ctx->k[40], x[13]);
    d = md5x4_f3(11, d, a, b, c, ctx->k[41], x[0]);
    c = md5x4_f3(16, c, d, a, b, ctx->k[42], x[3]);
    b = md5x4_f3(23, b, c, d, a, ctx->k[43], x[6]);
    a = md5x4_f3(4, a, b, c, d, ctx->k[44], x[9]);
    d = md5x4_f3(11, d, a, b, c, ctx->k[45], x[12]);
    c = md5x4_f3(16, c, d, a, b, ctx->k[46], x[15]);
    b = md5x4_f3(23, b, c, d, a, ctx->k[47], x[2]);

    a = md5x4_f4(6, a, b, c, d, ctx->k[48], x[0]);
    d = md5x4_f4(10, d, a, b, c, ctx->k[49], x[7]);
    c = md5x4_f4(15, c, d, a, b, ctx->k[50], x[14]);
    b = md5x4_f4(21, b, c, d, a, ctx->k[51], x[5]);
    a = md5x4_f4(6, a, b, c, d, ctx->k[52], x[12]);
    d = md5x4_f4(10, d, a, b, c, ctx->k[53], x[3]);
    c = md5x4_f4(15, c, d, a, b, ctx->k[54], x[10]);
    b = md5x4_f4(21, b, c, d, a, ctx->k[55], x[1]);
    a = md5x4_f4(6, a, b, c, d, ctx->k[56], x[8]);
    d = md5x4_f4(10, d, a, b, c, ctx->k[57], x[15]);
    c = md5x4_f4(15, c, d, a, b, ctx->k[58], x[6]);
    b = md5x4_f4(21, b, c, d, a, ctx->k[59], x[13]);
    a = md5x4_f4(6, a, b, c, d, ctx->k[60], x[4]);
    d = md5x4_f4(10, d, a, b, c, ctx->k[61], x[11]);
    c = md5x4_f4(15, c, d, a, b, ctx->k[62], x[2]);
    b = md5x4_f4(21, b, c, d, a, ctx->k[63], x[9]);

    ctx->a = _mm_add_epi32(ctx->a, a);
    ctx->b = _mm_add_epi32(ctx->b, b);
    ctx->c = _mm_add_epi32(ctx->c, c);
    ctx->d = _mm_add_epi32(ctx->d, d);
}

void md5x4_final(struct md5x4_context* ctx, uint8_t out[4][16]) {
    __m128 a = _mm_castsi128_ps(ctx->a);
    __m128 b = _mm_castsi128_ps(ctx->b);
    __m128 c = _mm_castsi128_ps(ctx->c);
    __m128 d = _mm_castsi128_ps(ctx->d);

    __m128 t0 = _mm_unpacklo_ps(a, b);
    __m128 t1 = _mm_unpackhi_ps(a, b);
    __m128 t2 = _mm_unpacklo_ps(c, d);
    __m128 t3 = _mm_unpackhi_ps(c, d);

    _mm_storeu_ps((float*)out[0], _mm_movelh_ps(t0, t2));
    _mm_storeu_ps((float*)out[1], _mm_movehl_ps(t2, t0));
    _mm_storeu_ps((float*)out[2], _mm_movelh_ps(t1, t3));
    _mm_storeu_ps((float*)out[3], _mm_movehl_ps(t3, t1));
}


void unsafe_pad_block(uint8_t block[64], size_t length) {
    block[length] = 0x80;
    *(uint64_t*)(block + 56) = length * 8;
}

int main() {
    struct md5x4_context ctx;
    md5x4_pre_init(&ctx);

    uint8_t data[4][64] = {
        "We are the others",
        "We are the cast-outs",
        "We're the outsiders",
        "But you can't hide us",
    };

    unsafe_pad_block(data[0], 17);
    unsafe_pad_block(data[1], 20);
    unsafe_pad_block(data[2], 19);
    unsafe_pad_block(data[3], 21);

    uint8_t digests[4][16] = {0};

    md5x4_init(&ctx);
    md5x4_raw_update(&ctx, data);
    md5x4_final(&ctx, digests);

    for (size_t i = 0; i < 4; ++i) {
        for (size_t j = 0; j < 16; ++j)
            printf("%02X ", digests[i][j]);
        printf("\n");
    }

    return 0;
}