// Adapted from https://w...content-available-to-author-only...s.org/wavelet-trees-introduction

#include <iostream>
#include <vector>
#include <map>
#include <algorithm>
#include <climits>
using namespace std;

// wavelet tree class 
class wavelet_tree { 
public: 
	// Range to elements 
	int low, high; 

	// Left and Right child 
	wavelet_tree* l, *r; 

	std::vector<int> freq;

	// Default constructor 
	// Array is in range [x, y] 
	// Indices are in range [from, to] 
	wavelet_tree(int* from, int* to, int x, int y) 
	{ 
		// Initialising low and high 
		low = x, high = y; 

		// Array is of 0 length 
		if (from >= to) 
			return; 

		// Array is homogenous 
		// Example : 1 1 1 1 1 
		if (high == low) { 
			// Assigning storage to freq array 
			freq.reserve(to - from + 1); 

			// Initialising the Freq array 
			freq.push_back(0); 

			// Assigning values 
			for (auto it = from; it != to; it++) 
			
				// freq will be increasing as there'll 
				// be no further sub-tree 
				freq.push_back(freq.back() + 1); 
			
			return; 
		} 

		// Computing mid 
		int mid = (low + high) / 2; 

		// Lambda function to check if a number 
		// is less than or equal to mid 
		auto lessThanMid = [mid](int x) { 
			return x <= mid; 
		}; 

		// Assigning storage to freq array 
		freq.reserve(to - from + 1); 

		// Initialising the freq array 
		freq.push_back(0); 

		// Assigning value to freq array 
		for (auto it = from; it != to; it++) 

			// If lessThanMid returns 1(true), we add 
			// 1 to previous entry. Otherwise, we add 0 
			// (element goes to right sub-tree) 
			freq.push_back(freq.back() + lessThanMid(*it));		 

		// std::stable_partition partitions the array w.r.t Mid 
		auto pivot = std::stable_partition(from, to, lessThanMid); 

		// Left sub-tree's object 
		l = new wavelet_tree(from, pivot, low, mid); 

		// Right sub-tree's object 
		r = new wavelet_tree(pivot, to, mid + 1, high); 
	} 

	// Count of numbers in range[L..R] less than 
	// or equal to k 
	int kOrLess(int l, int r, int k) 
	{ 
		// No elements int range is less than k 
		if (l > r or k < low) 
			return 0; 

		// All elements in the range are less than k 
		if (high <= k) 
			return r - l + 1; 

		// Computing LtCount and RtCount 
		int LtCount = freq[l - 1]; 
		int RtCount = freq[r]; 

		// Answer is (no. of element <= k) in 
		// left + (those <= k) in right 
		return (this->l->kOrLess(LtCount + 1, RtCount, k) + 
			this->r->kOrLess(l - LtCount, r - RtCount, k)); 
	} 

	// Count of numbers in range[L..R] less than 
	// or equal to k 
	int kOrMore(int l, int r, int k) 
	{ 
		// No elements int range are greater than k 
		if (l > r or k > high) 
			return 0; 

		// All elements in the range are greater than k 
		if (low >= k) 
			return r - l + 1; 

		// Computing LtCount and RtCount 
		int LtCount = freq[l - 1]; 
		int RtCount = freq[r]; 

		// Answer is (no. of element <= k) in 
		// left + (those <= k) in right 
		return (this->l->kOrMore(LtCount + 1, RtCount, k) + 
			this->r->kOrMore(l - LtCount, r - RtCount, k)); 
	}

}; 

// Driver code 
int main() 
{ 
	int size = 7, high = INT_MIN;
                 // 1  2  3  4  5  6  7
	int arr[] = {1, 2, 3, 2, 4, 3, 1};
	int next[size];
	std::map<int, int> next_idx;
	
	for (int i=size-1; i>=0; i--){
		if (next_idx.find(arr[i]) == next_idx.end())
			next[i] = size + 1;
		else
			next[i] = next_idx[arr[i]];
		next_idx[arr[i]] = i + 1;
		high = max(high, next[i]);
	} 

	// Object of class wavelet tree 
	wavelet_tree obj(next, next + size, 1, high);

	// Queries are NON-zero-based
	//
	//  1  2  3  4  5  6  7
	// {1, 2, 3, 2, 4, 3, 1};
	// query([3, 6]) = 3;
	cout << obj.kOrMore(3, 6, 7) << '\n';
	// query([1, 4]) = 3;
	cout << obj.kOrMore(1, 4, 5) << '\n';
	// query([1, 7]) = 4;
	cout << obj.kOrMore(1, 7, 8) << '\n';

	return 0; 
} 
