fork(1) download
  1. /* package whatever; // don't place package name! */
  2.  
  3. import java.util.*;
  4. import java.lang.*;
  5. import java.io.*;
  6. import java.util.Comparator;
  7. import java.util.PriorityQueue;
  8. import java.util.*;
  9.  
  10. /* Name of the class has to be "Main" only if the class is public. */
  11. class Solution {
  12.  
  13. class State {
  14. int i; // index of x[]
  15. int j; // index of y[]
  16. int sum; // x[i] + y[j]
  17.  
  18. public State(int i, int j, int sum) {
  19. this.i = i;
  20. this.j = j;
  21. this.sum = sum;
  22. }
  23. }
  24.  
  25. public int findKthDistinctSum(int[] x, int[] y, int k) {
  26. if (x.length == 0 || y.length == 0) {
  27. throw new IllegalArgumentException("Can't handle zero-length arrays.");
  28. }
  29.  
  30. // use a min heap to poll the next state that has minimum sum
  31. PriorityQueue<State> heap = new PriorityQueue<>(new Comparator<State>() {
  32. public int compare(State a, State b) {
  33. return a.sum - b.sum;
  34. }
  35. });
  36.  
  37. // use a hash set to avoid duplicate sum
  38. Set<Integer> set = new HashSet<>();
  39.  
  40. // step 1. create initial state
  41. int sum = x[0] + y[0];
  42. heap.offer(new State(0, 0, sum));
  43. set.add(sum);
  44.  
  45. // step 2. generate new states based on current state
  46. // until we get the kth smallest sum
  47. while (k-- > 1) {
  48. State s = heap.poll();
  49.  
  50. // new state 1: x[i], y[j + 1]
  51. if (s.j < y.length - 1) {
  52. sum = x[s.i] + y[s.j + 1];
  53.  
  54. if (!set.contains(sum)) {
  55. heap.offer(new State(s.i, s.j + 1, sum));
  56. set.add(sum);
  57. }
  58. }
  59.  
  60. // new state 2: x[i + 1], y[j]
  61. if (s.i < x.length - 1) {
  62. sum = x[s.i + 1] + y[s.j];
  63.  
  64. if (!set.contains(sum)) {
  65. heap.offer(new State(s.i + 1, s.j, sum));
  66. set.add(sum);
  67. }
  68. }
  69. }
  70.  
  71. return heap.poll().sum;
  72. }
  73.  
  74. }
  75. class Ideone
  76. {
  77. public static void main (String[] args) throws java.lang.Exception
  78. {
  79. int A[] = {1, 2, 4, 6};
  80. int B[] = {2, 3, 6};
  81. int k = 4;
  82. Solution sol = new Solution();
  83. System.out.println(sol.findKthDistinctSum(A, B, k));
  84. System.out.println(sol.findKthDistinctSum(B, A, k));
  85. }
  86. }
Success #stdin #stdout 0.1s 320576KB
stdin
Standard input is empty
stdout
7
6