#include <algorithm>
#include <bitset>
#include <iomanip>
#include <iostream>
#include <numeric>
#include <vector>

static const std::pair<int, int> kBishopDirections[] = {
	{1, 1}, {-1, 1}, {1, -1}, {-1, -1}};

static bool valid_square(int row, int col) {
	return row >= 0 && col >= 0 && row < 8 && col < 8;
}

template <typename TCont>
auto relevant_squares(int square, const TCont& directions) {
	std::vector<int> res;
	int start_row = square / 8;
	int start_col = square % 8;
	for (auto direction : directions) {
		auto dst_row = start_row + direction.first;
		auto dst_col = start_col + direction.second;
		// If the next square in this direction is invalid, the current square
		// is at the board's edge and should not be added.
		while (valid_square(dst_row + direction.first, dst_col + direction.second)) {
			res.push_back(dst_row * 8 + dst_col);
			dst_row += direction.first;
			dst_col += direction.second;
		}
	}
	std::sort(res.begin(), res.end());
	return res;
}

template <typename TCont, typename TOp>
void for_each_bitboard(const TCont& squares, TOp op) {
	const auto i_max = 1ull << squares.size();
	for (auto i = 0ull; i < i_max; ++i) {
		auto bitboard = 0ull;
		auto bits = i;
		for (const auto sq : squares) {
			if (bits & 1) {
				bitboard |= 1ull << sq;
			}
			bits >>= 1;
		}
		op(bitboard);
	}
}

template <typename TMagic, typename TBitboard>
auto index(TMagic magic, TBitboard bitboard) {
	const auto shift = magic >> 56;
	return (bitboard * magic) >> shift;
}

template <typename TBitboard>
void print_board(TBitboard bitboard) {
	auto reverse_8bits = [](auto b) {
		b = (b & 0xF0) >> 4 | (b & 0x0F) << 4;
		b = (b & 0xCC) >> 2 | (b & 0x33) << 2;
		b = (b & 0xAA) >> 1 | (b & 0x55) << 1;
		return b;
	};

	for (int shift = 56; shift >= 0; shift -= 8) {
		auto row = reverse_8bits((bitboard >> shift) & 0xFF);
		std::cout << std::bitset<8>(row);
		if (shift)
			std::cout << "\n";
	}
}

int main() {
	// const auto square = 2; // C1
	const auto square = 62; // G8
	const auto squares = relevant_squares(square, kBishopDirections);
	std::cout << "Relevant squares:";
	for (auto wtf : squares) {
		std::cout << " " << wtf;
	}
	std::cout << std::endl;

	std::vector<int> bits(64);
	std::iota(bits.begin(), bits.end(), 0);
	std::reverse(bits.begin(), bits.end());
	for (auto square : squares) {
		std::cout << "<<" << std::setw(2) << square << ": ";
		auto shifted = bits;
		shifted.erase(shifted.begin(), shifted.begin() + square);
		for (auto bit : shifted) {
			std::cout << std::setw(3) << bit;
		}
		std::cout << "\n";
	}
	std::cout << "\n";

	const auto magic = 0x3c007f2491c5c260;
	std::cout << "Magic:";
	for (int i = 56; i >= 0; i -= 8) {
		std::cout << std::setw(9) << std::bitset<8>(magic >> i);
		std::cout << "(" << i << ")";
	}
	std::cout << "\n";
	for (auto square : squares) {
		std::cout << "<<" << std::setw(2) << square << ": ";
		auto str = std::bitset<64>(magic << square).to_string();
		for (auto bit : str) {
			std::cout << std::setw(2) << bit;
		}
		std::cout << "\n";
	}
	std::cout << "\n";

	for_each_bitboard(squares, [&](auto bitboard) {
		print_board(bitboard);
		std::cout << " ==> ";
		std::cout << std::setw(3) << std::dec << index(magic, bitboard);
		std::cout << "\n\n";
	});
}