#include <iostream>

uint32_t node_id = 0;
uint32_t failed = 0;

uint32_t popcnt(const uint32_t x)
{ return __builtin_popcount(x); }

uint32_t getMSBindex(const uint32_t x)
{ return 31 - __builtin_clz(x); }

uint32_t getMSBmask(const uint32_t x)
{ return 1 << getMSBindex(x); }

uint32_t notSimpleCase(const uint32_t x)
{ return ((x+1) & x) && ((x+2) & (x+1)); }

/*
uint32_t parent(const uint32_t node)
{
    uint32_t x = node;

    while (notSimpleCase(x))
    {
        x &= ~getMSBmask(x);
        ++x;
    }

    return node + 1 + (((x+1) & x)? 0: x);
}

uint32_t parent(const uint32_t node)
{
    uint32_t x = node;
    uint32_t mid = getMSBindex(node) / 2;

    while (notSimpleCase(x) && mid >= 3)
    {
        const uint32_t mask = (1 << mid) - 1;
        const uint32_t y = (x & mask) + popcnt(x & ~mask);

        if ((y & ~mask) || y == mask) // overflow
            break;

        x = y;
        mid /= 2;
    }

    while (notSimpleCase(x))
    {
        x &= ~getMSBmask(x);
        ++x;
    }

    return node + 1 + (((x+1) & x)? 0: x);
}
*/

uint32_t parent(const uint32_t node)
{
    uint32_t x = node;
    uint32_t bit = x;

    while ((x & bit) && notSimpleCase(x))
    {
        const uint32_t y = x + popcnt(x);
        bit = getMSBmask(y & ~x);
        const uint32_t mask = (bit << 1) - 1;
        const uint32_t z = (x & mask) + popcnt(x & ~mask);

        if (z == mask && (x & (bit << 1)))
            return node + 1;

        x = z;
    }

    if (notSimpleCase(x))
        return node + 1;
    else
        return node + 1 + (((x+1) & x)? 0: x);
}

uint32_t test(const int32_t depth)
{
    if (depth)
    {
        const uint32_t left = test(depth - 1);
        const uint32_t right = test(depth - 1);
        ++node_id;
        if (left != node_id) ++failed;
        if (right != node_id) ++failed;
    }
    else
    {
        ++node_id;
    }

    return parent(node_id);
}

int main()
{
    test(23);
    std::cout << "nodes: " << node_id << '\n';
    std::cout << "failed: " << failed << '\n';
}
