//==============================================================
// Copyright © Intel Corporation
//
// SPDX-License-Identifier: MIT
// =============================================================
#include <CL/sycl.hpp>
#include <vector>
#include <iostream>
#include <fstream>
#include <string>
#include <sstream>
#include <algorithm>
#include <array>
#include <map>
#include<chrono>
#if FPGA || FPGA_EMULATOR
#include <ext/intel/fpga_extensions.hpp>
#endif

using namespace sycl;

// Vector type and data size for this example.
size_t vector_size = 10000;
typedef std::vector<int> IntVector;

const std::string train_data_path = "C:\\Users\\AMAL.T\\Downloads\\grad project\\k_nearest_neighbors_train_data.csv";
const std::string train_label_path = "C:\\Users\\AMAL.T\\Downloads\\grad project\\k_nearest_neighbors_train_label.csv";
const std::string test_data_path = "C:\\Users\\AMAL.T\\Downloads\\grad project\\k_nearest_neighbors_test_data.csv";
const std::string test_label_path = "C:\\Users\\AMAL.T\\Downloads\\grad project\\k_nearest_neighbors_test_label.csv";

std::vector<std::vector<double>>train_data;
std::vector<int>train_label;

std::vector<std::vector<double>>test_data;
std::vector<int>test_label;

std::vector<int>knn_res_test_label;

// Create an exception handler for asynchronous SYCL exceptions
static auto exception_handler = [](sycl::exception_list e_list) {
    for (std::exception_ptr const& e : e_list) {
        try {
            std::rethrow_exception(e);
        }
        catch (std::exception const& e) {
#if _DEBUG
            std::cout << "Failure" << std::endl;
#endif
            std::terminate();
        }
    }
};

void read_csv_data(std::vector<std::vector<double>>& data_vec, std::string path) {
    std::vector<double> row;
    std::string line, word;
    std::fstream file(path);
    if (file.is_open())
    {
        while (getline(file, line))
        {
            row.clear();

            std::stringstream str(line);

            while (getline(str, word, ',')) {
                double val = std::stod(word);
                row.push_back(val);
            }
            data_vec.push_back(row);
        }
        file.close();
    }
    else
        std::cout << "Could not open the file\n";
}
void read_csv_label(std::vector<int>& data_vec, std::string path) {
    std::string line;
    std::fstream file(path);
    if (file.is_open())
    {
        while (getline(file, line))
        {
            std::stringstream str(line);
            data_vec.push_back(std::stoi(line));
        }
    }
    else
        std::cout << "Could not open the file\n";
}

std::vector<double> distance_calculation(std::vector<std::vector<double>>& dataset, std::vector<double>& curr_test) {
    auto start = std::chrono::high_resolution_clock::now();
    std::vector<double>res;
    for (int i = 0; i < dataset.size(); ++i) {
        double dis = 0;
        for (int j = 0; j < dataset[i].size(); ++j) {
            dis += (curr_test[j] - dataset[i][j]) * (curr_test[j] - dataset[i][j]);
        }
        res.push_back(dis);
    }
    auto finish = std::chrono::high_resolution_clock::now();
    std::chrono::duration<double> elapsed = finish - start;
    std::cout << "Elapsed time: " << elapsed.count() << " s\n";
    return res;
}

std::vector<double> distance_calculation_FPGA(queue& q, const std::vector<std::vector<double>>& dataset, const std::vector<double>& curr_test) {
    std::vector<double>linear_dataset;
    for (int i = 0; i < dataset.size(); ++i) {
        for (int j = 0; j < dataset[i].size(); ++j) {
            linear_dataset.push_back(dataset[i][j]);
        }
    }
    range<1> num_items{ dataset.size() };
    std::vector<double>res;
    //std::cout << "im in" << std::endl;

    res.resize(dataset.size());
    buffer dataset_buf(linear_dataset);
    buffer curr_test_buf(curr_test);
    buffer res_buf(res.data(), num_items);
    {
        // auto start = std::chrono::high_resolution_clock::now();
        q.submit([&](handler& h) {
            accessor a(dataset_buf, h, read_only);
            accessor b(curr_test_buf, h, read_only);

            accessor dif(res_buf, h, write_only, no_init);
            h.parallel_for(range<1>(num_items), [=](id<1> i) {
                //  dif[i] = a[i].size() * 1.0;// a[i];
                for (int j = 0; j < 5; ++j) {
                    dif[i] += (b[j] - a[i * 5 + j]) * (b[j] - a[i * 5 + j]);

                }
                });
            }).wait();
       // q.wait();
        //auto finish = std::chrono::high_resolution_clock::now();
        //std::chrono::duration<double> elapsed = finish - start;
     //   std::cout << "Elapsed time: " << elapsed.count() << " s\n";

    }
    /*
        for (int i = 0; i < dataset.size(); ++i) {
            double dis = 0;
            for (int j = 0; j < dataset[i].size(); ++j) {
                dis += (curr_test[j] - dataset[i][j]) * (curr_test[j] - dataset[i][j]);
            }
            res.push_back(dis);
        }
        */
    return res;
}


int main(int argc, char* argv[]) {

    if (argc > 1) vector_size = std::stoi(argv[1]);
#if FPGA_EMULATOR
    ext::intel::fpga_emulator_selector d_selector;
#elif FPGA
    ext::intel::fpga_selector d_selector;
#else
    default_selector d_selector;
#endif


    try {
        queue q(d_selector, exception_handler);

        // my knn project
             // this is on the host definitelly
        read_csv_data(train_data, train_data_path);
        read_csv_data(test_data, test_data_path);
        read_csv_label(train_label, train_label_path);
        read_csv_label(test_label, test_label_path);
        //


        int k = sqrt((int)train_data.size());
        std::unordered_map<int, int>freq_label;
        std::vector<std::pair<double, int>> knn_label;

        // Print out the device information used for the kernel code.
        std::cout << "Running on device: "
            << q.get_device().get_info<info::device::name>() << "\n";


        // std::vector<double> knn_dist = distance_calculation(train_data, test_data[0]);

        for (int i = 0; i < (int)test_data.size(); ++i) {
            //   std::cout << "it is alright " << i<< std::endl;
            std::vector<double> knn_dist = distance_calculation_FPGA(q, train_data, test_data[i]);

            for (int j = 0; j < knn_dist.size(); ++j) {
                knn_label.push_back(std::make_pair(knn_dist[j], train_label[j]));
            }
            sort(knn_label.begin(), knn_label.end());

            int mx = -1, label = -1;
            for (int j = 0; j < k; ++j) {
                //               std::cout << knn_label[j].first << " " << knn_label[j].second << std::endl;
                int cur_label = ++freq_label[knn_label[j].second];
                if (cur_label > mx) {
                    mx = cur_label;
                    label = knn_label[j].second;
                }
            }
            knn_label.clear();
            freq_label.clear();
            knn_res_test_label.push_back(label);
        }
        int corr = 0;
        for (int i = 0; i < test_label.size(); ++i) {
            //   std::cout << "real: " << test_label[i] << " predict: " << knn_res_test_label[i] << std::endl;
            if (test_label[i] == knn_res_test_label[i])
                corr++;

        }
        double accuracy = corr * 1.0 / test_label.size();
        std::cout << corr << " " << accuracy << std::endl;
        //
        // Print out the device information used for the kernel code.
        std::cout << "Running on device: "
            << q.get_device().get_info<info::device::name>() << "\n";

    }
    catch (exception const& e) {
        std::cout << "An exception is caught for vector add.\n";
        std::terminate();
    }



    std::cout << "Vector add successfully completed on device.\n";
    return 0;
}