#include <Eigen\Dense>
#include <Eigen\StdVector>

#include <time.h>		
#include <iostream>

using Vec3f = Eigen::Vector3f;

class RandomLCG
{
private:
	unsigned int m_seed;

public:
	RandomLCG(unsigned seed = 0) :
		m_seed(seed)
	{}

	float operator()()
	{
		m_seed = 214013u * m_seed + 2531011u;
		return m_seed * (1.0f / 4294967296.0f);
	}
};

class Ray
{
public:
	Vec3f origin;
	Vec3f direction;

	EIGEN_MAKE_ALIGNED_OPERATOR_NEW

		Ray(Vec3f const & ori, Vec3f const & dir) :
		origin(ori), direction(dir.normalized()) {}
};

class Material
{
public:
	enum class Type { DIFFUSE, SPECULAR, REFRACTION };

	Type type;
	float fresnel_ratio;
};

class Sphere
{
public:
	Vec3f m_center;
	float m_radius;

	Vec3f m_color;
	Vec3f m_emission;
	Material m_mateiral;

	float m_sqr_radius;
	float m_max_color_component;
	Vec3f m_color_ratio;

	EIGEN_MAKE_ALIGNED_OPERATOR_NEW

public:
	Sphere(Vec3f && center, float rad,
		Vec3f && emission, Vec3f && color, Material && material) :
		m_center(center), m_color(color), m_emission(emission), m_radius(rad), m_mateiral(material)
	{
		m_sqr_radius = m_radius * m_radius;
		m_max_color_component = m_color.maxCoeff();
		m_color_ratio = (1.0f / m_max_color_component) * m_color;
	}

	bool intersect(Ray const & ray, float & root, float threshold = 1e-20) const
	{
		root = 0.0f;
		// Solve t^2*d.d + 2*t*(o-p).d + (o-p).(o-p)-R^2 = 0
		Vec3f op = m_center - ray.origin;
		float b = op.transpose() * ray.direction;
		float det = b * b - op.squaredNorm() + m_sqr_radius;
		float const eps = 1e-4;

		if (det < 0) return false;

		float dets = (std::sqrt)(det);

		if (b - dets > eps)
		{
			root = b - dets;
			return true;
		}
		else if (b + dets > eps)
		{
			root = b + dets;
			return true;
		}

		return false;
	}
};

class Intersection
{
public:
	enum class Type { GO_INWARD, GO_OUTWARD };

	float root;
	Vec3f position;
	Vec3f normal;
	Vec3f normal_oriented;
	Type type;
	Vec3f color;
	Material material;
	Sphere const * obj;

	EIGEN_MAKE_ALIGNED_OPERATOR_NEW
};

class Scene
{
private:
	std::vector<Sphere, Eigen::aligned_allocator<Sphere> > m_components;

public:
	void add(Sphere const & sphere)
	{
		m_components.push_back(std::move(sphere));
	}

	bool const & intersect(Ray const & ray, Intersection & intersection) const
	{
		bool is_intersect = false;
		float root = (std::numeric_limits<float>::max)();
		Sphere const * pivot = nullptr;

		for (Sphere const & com : m_components)
		{
			float cur_root = 0.0f;
			if (!com.intersect(ray, cur_root)) continue;
			if (root <= cur_root) continue;

			pivot = &com;
			root = cur_root;
			is_intersect = true;
		}

		if (is_intersect)
		{
			//std::cout << "intersected" << std::endl;
			intersection.root = root;
			intersection.position = ray.origin + root * ray.direction;
			intersection.normal = (intersection.position - pivot->m_center).normalized();
			intersection.color = pivot->m_color;
			intersection.type = (intersection.normal.transpose() * ray.direction > 0) ?
				Intersection::Type::GO_INWARD : Intersection::Type::GO_OUTWARD;
			intersection.normal_oriented = intersection.normal * (intersection.type == Intersection::Type::GO_INWARD ? 1.0f : -1.0f);
			intersection.obj = pivot;
			intersection.material = pivot->m_mateiral;

		}

		//std::cout << "return value " << is_intersect << std::endl;
		return is_intersect;
	}
};

Vec3f radiance(Scene const & scene, Ray const & ray, int const depth,
	int const min_depth, int const max_depth, RandomLCG & rand)
{
	//std::cout << "depth " << depth << std::endl;
	Intersection inte;
	bool is_intersect = scene.intersect(ray, inte);
	//std::cout << "intersect test " << is_intersect << std::endl;

	if (!is_intersect)
		return Vec3f::Zero();

	Sphere const * sphere = inte.obj;

	int const new_depth = depth + 1;
	bool reach_max_depth = (new_depth >= max_depth);
	bool reach_min_depth = (new_depth >= min_depth);

	/* use Russian roulette for path termination */
	bool isRR = rand() < sphere->m_max_color_component;

	if (reach_max_depth || (reach_min_depth && !isRR))
		return sphere->m_color;

	/* WARNING: remain consideration */

	Vec3f emission = inte.obj->m_emission;

	switch (inte.material.type)
	{
	case Material::Type::DIFFUSE:
	{
		//std::cout << "DIFFUSE" << std::endl;
		/* local coordinate */
		Vec3f w = inte.normal_oriented;
		Vec3f wo = w.x() < -0.1f || w.x() > 0.1f ? Vec3f(0.0f, 1.0f, 0.0f) : Vec3f(1.0f, 0.0f, 0.0f);
		Vec3f u = (wo.cross(w)).normalized();
		Vec3f v = w.cross(u);

		/* semi-sphere sampling */
		float r = 2.0f * M_PI * rand();
		float d = rand();
		float dsqt = std::sqrt(d);

		Vec3f diffuse_dir = (u * std::cos(r) * dsqt + v * std::sin(r) * dsqt + w * std::sqrt(1.0f - d)).normalized();
		Vec3f diffuse = inte.color.cwiseProduct(radiance(scene,
			Ray{ inte.position, diffuse_dir }, new_depth, min_depth, max_depth, rand));

		return emission + diffuse;
	}
	break;
	case Material::Type::SPECULAR:
	{
		//std::cout << "SPECULAR" << std::endl;
		Vec3f reflect_dir = ray.direction - inte.normal * (2.0f * inte.normal.transpose() * ray.direction);
		Vec3f spectular = inte.color.cwiseProduct(radiance(scene,
			Ray{ inte.position, reflect_dir }, new_depth, min_depth, max_depth, rand));
		return emission + spectular;
	}
	break;
	case Material::Type::REFRACTION:
	{
		//std::cout << "REFRACTION" << std::endl;
		Vec3f reflect_dir = ray.direction - inte.normal * (2.0f * inte.normal.transpose() * ray.direction);

		bool go_inward = (inte.type == Intersection::Type::GO_INWARD);
		float refr_ratio = (go_inward ? inte.material.fresnel_ratio : 1.0f / inte.material.fresnel_ratio);

		float ddn = ray.direction.transpose() * inte.normal_oriented;
		float cos2t = 1.0f - refr_ratio * refr_ratio * (1.0f - ddn * ddn);

		/* total internal reflection */
		if (cos2t < 0.0f)
		{
			Vec3f reflection = inte.color.cwiseProduct(radiance(scene,
				Ray{ inte.position, reflect_dir }, new_depth, min_depth, max_depth, rand));
			return emission + reflection;
		}

		/* compute refraction direction */
		Vec3f refraction_dir =
			(ray.direction * refr_ratio - inte.normal * (go_inward ? 1.0f : -1.0f) * (ddn * refr_ratio + std::sqrt(cos2t))).normalized();

		float R0 = std::abs(refr_ratio - 1.0f) / (refr_ratio + 1.0f);
		float c = 1 - (go_inward ? -ddn : refraction_dir.transpose() * inte.normal);
		float fresnel_reflectance = R0 + (1.0f - R0) * c * c * c * c;
		float P = 0.25f + 0.5f * fresnel_reflectance;

		Vec3f res;
		if (new_depth > 2)
		{
			if (rand() < P)
				res = (fresnel_reflectance / P) *
				radiance(scene, Ray{ inte.position, reflect_dir }, new_depth, min_depth, max_depth, rand);
			else
				res = ((1.0f - fresnel_reflectance) / (1.0f - P)) *
				radiance(scene, Ray{ inte.position, refraction_dir }, new_depth, min_depth, max_depth, rand);
		}
		else
			res =
			fresnel_reflectance * radiance(scene, Ray{ inte.position, reflect_dir }, new_depth, min_depth, max_depth, rand) +
			(1.0f - fresnel_reflectance) * radiance(scene, Ray{ inte.position, refraction_dir }, new_depth, min_depth, max_depth, rand);

		return emission + inte.color.cwiseProduct(res);
	}
	break;
	default:
		std::cout << "ERROR : reach invalid code section" << std::endl;
		return Vec3f::Zero();
		break;
	}
}

inline float clamp(float x) {
	if (x < 0.0f)
		return 0.0f;
	else if (x > 1.0f)
		return 1.0f;
	else
		return x;
}

inline int toInt(float x) {
	return int(pow(clamp(x), 1 / 2.2) * 255 + .5);
}

int main(int argc, char *argv[]) {
	clock_t start = clock();

	Scene scene;

	Sphere spheres[] = {
		Sphere(Vec3f(1e5f + 1.0f, 40.8f, 81.6f),	1e5f,	Vec3f::Zero(),	Vec3f(.75f, .25f, .25f),	{ Material::Type::DIFFUSE, 0.0f }),//Left
		Sphere(Vec3f(-1e5f + 99.0f, 40.8f, 81.6f),	1e5f,	Vec3f::Zero(),	Vec3f(.25f, .25f, .75f),	{ Material::Type::DIFFUSE, 0.0f }),//Rght
		Sphere(Vec3f(50.0f, 40.8f, 1e5f),			1e5f,	Vec3f::Zero(),	Vec3f(.75f, .75f, .75f),	{ Material::Type::DIFFUSE, 0.0f }),//Back
		Sphere(Vec3f(50.0f, 40.8f, -1e5f + 170.0f),	1e5f,	Vec3f::Zero(),	Vec3f::Zero(),				{ Material::Type::DIFFUSE, 0.0f }),//Frnt
		Sphere(Vec3f(50.0f, 1e5f, 81.6f),			1e5f,	Vec3f::Zero(),	Vec3f(.75f, .75f, .75f),	{ Material::Type::DIFFUSE, 0.0f }),//Botm
		Sphere(Vec3f(50.0f, -1e5f + 81.6f, 81.6f),	1e5f,	Vec3f::Zero(),	Vec3f(.75f, .75f, .75f),	{ Material::Type::DIFFUSE, 0.0f }),//Top
		Sphere(Vec3f(27.0f, 16.5f, 47.0f),			16.5f,	Vec3f::Zero(),	Vec3f(1.0f, 1.0f, 1.0f) * .999f,	{ Material::Type::SPECULAR, 0.0f }),//Mirr
		Sphere(Vec3f(73.0f, 16.5f, 78.0f),			16.5f,	Vec3f::Zero(),	Vec3f(1.0f, 1.0f, 1.0f) * .999f,	{ Material::Type::REFRACTION, 2.6f }),//Glas
		Sphere(Vec3f(50.0f, 681.6f - 0.27f, 81.6f),	600.0f,	Vec3f(12.0f, 12.0f, 12.0f), Vec3f::Zero(),	{ Material::Type::DIFFUSE, 0.0f }) //Lite
	};

	for (auto sph : spheres)
		scene.add(sph);

	const int w = 256;
	const int h = 256;

	const int samps = argc == 2 ? atoi(argv[1]) / 4 : 100; // # samples

	const Ray cam(Vec3f(50.0f, 52.0f, 295.6f), Vec3f(0.0f, -0.042612f, -1.0f)); // cam pos, dir
	const Vec3f cx(w * .5135f / h, 0.0f, 0.0f);
	const Vec3f cy = cx.cross(cam.direction).normalized() * .5135f;
	
	Vec3f * buffer = new Vec3f[w * h];

//#pragma omp parallel for schedule(dynamic, 1)       // OpenMP
	// Loop over image rows
	for (int y = 0; y < h; y++) 
	{
		fprintf(stderr, "\rRendering (%d spp) %5.2f%%", samps * 4, 100.*y / (h - 1));
		RandomLCG rand(y);

		// Loop cols
		for (int x = 0; x < w; x++) 
		{
			int const i = (h - y - 1) * w + x;
			//int const i = y * w + x;
			buffer[i] = Vec3f::Zero();
			
			// 2x2 subpixel 
			for (int sy = 0; sy < 2; sy++) for (int sx = 0; sx < 2; sx++)
			{
				Vec3f r = Vec3f::Zero();
				
				for (int s = 0; s < samps; s++) 
				{
					float r1 = 2.0f * rand();
					float r2 = 2.0f * rand();
					float dx = r1 < 1.0f ? sqrt(r1) - 1.0f : 1.0f - sqrt(2.0f - r1);
					float dy = r2 < 1.0f ? sqrt(r2) - 1.0f : 1.0f - sqrt(2.0f - r2);

					Vec3f d = cx * (((sx + 0.5f + dx) / 2.0f + x) / w - 0.5f) +
						cy * (((sy + 0.5f + dy) / 2.0f + y) / h - 0.5f) + cam.direction;

					r = r + radiance(scene, Ray(cam.origin + d * 140.0f, d), 0, 5, 30, rand) * (1.0f / samps);
				}
				buffer[i] = buffer[i] + Vec3f(clamp(r.x()), clamp(r.y()), clamp(r.z()));
			}
			buffer[i] = buffer[i] * 0.25f;

			if (x % 16 == 0 && y % 16 == 0)
				std::cout << buffer[i] << std::endl;

		}
	}

	printf("\n%f sec\n", (float)(clock() - start) / CLOCKS_PER_SEC);

	FILE *f = fopen("myimage4.ppm", "w"); // Write image to PPM file.
	fprintf(f, "P3\n%d %d\n%d\n", w, h, 255);
	for (int i = 0; i < w * h; i++)
		fprintf(f, "%d %d %d\n", toInt(buffer[i][0]), toInt(buffer[i][1]), toInt(buffer[i][2]));
	fclose(f);

	return 0;
}
