#include <iostream>
#include <stdio.h>
#include <stdlib.h>
using namespace std;
// Tree node
struct node
{
int data;
struct node *left;
struct node *right;
};
//Utility function to create a tree node
struct node *newNode(int k)
{
struct node *n=(struct node *)malloc(sizeof(struct node));
n->left=NULL;
n->right=NULL;
n->data=k;
return n;
}
// Function takes a root node an integer pointer
// which stores the total count of unival subtrees
bool countOptimized_Util(struct node *root, int *counter)
{
if(!root)
return true;
bool l=countOptimized_Util(root->left,counter);
bool r=countOptimized_Util(root->right,counter);
// both left and right subtrees are unival
if(l&&r)
{
struct node *rl=root->left;
struct node *rr=root->right;
// if leaf node
if(!rl && !rr)
{
(*counter)++;
return true;
}
// left and right child exists and their data is also same as root's data
else if(rl && rr && rl->data==root->data && rr->data==root->data)
{
(*counter)++;
return true;
}
// only left child exists and its data is same as root's data
else if(rl && rl->data==root->data)
{
(*counter)++;
return true;
}
// only right child exists and its data is same as root's data
else if(rr && rr->data==root->data)
{
(*counter)++;
return true;
}
}
return false;
}
// Counts the number of unival subtrees
int countOptimized(struct node *root)
{
int counter=0;
countOptimized_Util(root,&counter);
return counter;
}
// Driver function
int main()
{
struct node *root=NULL;
root=newNode(1);
root->left=newNode(2);
root->left->left=newNode(2);
root->left->right=newNode(2);
root->left->left->left=newNode(5);
root->left->left->right=newNode(5);
root->right=newNode(3);
root->right->left=newNode(3);
root->right->right=newNode(3);
root->right->left->left=newNode(4);
root->right->left->right=newNode(4);
root->right->right->left=newNode(3);
root->right->right->right=newNode(3);
cout<<countOptimized(root);
return 0;
}