import java.sql.Connection;
import java.sql.DriverManager;
import java.util.Collections;
import java.util.Map;
import java.util.HashMap;
import java.util.List;
import java.util.LinkedList;
import java.util.concurrent.atomic.AtomicInteger;

import java.lang.reflect.Field;
import java.lang.reflect.Modifier;

import org.aspectj.lang.Signature;

public aspect LoanAndReturnAspect {
    private ThreadLocal<Integer> depth = new ThreadLocal() {
        @Override
        protected Integer initialValue() {
            return 0;
        }
    };
    private ThreadLocal<Map<Connection, Integer>> possessedConnection = new ThreadLocal() {
        @Override
        protected Map<Connection, Integer> initialValue() {
            return new HashMap<Connection, Integer>();
        }
    };
    private Map<Connection, AtomicInteger> refCounts =
        Collections.synchronizedMap(new HashMap<Connection, AtomicInteger>());
    private Map<Connection, Thread> departments =
        Collections.synchronizedMap(new HashMap<Connection, Thread>());
    
    pointcut loanConnection() :
        call(Connection DriverManager.getConnection(..)) && !within(LoanAndReturnAspect); // 註1
    pointcut retainConnection(Connection conn, Object theObj) :
        set(!static Connection+ *.*) && target(theObj) && args(conn) && !within(LoanAndReturnAspect);
    pointcut retainConnectionStatic(Connection conn) :
        set(static Connection+ *.*) && args(conn) && !within(LoanAndReturnAspect);
    pointcut checkPoint() :
        execution(* *.*(..)) && !within(LoanAndReturnAspect);  // 註2
    pointcut releaseConnection(Connection conn) :
        call(void Connection+.close()) && target(conn) && !within(LoanAndReturnAspect);
    
    after() returning(Connection conn) : loanConnection() {  // 註3
        System.out.printf("Loan a connection: %s%n", conn);
        possessedConnection.get().put(conn, depth.get());
        departments.put(conn, Thread.currentThread());
        refCounts.put(conn, new AtomicInteger(0));
    }
    
    Object around() : checkPoint() {
        Signature sig = thisJoinPoint.getSignature();
        logMethod(sig, true);
        try {
            depth.set(depth.get() + 1);
            return proceed();
        }
        finally {
            depth.set(depth.get() - 1);
            tryToReturnConnection();
            logMethod(sig, false);
        }
    }
    
    after(Connection conn) : releaseConnection(conn) && !cflowbelow(execution(void Connection+.close())) {
        System.out.printf("[Manually close] %s%n", conn);
        refCounts.remove(conn);
        Thread department = departments.get(conn);
        if (department == Thread.currentThread())
            possessedConnection.get().remove(conn);
        departments.remove(conn);
    }
    
    before(Connection newConn, Object theObj) : retainConnection(newConn, theObj) {
        Connection oldConn = (Connection) getFieldValue(theObj, thisJoinPoint.getSignature().getName());
        if (isTracking(oldConn))
            release(oldConn);
        if (isTracking(newConn))
            retain(newConn);
    }
    
    before(Connection newConn) : retainConnectionStatic(newConn) {
        Signature sig = thisJoinPoint.getSignature();
        Connection oldConn = (Connection) getStaticFieldValue(sig.getDeclaringType(), sig.getName());
        if (isTracking(oldConn))
            release(oldConn);
        if (isTracking(newConn))
            retain(newConn);
    }
    
    private boolean isTracking(Connection conn) {
        return refCounts.containsKey(conn);
    }
    
    private void retain(Connection conn) {
        int refCount = refCounts.get(conn).incrementAndGet();
        System.out.printf("[Retain] %s[%d]%n", conn, refCount);
    }
    
    private void release(Connection conn) {
        int refCount = refCounts.get(conn).decrementAndGet();
        System.out.printf("[Release] %s[%d]%n", conn, refCount);
        if (refCount < 1 && !departments.get(conn).isAlive()) {
            returnConnection(conn, false);
        }
    }
    
    private void tryToReturnConnection() {
        int stackDepth = depth.get();
        Map<Connection, Integer> connTable = possessedConnection.get();
        List<Connection> willReturn = new LinkedList<Connection>();
        for (Connection conn : connTable.keySet()) {
            if (!isTracking(conn)) continue;
            int bornDepth = connTable.get(conn);
            if (bornDepth > stackDepth && refCounts.get(conn).intValue() < 1)
                willReturn.add(conn);
        }
        
        for (Connection conn : willReturn) {
            returnConnection(conn, true);
        }
    }
    
    private void returnConnection(Connection conn, boolean inOriginalDepartment) {
        System.out.printf("[Auto close] %s%n", conn);
        try {
            conn.close();
        }
        catch (java.sql.SQLException e) {
        }
        refCounts.remove(conn);
        departments.remove(conn);
        if (inOriginalDepartment) {
            possessedConnection.get().remove(conn);
        }
    }
    
    private void logMethod(Signature sig, boolean isEntering) {
        int stackDepth = depth.get();
        System.out.printf("%s<%s%s:%s>%n",
            computeIndentation(stackDepth),
            isEntering? "" : "/",
            Thread.currentThread().getName(), sig);
    }
    
    private String computeIndentation(int level) {
        StringBuilder buf = new StringBuilder();
        for (int i = 0; i < level; ++i)
            buf.append("  ");
        return buf.toString();
    }
    
    private Object getFieldValue(Object obj, String name) {
        Field field = null;
        try {
            field = obj.getClass().getDeclaredField(name);
            if (Modifier.isStatic(field.getModifiers()))
                field = null;
        }
        catch (Exception e) {
        }
        if (field == null) {
            try {
                field = obj.getClass().getField(name);
                if (Modifier.isStatic(field.getModifiers()))
                    field = null;
            }
            catch (Exception e1) {
            }
        }
        if (field == null)
            throw new RuntimeException(String.format("can't find instance field named: %s in class: %s", name, obj.getClass()));
        field.setAccessible(true);
        try {
            return field.get(obj);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
    
    private Object getStaticFieldValue(Class<?> klass, String name) {
        Field field = null;
        try {
            field = klass.getDeclaredField(name);
            if (!Modifier.isStatic(field.getModifiers()))
                field = null;
        }
        catch (Exception e) {
        }
        if (field == null) {
            try {
                field = klass.getField(name);
                if (!Modifier.isStatic(field.getModifiers()))
                    field = null;
            }
            catch (Exception e1) {
            }
        }
        if (field == null)
            throw new RuntimeException(String.format("can't find static field named: %s in class: %s", name, klass));
        field.setAccessible(true);
        try {
            return field.get(klass);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}
