#include <stdio.h>
#include <stdlib.h>
#include <math.h>


typedef struct Node_t
{
   int data;
   struct Node_t* left;
   struct Node_t* right;
} Node;

int treeCompare(Node* n1, Node* n2)
{
   if (!n1) return (n2==NULL);
   if (!n2 || (n1->data != n2->data)) return 0;
   return (treeCompare(n1->left, n2->left) &&
            treeCompare(n1->right, n2->right));
}



//utility
int countNodes(Node* n)
{
   if (!n) return 0;
   return 1+countNodes(n->left)+countNodes(n->right);
}



int* bencodeHelper(Node* n, int* code, int* pos, int* flag)
{
   int hasKids = (n->left!=0);
   code[*flag/32]|=hasKids<<(*flag&31);
   *flag+=1;
   if (hasKids) bencodeHelper(n->left, code, pos, flag);
   code[*pos]=n->data;
   *pos+=1;
   if (hasKids) bencodeHelper(n->right, code, pos, flag);
   return code;
}


int* bencode(Node* h, int* sizeOut)
{
   int nnodes=countNodes(h);
   int nflags = (int)ceil(nnodes/32.0);
   int pos=nflags+1;
   int flag=32;
   int* out;
   *sizeOut = 1+nnodes+nflags;
   out = calloc(*sizeOut, sizeof(int));
   if (!h) return out;
   out[0]=nflags+1;
   return bencodeHelper(h,out,&pos,&flag);
}



Node* bdecodeHelper(int* code, int* pos, int* flag)
{
   Node*n = calloc(1, sizeof(Node));
   int hasKids = code[*flag/32]>>(*flag&31)&1;
   *flag+=1;
   if (hasKids) n->left = bdecodeHelper(code, pos, flag);
   n->data = code[*pos];
   *pos+=1;
   if (hasKids) n->right = bdecodeHelper(code, pos, flag);
   return n;
}

Node* bdecode(int* code)
{
   int flag=32;
   int pos=code[0];
   if (!pos) return NULL;
   return bdecodeHelper(code, &pos, &flag);
}


Node* makeRandomTree(float freq, int maxDepth)
{
   Node* n;
   n = calloc(1,sizeof(Node));
   n->data = rand();
   if (maxDepth-- && (rand()/(float)RAND_MAX < freq))
   {
      n->left = makeRandomTree(freq, maxDepth);
      n->right = makeRandomTree(freq, maxDepth);
   }
   return n;
}

int main(int argc, char* argv[])
{
   Node* head = makeRandomTree(0.79,8);
   Node* dup;
   int i,sz;
   int nnodes = countNodes(head);
   int* code = bencode(head, &sz);
   printf("Tree with %d nodes encodes to %d ints\n",nnodes,sz);
   for (i=0;i<sz;++i)
   {
      printf("%.4x, ",code[i]);
   }
   puts("\n");
   dup = bdecode(code);
   if (treeCompare(head,dup)) { puts("Trees Compare Exactly!\n"); }
   else {puts("FAILURE\n");}
}
