fork download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3. using ll = long long;
  4.  
  5. struct TrieNode {
  6. int child[2];
  7. TrieNode() { child[0] = child[1] = -1; }
  8. };
  9.  
  10. struct BinaryTrie {
  11. vector<TrieNode> trie;
  12. int k; // 비트 길이 (트라이 깊이)
  13.  
  14. BinaryTrie(int _k) : k(_k) {
  15. trie.clear();
  16. trie.push_back(TrieNode()); // root
  17. }
  18.  
  19. void insert(int x) {
  20. int node = 0;
  21. for (int bit = k-1; bit >= 0; --bit) {
  22. int b = (x >> bit) & 1;
  23. if (trie[node].child[b] == -1) {
  24. trie[node].child[b] = (int)trie.size();
  25. trie.push_back(TrieNode());
  26. }
  27. node = trie[node].child[b];
  28. }
  29. }
  30.  
  31. // DFS를 돌면서 각 리프(삽입된 값)에 대한 기여도 계산
  32. ll dfs(int node, int depth, ll fixedMask, int fixedCount) {
  33. // depth는 현재 처리 중인 비트 (루트에서 k-1 시작, 마지막 -1이 되면 리프)
  34. int left = trie[node].child[0];
  35. int right = trie[node].child[1];
  36.  
  37. if (left == -1 && right == -1) {
  38. // 리프
  39. // 현재 fixedMask = sum(2^j for j in 분기 발생 비트)
  40. // fixedCount = |S_b|
  41. int L = 1 << k;
  42. ll freeCount = k - fixedCount;
  43. // contrib = 2^(k - |S_b| - 1) * ((2^k - 1) - fixedMask)
  44. ll powVal = (1LL << (freeCount - 1));
  45. ll contrib = powVal * ((1LL << k) - 1 - fixedMask);
  46. return contrib;
  47. }
  48.  
  49. ll res = 0;
  50. bool branching = (left != -1 && right != -1);
  51. if (left != -1) {
  52. res += dfs(left, depth - 1, fixedMask + (branching ? (1LL << depth) : 0), fixedCount + (branching ? 1 : 0));
  53. }
  54. if (right != -1) {
  55. res += dfs(right, depth - 1, fixedMask + (branching ? (1LL << depth) : 0), fixedCount + (branching ? 1 : 0));
  56. }
  57. return res;
  58. }
  59.  
  60. ll computeF() {
  61. return dfs(0, k-1, 0, 0);
  62. }
  63. };
  64.  
  65. // dyadic decomposition
  66. vector<pair<int,int>> decompose(int M) {
  67. vector<pair<int,int>> blocks;
  68. int cur = 0;
  69. while (cur < M) {
  70. int size = 1;
  71. while (size <= M - cur) size <<= 1;
  72. size >>= 1;
  73. while (cur % size != 0) size >>= 1;
  74. blocks.push_back({cur, __builtin_ctz(size)}); // base, k (k=log2(size))
  75. cur += size;
  76. }
  77. return blocks;
  78. }
  79.  
  80. int main() {
  81. ios::sync_with_stdio(false);
  82. cin.tie(nullptr);
  83.  
  84. int N, M;
  85. cin >> N >> M;
  86. vector<int> A(N);
  87. for (int i = 0; i < N; i++) cin >> A[i];
  88.  
  89. ll answer = 0;
  90. auto blocks = decompose(M);
  91.  
  92. for (auto [base, k] : blocks) {
  93. int L = 1 << k;
  94.  
  95. // Step 1: B = A xor base
  96. int hmin = INT_MAX;
  97. for (int a : A) {
  98. int v = a ^ base;
  99. hmin = min(hmin, v >> k);
  100. }
  101.  
  102. // Step 2: B' 집합
  103. unordered_set<int> BpSet;
  104. BpSet.reserve(N*2);
  105. for (int a : A) {
  106. int v = a ^ base;
  107. if ((v >> k) == hmin) {
  108. BpSet.insert(v & (L - 1));
  109. }
  110. }
  111.  
  112. // Step 3: 상위항
  113. answer += (ll)(hmin << k) * L;
  114.  
  115. // Step 4: 하위항 F(k, B')
  116. if (!BpSet.empty()) {
  117. BinaryTrie trie(k);
  118. for (int val : BpSet) trie.insert(val);
  119. answer += trie.computeF();
  120. }
  121. }
  122.  
  123. cout << answer << "\n";
  124. return 0;
  125. }
Success #stdin #stdout 0s 5332KB
stdin
22
atcoderbeginnercontest
stdout
0