#ifdef _OPENMP
#include <omp.h>
#endif
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>

#ifndef ELEM_COUNT
#define ELEM_COUNT (1 << 29)
#elif ELEM_COUNT <= 0
#error Don't you think you're smarter than me?
#endif

#ifdef MAX_ELEMENT
int next_element() {
    return rand() % MAX_ELEMENT;
}
#else
int next_element() {
    return rand();
}
#endif

void print_array(int* arr) {
    printf("[ %d", arr[0]);
    for (int i = 1; i < ELEM_COUNT; ++i) {
        printf(", %d", arr[i]);
    }
    printf(" ]\n");
}

int main() {
    int* from = (int*)malloc(((size_t)ELEM_COUNT) * sizeof(int));
    int* to = (int*)malloc(((size_t)ELEM_COUNT) * sizeof(int));
    if (!from || !to) {
        printf("Buy and install more RAM, looser!\n");
        exit(EXIT_FAILURE);
    }
    for (int i = 0; i < ELEM_COUNT; ++i) {
        from[i] = next_element();
    }
#ifdef DEBUG
    printf("Going to sort array:\n");
    print_array(from);
#else
    printf("Array prepared\n");
#endif
    struct timeval start, end;
    gettimeofday(&start, NULL);
    for (int size = 1; size < ELEM_COUNT; size += size) {
        int* stop = from + ELEM_COUNT;
#pragma omp parallel for
        for (int i = 0; i < ELEM_COUNT; i += size + size) {
            int* a = from + i;
            int* aend = a + size;
            if (aend > stop) {
                aend = stop;
            }
            int* b = aend;
            int* bend = b + size;
            if (bend > stop) {
                bend = stop;
            }
            int* out = to + i;
            while (a < aend && b < bend) {
                if (*a < *b) {
                    *out++ = *a++;
                } else {
                    *out++ = *b++;
                }
            }
            if (a < aend) {
                while (a < aend) {
                    *out++ = *a++;
                }
            } else {
                while (b < bend) {
                    *out++ = *b++;
                }
            }
        }
        int* t = from;
        from = to;
        to = t;
    }
    gettimeofday(&end, NULL);
#ifdef DEBUG
    printf("Array sorted:\n");
    print_array(from);
#endif
    printf("Time taken: %fs\n",
        ((double)(end.tv_sec - start.tv_sec) + (double)(end.tv_usec - start.tv_usec) / 1000000.0));
#ifdef _OPENMP
    printf("Threads: %d\n", omp_get_max_threads());
#endif
    return EXIT_SUCCESS;
}