// Two nodes in the BST's swapped, correct the BST.
#include <stdio.h>
#include <stdlib.h>

/* A binary tree node has data, pointer to left child
   and a pointer to right child */
struct node
{
    int data;
    struct node *left, *right;
};

// A utility function to swap two integers
void swap( int* a, int* b )
{
    int t = *a;
    *a = *b;
    *b = t;
}

/* Helper function that allocates a new node with the
   given data and NULL left and right pointers. */
struct node* newNode(int data)
{
    struct node* node = (struct node *)malloc(sizeof(struct node));
    node->data = data;
    node->left = NULL;
    node->right = NULL;
    return(node);
}

// This function does inorder traversal to find out the two swapped nodes.
// It sets three pointers, first, middle and last.  If the swapped nodes are
// adjacent to each other, then first and middle contain the resultant nodes
// Else, first and last contain the resultant nodes
void correctBSTUtil( struct node* root, struct node** first,
                     struct node** middle, struct node** last,
                     struct node** prev )
{
    if( root )
    {
        // Recur for the left subtree
        correctBSTUtil( root->left, first, middle, last, prev );

        // If this node is smaller than the previous node, it's violating
        // the BST rule.
        if (*prev && root->data < (*prev)->data)
        {
            // If this is first violation, mark these two nodes as
            // 'first' and 'middle'
            if ( !*first )
            {
                *first = *prev;
                *middle = root;
            }

            // If this is second violation, mark this node as last
            else
                *last = root;
        }

        // Mark this node as previous
        *prev = root;

        // Recur for the right subtree
        correctBSTUtil( root->right, first, middle, last, prev );
    }
}

// A function to fix a given BST where two nodes are swapped.  This
// function uses correctBSTUtil() to find out two nodes and swaps the
// nodes to fix the BST
void correctBST( struct node* root )
{
    // Initialize pointers needed for correctBSTUtil()
    struct node *first, *middle, *last, *prev;
    first = middle = last = prev = NULL;

    // Set the poiters to find out two nodes
    correctBSTUtil( root, &first, &middle, &last, &prev );

    // Fix (or correct) the tree
    if( first && last )
        swap( &(first->data), &(last->data) );
    else if( first && middle ) // Adjacent nodes swapped
        swap( &(first->data), &(middle->data) );

    // else nodes have not been swapped, passed tree is really BST.
}

/* Given a binary tree, print its nodes in inorder*/
void printInorder(struct node* node)
{
    if (node == NULL)
        return;

    /* first recur on left child */
    printInorder(node->left);

    /* then print the data of node */
    printf("%d ", node->data);

    /* now recur on right child */
    printInorder(node->right);
}

void insertBST( struct node** root, int data )
{
    if( !*root )
        *root = newNode( data );
    else if( data < (*root)->data )
        insertBST( &( (*root)->left ), data );
    else
        insertBST( &( (*root)->right ), data );
}

struct node* searchBST( struct node* root, int data )
{
    if( !root )
        return NULL;
    if( root->data == data )
        return root;
    else if( data < root->data )
        return searchBST( root->left, data );
    return searchBST( root->right, data );
}


void testCase( int arr[], int size, struct node* root )
{
    int d1, d2, i;

    for( i = 0; i < 10; ++i )
    {
        printf( "\n\nTest case #%d\n", i );

        //Generate two nodes randomly to swap, then the program corrects it. :)
        d1 = arr[ rand() % size ];
        d2 = arr[ rand() % size ];

        struct node* t1 = searchBST( root, d1 );
        struct node* t2 = searchBST( root, d2 );

        if( t1 && t2 ) // both nodes exist
        {
            printf( "After swapping the nodes %d and %d\n",d1, d2  );

            swap( &(t1->data), &(t2->data) );
            printInorder( root );

            printf( "\nAfter correcting the BST\n" );

            correctBST( root );
            printInorder( root );
        }
    }
}

int main()
{
    struct node* root = NULL;

    int arr[] = { 10, 5, 15, 3, 7, 8, 20, 25 };

    int i, size = sizeof( arr ) / sizeof( *arr );

    for( i = 0; i < size; ++i )
        insertBST( &root, arr[i] );

    printf( "Original tree\n");
    printInorder( root );

    testCase( arr, size, root );

    return 0;
}
