//=== ===//
//
//                     The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//

#include "clang/ASTMatchers/ASTMatchers.h"
#include "clang/ASTMatchers/ASTMatchFinder.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Frontend/FrontendActions.h"
#include "clang/Lex/Lexer.h"
#include "clang/Tooling/CompilationDatabase.h"
#include "clang/Tooling/Refactoring.h"
#include "clang/Tooling/Tooling.h"
#include "llvm/ADT/OwningPtr.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Support/system_error.h"

#include <iostream>

using namespace std;

using namespace clang;
using namespace clang::ast_matchers;
using namespace llvm;
using clang::tooling::newFrontendActionFactory;
using clang::tooling::Replacement;
using clang::tooling::CompilationDatabase;

// FIXME: Pull out helper methods in here into more fitting places.

// Returns the text that makes up 'node' in the source.
// Returns an empty string if the text cannot be found.
template <typename T>
static std::string getText(const SourceManager &SourceManager, const T &Node)
{
    SourceLocation StartSpellingLocatino = SourceManager.getSpellingLoc(Node.getLocStart());
    SourceLocation EndSpellingLocation = SourceManager.getSpellingLoc(Node.getLocEnd());
    if(!StartSpellingLocatino.isValid() || !EndSpellingLocation.isValid())
    {
        return std::string();
    }
    bool Invalid = true;
    const char *Text =
        SourceManager.getCharacterData(StartSpellingLocatino, &Invalid);
    if(Invalid)
    {
        return std::string();
    }
    std::pair<FileID, unsigned> Start = SourceManager.getDecomposedLoc(StartSpellingLocatino);
    std::pair<FileID, unsigned> End = SourceManager.getDecomposedLoc(Lexer::getLocForEndOfToken(
                                           EndSpellingLocation, 0, SourceManager, LangOptions()));
    if(Start.first != End.first)
    {
        // Start and end are in different files.
        return std::string();
    }
    if(End.second < Start.second)
    {
        // Shuffling text with macros may cause this.
        return std::string();
    }
    return std::string(Text, End.second - Start.second);
}

namespace
{

void printCallPredicates(const CallExpr *Call, const string &indent = "    ")
{
    string new_indent = indent + "    ";
    cout << indent << "Call" << endl << indent << "(" << endl << new_indent;
    cout << "ArgumentCountIs(" << Call->getNumArgs() << ")";
    if(Call->getDirectCallee())
    {
        cout << ",Callee(Function(HasName(\"" << Call->getDirectCallee()->getQualifiedNameAsString() << "\")))";
        for(unsigned i = 0; i != Call->getNumArgs(); ++i)
        {
            if(const CallExpr *callExpr = dyn_cast<CallExpr>(Call->getArg(i)))
            {
                cout << ",HasArgument(" << i << "," << endl;
                printCallPredicates(callExpr, new_indent);
                cout << ")";
            }
        }
    }
    cout << endl << indent << ")";
}

class PrintMatchers : public ast_matchers::MatchFinder::MatchCallback
{
public:
    virtual void run(const ast_matchers::MatchFinder::MatchResult &Result)
    {
        const CallExpr *Call = Result.Nodes.getStmtAs<CallExpr>("call");
        cout << "Code: " << getText(*Result.SourceManager, *Call) << endl << "Matcher:" << endl;
        printCallPredicates(Call);
        cout << endl << endl;
    }
};


class PrintCall : public ast_matchers::MatchFinder::MatchCallback
{
public:
    virtual void run(const ast_matchers::MatchFinder::MatchResult &Result)
    {
        const CallExpr *Call = Result.Nodes.getStmtAs<CallExpr>("call");
        cout << "--------> Matched Call:\t" << getText(*Result.SourceManager, *Call) << endl << endl;
    }
};

} // end namespace

cl::opt<std::string> BuildPath(
    cl::Positional,
    cl::desc("<build-path>"));

cl::list<std::string> SourcePaths(
    cl::Positional,
    cl::desc("<source0> [... <sourceN>]"),
    cl::OneOrMore);

int main(int argc, char **argv)
{
    cl::ParseCommandLineOptions(argc, argv);
    std::string ErrorMessage;
    llvm::OwningPtr<CompilationDatabase> Compilations(CompilationDatabase::loadFromDirectory(BuildPath, ErrorMessage));
    if(!Compilations)
    {
        llvm::report_fatal_error(ErrorMessage);
    }
    tooling::RefactoringTool Tool(*Compilations, SourcePaths);
    ast_matchers::MatchFinder Finder;

    // Generate and print matchers for calls
    PrintMatchers print_matchers;
    Finder.addMatcher(Id("call", Call(True())), &print_matchers);

    PrintCall matcher_stream_outputs;
    // Autogenerated matcher for: std::cout << 1.0 << "1" << std::endl;
    Finder.addMatcher(Id("call",
        Call
        (
            ArgumentCountIs(2),Callee(Function(HasName("std::basic_ostream<char, std::char_traits<char> >::operator<<"))),HasArgument(0,
            Call
            (
                ArgumentCountIs(2),Callee(Function(HasName("std::operator<<"))),HasArgument(0,
                Call
                (
                    ArgumentCountIs(2),Callee(Function(HasName("std::basic_ostream<char, std::char_traits<char> >::operator<<")))
                ))
            ))
        )
    ), &matcher_stream_outputs);

    PrintCall matcher_c_str;
    // Autogenerated matcher for: std::string("1").c_str();
    Finder.addMatcher(Id("call",
        Call
        (
            ArgumentCountIs(0),Callee(Function(HasName("std::basic_string<char, std::char_traits<char>, std::allocator<char> >::c_str")))
        )
    ), &matcher_c_str);

    /*PrintCall callback;
    Finder.addMatcher(
        Id("call", Call(True())),
        &callback);*/
    return Tool.run(newFrontendActionFactory(&Finder));
}

