#include <algorithm>
#include <iostream>
#include <map>
#include <string>
#include <vector>
#include <set>

class Node
{
public:
    bool endOfSentence = false;
    std::set<int> weights;
    std::map<std::string, Node> children;

    Node() = default;
    
    const Node* get(const std::string& word) const
    {
        auto it = children.find(word);
        
        if (it == children.end()) {
            return nullptr;   
        }
        return &it->second;
    }
    
    auto find_by_weight(int weight) const
    {
        return std::find_if(children.begin(),
                            children.end(),
                            [=](const auto& p){ return p.second.weights.count(weight);});
    }

};


class Trie
{
    Node root;
public:

    void add(int weight, const std::vector<std::string>& phrase)
    {
        Node* node = &root;
        for (const auto& word : phrase) {
            node->weights.insert(weight);
            node = &node->children[word];
        }
        node->weights.insert(weight);
        node->endOfSentence = true;
    }

    bool contains(const std::vector<std::string>& phrase) const
    {
        const Node* node = &root;
        for (const auto& word : phrase) {
            node = node->get(word);
            if (node == nullptr) {
                return false;
            }
        }
        return node->endOfSentence;
    }

    void print(int weight) const
    {
        const Node* node = &root;
        const char* sep = "";
        while (node) {
            const auto it = node->find_by_weight(weight);

            if (it == node->children.end()) {
            	break;
            }
            std::cout << sep << it->first;
            sep = " ";
            node = &it->second;
        }
        std::cout << std::endl;
    }

    void print_all() const
    {
        for (int i : root.weights) {
            print(i);   
        }
    }
};

int main(int argc, char* argv[]) {
    const std::vector<std::vector<std::string>> sentences = {
        {"My", "name", "is", "John"},
        {"My", "house", "is", "small"},
        {"Hello", "world"},
        {"Hello", "world", "!"}
    };

    Trie trie;
    int i = 0;
    for (const auto& sentence : sentences) {
        trie.add(i, sentence);
        ++i;
    }

    const std::vector<std::vector<std::string>> queries = {
        {"My", "name", "is", "John"},
        {"My", "house"},
        {"Hello", "world"}
    };

    for (const auto& query : queries) {
        std::cout << trie.contains(query) << std::endl;
    }

    trie.print_all();
}
