#include <stdio.h>
#include <stdlib.h>

typedef unsigned char uint8_t;
typedef unsigned short int uint16_t;
typedef unsigned int uint32_t;

// Define structures
typedef struct {
    char name[2];
    unsigned int size;
    unsigned int reserved;
    unsigned int image_offset;
} BMP_HEADER;

typedef struct {
    unsigned int size;
    unsigned int width;
    unsigned int height;
    unsigned short int number_color_plane;
    unsigned short int number_bit_per_pixel;
    unsigned int compression_method;
    unsigned int image_size;
    unsigned int horizontal_resolution;
    unsigned int vertical_resolution;
    unsigned int number_color;
    unsigned int ignored;
} DIB_HEADER;

typedef struct {
    uint8_t y;
    uint8_t r;
    uint8_t g;
    uint8_t b;
} PIXEL;

typedef struct {
    BMP_HEADER header;
    DIB_HEADER dib;
    PIXEL* image;
} BMP_FILE;

// Function prototypes
int openbmpfile(const char* file_name, BMP_FILE* input_file);
int createbmpfile(const char* file_name, BMP_FILE* input_file);
int image_scale(BMP_FILE* bmp, BMP_FILE* output, float times);

int openbmpfile(const char* file_name, BMP_FILE* input_file) {
    FILE* bmp_file;
    bmp_file = fopen(file_name, "rb");
    if (bmp_file == NULL) return 1;

    // Read header of bmp file
    fread(input_file->header.name, 2, 1, bmp_file);
    fread(&(input_file->header.size), 3 * sizeof(int), 1, bmp_file);
    if ((input_file->header.name[0] != 'B') || (input_file->header.name[1] != 'M')) {
        fclose(bmp_file);
        return 1;
    }

    // Read dib of bmp file
    fread(&(input_file->dib), 40, 1, bmp_file);
    if ((input_file->dib.size != 40) || (input_file->dib.compression_method != 0)) {
        fclose(bmp_file);
        return 1;
    }

    // Go to the initial position of image data
    fseek(bmp_file, input_file->header.image_offset, SEEK_SET);

    // Calculate image data size
    uint32_t size_image_content, i, this_pixel;
    if (input_file->dib.image_size == 0)
        size_image_content = input_file->dib.width * input_file->dib.height;
    else
        size_image_content = input_file->dib.image_size / input_file->dib.number_bit_per_pixel;

    // Declare dynamic array for image data
    input_file->image = (PIXEL*)malloc(size_image_content * sizeof(PIXEL));

    // Read image data
    for (i = 0; i < size_image_content; i++) {
        if (input_file->dib.number_bit_per_pixel == 24) {
            fread(&this_pixel, 3, 1, bmp_file);
            input_file->image[i].r = this_pixel & 0xff;
            input_file->image[i].g = (this_pixel & 0xff00) >> 8;
            input_file->image[i].b = (this_pixel & 0xff0000) >> 16;
        } else if (input_file->dib.number_bit_per_pixel == 32) {
            fread(&this_pixel, 4, 1, bmp_file);
            input_file->image[i].y = this_pixel & 0xff;
            input_file->image[i].r = (this_pixel & 0xff00) >> 8;
            input_file->image[i].g = (this_pixel & 0xff0000) >> 16;
            input_file->image[i].b = (this_pixel & 0xff000000) >> 24;
        }
    }
    fclose(bmp_file);
    return 0;
}

int createbmpfile(const char* file_name, BMP_FILE* input_file) {
    FILE* file = fopen(file_name, "wb");
    if (file == NULL) return 1;
    printf("[log] start write file\n");

    // Write header and dib of bmp file
    fwrite(input_file->header.name, 2, 1, file);
    fwrite(&(input_file->header.size), 12, 1, file);
    fwrite(&(input_file->dib), 40, 1, file);

    // Go to the position of image file
    fseek(file, input_file->header.image_offset, SEEK_SET);

    // Calculate image data size
    uint32_t size_image_content, i, this_pixel;
    if (input_file->dib.image_size == 0)
        size_image_content = input_file->dib.width * input_file->dib.height;
    else
        size_image_content = input_file->dib.image_size / input_file->dib.number_bit_per_pixel;
    printf("[log] size of image: %d\n", size_image_content);

    // Calculate the number of bytes per row
    uint32_t bytes_per_row = input_file->dib.width * (input_file->dib.number_bit_per_pixel / 8);
    // Calculate the number of padding bytes needed to align to 4 bytes
    uint32_t padding_bytes = (4 - (bytes_per_row % 4)) % 4;
    printf("[log] size of image: %d\n", size_image_content);
    // Write image data
    for (i = 0; i < size_image_content; i++) {
        if (input_file->dib.number_bit_per_pixel == 24) {
            this_pixel = 0;
            this_pixel |= input_file->image[i].r;
            this_pixel |= input_file->image[i].g << 8;
            this_pixel |= input_file->image[i].b << 16;
            fwrite(&(this_pixel), 3, 1, file);
        } else if (input_file->dib.number_bit_per_pixel == 32) {
            this_pixel = input_file->image[i].y;
            this_pixel |= input_file->image[i].r << 8;
            this_pixel |= input_file->image[i].g << 16;
            this_pixel |= input_file->image[i].b << 24;
            fwrite(&(this_pixel), 4, 1, file);
        } 
        else
        printf("[log] [error] bit per pixel error\n");
    if(padding_bytes != 0)
    {
        if(i%input_file->dib.width==0)
        {
            uint32_t padding[4] = {0}; 
            fwrite(padding, padding_bytes, 1, file);
        }
    }
        
    }
    fclose(file);
    printf("[log] write file success\n");

    return 0;
}

int image_scale(BMP_FILE* bmp, BMP_FILE* output, float times) {
    uint32_t height = bmp->dib.height * times;
    uint32_t width = bmp->dib.width * times;
    uint32_t origin_width = bmp->dib.width;
    uint32_t origin_height = bmp->dib.height;
    uint16_t bits = bmp->dib.number_bit_per_pixel;

    printf("Height: %u Width: %u\n", height, width);

    // Calculate image data size
    uint32_t size_image_content, i, j;
    uint32_t p1, p2;
    float d1, d2, d3, d4;
    size_image_content = height * width;

    // Copy header and dib
    output->dib = bmp->dib;
    output->header = bmp->header;

    // Change height and width
    output->dib.height = height;
    output->dib.width = width;
    output->header.size = height * width * bmp->dib.number_bit_per_pixel + 54;

    // Generate dynamic array for image data
    output->image = (PIXEL*)malloc(size_image_content * sizeof(PIXEL));

    // Bilinear interpolation
    for (i = 0; i < height; i++) {
        for (j = 0; j < width; j++) {
            p2 = i / times;
            if (p2 == origin_height - 1) p2 -= 1;
            p1 = j / times;
            if (p1 == origin_width - 1) p1 -= 1;
            d1 = i / times - p2;
            d2 = 1 - d1;
            d3 = j / times - p1;
            d4 = 1 - d3;
            output->image[i * width + j].y = (uint8_t)(
                bmp->image[(p2)*origin_width + p1].y * d2 * d4 +
                bmp->image[(p2 + 1) * origin_width + p1].y * d1 * d4 +
                bmp->image[(p2)*origin_width + p1 + 1].y * d2 * d3 +
                bmp->image[(p2 + 1) * origin_width + p1 + 1].y * d1 * d3);
            output->image[i * width + j].r = (uint8_t)(
                bmp->image[(p2)*origin_width + p1].r * d2 * d4 +
                bmp->image[(p2 + 1) * origin_width + p1].r * d1 * d4 +
                bmp->image[(p2)*origin_width + p1 + 1].r * d2 * d3 +
                bmp->image[(p2 + 1) * origin_width + p1 + 1].r * d1 * d3);
            output->image[i * width + j].g = (uint8_t)(
                bmp->image[(p2)*origin_width + p1].g * d2 * d4 +
                bmp->image[(p2 + 1) * origin_width + p1].g * d1 * d4 +
                bmp->image[(p2)*origin_width + p1 + 1].g * d2 * d3 +
                bmp->image[(p2 + 1) * origin_width + p1 + 1].g * d1 * d3);
            output->image[i * width + j].b = (uint8_t)(
                bmp->image[(p2)*origin_width + p1].b * d2 * d4 +
                bmp->image[(p2 + 1) * origin_width + p1].b * d1 * d4 +
                bmp->image[(p2)*origin_width + p1 + 1].b * d2 * d3 +
                bmp->image[(p2 + 1) * origin_width + p1 + 1].b * d1 * d3);
        }
    }
    return 0;
}

int main() {
    BMP_FILE input_file_1, input_file_2, output_file;
    char input_file_name_1[] = "hw1.bmp";
    char input_file_name_2[] = "hw2.bmp";

    // Read file
    openbmpfile(input_file_name_1, &input_file_1);
    openbmpfile(input_file_name_2, &input_file_2);

    // Image scale
    image_scale(&input_file_1, &output_file, 1.5);
    createbmpfile("output1_up.bmp", &output_file);
    free(output_file.image);
    image_scale(&input_file_2, &output_file, 1.5);
    createbmpfile("output2_up.bmp", &output_file);
    free(output_file.image);
    image_scale(&input_file_1, &output_file, 1.0 / 1.5);
    createbmpfile("output1_down.bmp", &output_file);
    free(output_file.image);
    image_scale(&input_file_2, &output_file, 1.0 / 1.5);
    createbmpfile("output2_down.bmp", &output_file);
    free(output_file.image);

    free(input_file_1.image);
    free(input_file_2.image);
    return 0;
}
