#include <iostream>
#include <string>
#include <vector>

using namespace std;

class Visitor;

class Node
{
public:
  vector<Node*> children;
  virtual ~Node() = default;
  virtual void accept(Visitor&);
};

class ANode : public Node
{
public:
  virtual void accept(Visitor&);

};
class BNode : public Node
{
public:
  virtual void accept(Visitor&);
};
class CNode : public Node
{
public:
  virtual void accept(Visitor&);
};

//-- try inheritance

class SubBNode: public BNode
{
public:
  virtual void accept(Visitor&);
};

//--

class Visitor 
{
  public:
  virtual void visit(Node& n);
  virtual void visit(ANode& n);
  virtual void visit(BNode& n);
  virtual void visit(CNode& n);
  virtual void visit(SubBNode& n);
};


//---------------------------------------------

void Node::accept(Visitor& v){
  cout << __PRETTY_FUNCTION__ << endl;
  v.visit(*this);
}
void ANode::accept(Visitor& v){
  cout << __PRETTY_FUNCTION__ << endl;
  v.visit(*this);
}
void BNode::accept(Visitor& v){
  cout << __PRETTY_FUNCTION__ << endl;
  v.visit(*this);
}
void CNode::accept(Visitor& v){
  cout << __PRETTY_FUNCTION__ << endl;
  v.visit(*this);
}
void SubBNode::accept(Visitor& v){
  cout << __PRETTY_FUNCTION__ << endl;
  v.visit(*this);
}
// -----
void Visitor::visit(Node& n){
  cout << __PRETTY_FUNCTION__ << "\t\tDEFAULT" << endl;
}
void Visitor::visit(ANode& n){
  cout << __PRETTY_FUNCTION__ << "\t\tDEFAULT" << endl;
}
void Visitor::visit(BNode& n){
  cout << __PRETTY_FUNCTION__ << "\t\tDEFAULT" << endl;
}
void Visitor::visit(CNode& n){
  cout << __PRETTY_FUNCTION__ << "\t\tDEFAULT" << endl;
}
void Visitor::visit(SubBNode& n){
  cout <<  __PRETTY_FUNCTION__ << "\t\tDEFAULT" << endl;
}

template <typename F>
class FunctorVisitor : public Visitor
{
public:
    explicit FunctorVisitor(F& f) : f(f) {}

    virtual void visit(Node& n) override { f(n);}
    virtual void visit(ANode& n) override { f(n);}
    virtual void visit(BNode& n) override { f(n);}
    virtual void visit(CNode& n) override { f(n);}
    virtual void visit(SubBNode& n) override { f(n);}
private:
    F& f;
};


class CountVisitor
{
public:
    void operator() (const Node& n) const {
        cout << __PRETTY_FUNCTION__ << "\t\tDefault" << endl;
    }
    void operator() (const BNode& n) {
        count++;
        cout << __PRETTY_FUNCTION__ << "\t\tSPECIAL" << endl;
    }

    int count = 0;
    void print() const {
        cout << "CountVisitor Found Bs: "<< count << endl;
    }
};

// ====================================================

int main() {
    cout << "======FLAT TEST======" << endl;
  
    Node n;
    ANode a;
    BNode b1;
    CNode c;
    BNode b2;
    SubBNode subB;

    vector<Node*> nodes = { &n, &a, &b1, &c, &b2, &subB };

    cout << "--DEFAULT--" << endl;
    Visitor v1;
    for( Node* n : nodes ){
        n->accept(v1);
    }

    cout << "--COUNT--" << endl;
    CountVisitor cv1;
    FunctorVisitor<CountVisitor> d(cv1);
    for (Node* n : nodes){
        n->accept(d);
    }
    cv1.print();
}
