import java.util.List;
import java.util.LinkedList;
import java.util.Arrays;
import java.util.PriorityQueue;

/**
 * Taxicab number is the one that can be represented as a sum of cubes
 * of two numbers in two different ways: a^3 + b^3 = c^3 + d^3 = N, where
 * (a, b) and (c, d) pairs differ not only in the ordering.
 *
 * The basic idea is to enumerate all pairs (i, j) for 1 &lt; i, j &lt; N
 * and find the ones that are repeated more than once. An easy way to
 * do this is to store every pair in the array, sort the array and
 * traverse it looking for the repeated pairs.
 *
 * But actually we can do a lot better in terms of extra memory used.
 * We could use a PQ at first to store the pairs from (1, 1) to (N, N)
 * and then traverse it adding new elements that would be traversed later
 * on fly.
 *
 * (1, 1) (1, 2) ... (1, N)
 * (2, 1) (2, 2) ... (2, N)
 * ...
 * (N, 1) (N, 2) ... (N, N)
 *
 * In each cell (i, j) there is a value i^3 + j^3. The thing is that
 * in every row and in every column the values are sorted in the increasing
 * order. And the elements from the matrix could be traversed in the
 * following order: (1, 1), (2, 1), (2, 2), (3, 1), (3, 2) and so on.
 *
 * So after retrieving a next element from the PQ we can add a new
 * one that would be by one cell lower in the matrix. Thus we will maintain
 * the number of elements stored in the PQ at N and also as a result
 * we will traverse every single element from the matrix under the main
 * diagonal in the increasing order.
 */
class TaxicabNumbers {
    /**
     * Finds every taxicab number that is less than {@code n} in
     * O(N^2lgN) time and using O(N^2) extra space
     */
    public static List<TaxicabNumber> taxicabNumbersSortVersion(int n) {
        // 0 is not a taxicab number, so we can safely skip it
        if (n <= 0) {
            throw new IllegalArgumentException(String.format(
                "Expected: n > 0. Got: %d.", n));
        }
        List<TaxicabNumber> taxicabNumbers = new LinkedList<>();

        // find an upper bound for n^(1/3)
        int maxPairNumber = 1;
        while (pow(maxPairNumber, 3) <= n) {
            maxPairNumber++;
        }
        
        // 1 + 2 + 3 + ... + n = n * (n + 1) / 2
        // 'cause we ignore the elements above the main diagonal
        IntPair[] pairs
            = new IntPair[maxPairNumber * (maxPairNumber + 1) / 2];

        for (int i = 0, iArray = 0; i < maxPairNumber; i++) {
            // ignore the elements above the main diagonal
            for (int j = 0; j <= i; j++) {
                pairs[iArray] = new IntPair(i, j);
                iArray++;
            }
        }
        Arrays.sort(pairs);

        // the number of pairs with equal sums in a row
        int runningNumber = 1;
        IntPair prev = new IntPair(0, 0);
        for (int i = 0; i < pairs.length; i++) {
            IntPair curr = pairs[i];
            if (prev.sum == curr.sum) {
                runningNumber++;
                if (runningNumber == 2) {
                    taxicabNumbers.add(new TaxicabNumber(prev, curr));
                }
            } else {
                runningNumber = 1;
            }
            prev = curr;
        }
        return taxicabNumbers;
    }

    public static List<TaxicabNumber> taxicabNumbersHeapVersion(int n) {
        if (n <= 0) {
            throw new IllegalArgumentException(String.format(
                "Expected: n > 0. Got: %d.", n));
        }
        List<TaxicabNumber> taxicabNumbers = new LinkedList<>();
        PriorityQueue<IntPair> pairs = new PriorityQueue<>();
        
        // find an upper bound for n^(1/3)
        int maxPairNumber = 1;
        while (pow(maxPairNumber, 3) < n) {
            maxPairNumber++;
        }

        for (int i = 0; i <= maxPairNumber; i++) {
            pairs.offer(new IntPair(i, i));
        }
        // the number of pairs with equal sums in a row
        int runningNumber = 1;
        IntPair prev = new IntPair(0, 0);
        while (!pairs.isEmpty()) {
            IntPair curr = pairs.poll();
            if (prev.sum == curr.sum) {
                runningNumber++;
                if (runningNumber == 2) {
                    taxicabNumbers.add(new TaxicabNumber(prev, curr));
                }
            } else {
                runningNumber = 1;
            }
            if (curr.i < maxPairNumber) {
                pairs.offer(new IntPair(curr.i + 1, curr.j));
            }
            prev = curr;
        }
        return taxicabNumbers;
    }

    private static long pow(long x, int n) {
        assert n >= 0;

        long pow = 1;
        for (int i = 0; i < n; i++) {
            pow *= x;
        }
        return pow;
    }

    private static class IntPair implements Comparable<IntPair> {
        private long i;
        private long j;
        private long sum;

        public IntPair(long i, long j) {
            assert i > 0 && j > 0;

            this.i = i;
            this.j = j;
            sum = i * i * i + j * j * j;
        }
        
        /**
         * (a, b) &lt; (c, d) iff a^3 + b^3 ^lt; c^3 + d^3, or
         * in case the sums are equal, then the smaller pair is the
         * one that has smaller value on the first place (that is
         * a &lt; c)
         */
        @Override
        public int compareTo(IntPair other) {
            int sumComparison = Long.compare(this.sum, other.sum);
            if (sumComparison == 0) {
                return Long.compare(this.i, other.i);
            } else {
                return sumComparison;
            }
        }

        @Override
        public String toString() {
            return String.format("%d^3 + %d^3", i, j);
        }
    }

    private static class TaxicabNumber {
        private IntPair representation1;
        private IntPair representation2;

        public TaxicabNumber(IntPair representation1,
                IntPair representation2) {
            assert representation1.sum == representation2.sum;

            this.representation1 = representation1;
            this.representation2 = representation2;
        }

        @Override
        public String toString() {
            return String.format("%s = %s = %d", representation1,
                representation2, representation1.sum);
        }
    }

    public static void main(String[] args) {
        System.out.println(taxicabNumbersSortVersion(65_000));
        System.out.println(taxicabNumbersHeapVersion(65_000));
    }
}
