#include <algorithm>
#include <cmath>
#include <stdint.h>

#define PI_F 3.14159265358979f

template<typename T> class Vec2Base
{
public:

    Vec2Base() {}
    Vec2Base(T a) : x(a), y(a) {}
    Vec2Base(T a, T b) : x(a), y(b) {}

    T x, y;
};

template<typename T> class Vec3Base
{
public:

    Vec3Base() {}
    Vec3Base(T a) : x(a), y(a), z(a) {}
    Vec3Base(T a, T b, T c) : x(a), y(b), z(c) {}

    friend Vec3Base<T> operator/(const Vec3Base& a, const Vec3Base& b)
    {
        Vec3Base<T> res;
        res.x = a.x / b.x; res.y = a.y / b.y; res.z = a.z / b.z;
        return res;
    }

    friend T Dot(const Vec3Base& a, const Vec3Base& b)
    {
        return a.x * b.x + a.y * b.y + a.z * b.z;
    }

    T x, y, z;
};

typedef Vec2Base<float> Vec2f;
typedef Vec3Base<float> Vec3f;

template<typename T>
T SafeSqrt(const T& a)
{
    return std::sqrt(std::max<T>(0, a));
}

template<typename T>
T SignNum(const T& a)
{
    return std::signbit(a) ? T(-1) : T(1);
}

Vec3f Normalize(const Vec3f& a)
{
    const float lenSqr = Dot(a, a);
    const float len = std::sqrt(lenSqr);
    return a / len;
}

void CrashFooSample(
    Vec2f       &aSlope,
    const float  aThetaI,
    const Vec2f &aSample)
{
    if (aThetaI < 0.0001f)
    {
        // Normal incidence - avoid division by zero later
        const float sampleXClamped = std::min(aSample.x, 0.9999f);
        const float radius = SafeSqrt(sampleXClamped / (1 - sampleXClamped));
        const float phi = 2 * PI_F * aSample.y;
        const float sinPhi = std::sin(phi);
        const float cosPhi = std::cos(phi);
        aSlope = Vec2f(radius * cosPhi, radius * sinPhi);
    }
    else
    {
        const float tanThetaI = std::tan(aThetaI);
        const float tanThetaIInv = 1.0f / tanThetaI;
        const float G1 =
            2.0f / (1.0f + SafeSqrt(1.0f + 1.0f / (tanThetaIInv * tanThetaIInv)));

        // Sample x dimension (marginalized PDF - can be sampled directly via CDF^-1)
        float A = 2.0f * aSample.x / G1 - 1.0f;
        if (std::abs(A) == 1.0f)
            A -= SignNum(A) * 1e-4f; // avoid division by zero later
        const float B = tanThetaI;
        const float tmpFract = 1.0f / (A * A - 1.0f);
        const float D = SafeSqrt(B * B * tmpFract * tmpFract - (A * A - B * B) * tmpFract);
        const float slopeX1 = B * tmpFract - D;
        const float slopeX2 = B * tmpFract + D;
        aSlope.x = (A < 0.0f || slopeX2 >(1.0f / tanThetaI)) ? slopeX1 : slopeX2;

        // Sample y dimension
        // Using conditional PDF; however, CDF is not directly invertible, so we use rational fit of CDF^-1.
        // We sample just one half-space - PDF is symmetrical in y dimension.
        // We use improved fit from Mitsuba renderer rather than the original fit from the paper.
        float ySign;
        float yHalfSample;
        if (aSample.y > 0.5f) // pick one positive/negative interval
        {
            ySign = 1.0f;
            yHalfSample = 2.0f * (aSample.y - 0.5f);
        }
        else
        {
            ySign = -1.0f;
            yHalfSample = 2.0f * (0.5f - aSample.y);
        }
        const float z =
            (yHalfSample * (yHalfSample * (yHalfSample *
            -(float)0.365728915865723 + (float)0.790235037209296) -
            (float)0.424965825137544) + (float)0.000152998850436920)
            /
            (yHalfSample * (yHalfSample * (yHalfSample * (yHalfSample *
            (float)0.169507819808272 - (float)0.397203533833404) -
            (float)0.232500544458471) + (float)1) - (float)0.539825872510702);
        aSlope.y = ySign * z * std::sqrt(1.0f + aSlope.x * aSlope.x);
    }
}

Vec3f CrashFoo(
    const Vec3f &aVec3,
    const float  aFloat,
    const Vec2f &aVec2)
{
    const Vec3f vec3Stretch =
        Normalize(Vec3f(aVec3.x, aVec3.x, std::max(aVec3.x, 0.0f)));

    float thetaWolStretch = 0.0f;
    float phiWolStretch = 0.0f;
    if (vec3Stretch.z < 0.999f)
    {
        thetaWolStretch = std::acos(vec3Stretch.z);
        phiWolStretch = std::atan2(vec3Stretch.y, vec3Stretch.x);
    }
    const float sinPhi = std::sin(phiWolStretch);
    const float cosPhi = std::cos(phiWolStretch);

    Vec2f slopeStretch;
    CrashFooSample(slopeStretch, thetaWolStretch, aVec2);

    slopeStretch = Vec2f(
        slopeStretch.x * cosPhi - slopeStretch.y * sinPhi,
        slopeStretch.x * sinPhi + slopeStretch.y * cosPhi);

    Vec2f slope(
        slopeStretch.x * aFloat,
        slopeStretch.y * aFloat);

    return Vec3f(slope.x, slope.y, 1.0f);
}

int32_t main(int32_t argc, const char *argv[])
{
    Vec3f vec3{ 0.00628005248f, -0.999814332f, 0.0182171166f };
    Vec2f vec2{ 0.947231591f, 0.0522233732f };
    float floatVal{ 0.010f };

    Vec3f vecResult = CrashFoo(vec3, floatVal, vec2);

    return (int32_t)vecResult.x;
}
