fork download
  1. #include <vector>
  2. #include <iostream>
  3. #include <algorithm>
  4. using namespace std;
  5. int N, D, E, A[20]; vector<pair<long long, long long> > v1[3]; vector<long long> seg[3][1111111];
  6. int query(int s, int a, int b, long long x, int k, int l, int r) {
  7. if(r <= a || b <= l) return 0;
  8. if(a <= l && r <= b) return upper_bound(seg[s][k].begin(), seg[s][k].end(), x + D) - lower_bound(seg[s][k].begin(), seg[s][k].end(), x - D);
  9. int lc = query(s, a, b, x, 2 * k, l, (l + r) / 2);
  10. int rc = query(s, a, b, x, 2 * k + 1, (l + r) / 2, r);
  11. return lc + rc;
  12. }
  13. void rec(int pos, long long x1, long long x2, int xe) {
  14. if(pos == N) { v1[xe].push_back(make_pair(x1, x2)); return; }
  15. rec(pos + 1, x1 + A[pos], x2 + A[pos], xe);
  16. rec(pos + 1, x1 - A[pos], x2, xe);
  17. rec(pos + 1, x1, x2 - A[pos], xe);
  18. if(xe < E) rec(pos + 1, x1, x2, xe + 1);
  19. }
  20. long long solve(int pos, long long x1, long long x2, int xe) {
  21. long long ret = 0;
  22. if(pos == N / 2) {
  23. for(int i = 0; i <= E - xe; i++) {
  24. int pl = lower_bound(v1[i].begin(), v1[i].end(), make_pair(-x1 - D, -1LL << 60)) - v1[i].begin();
  25. int pr = lower_bound(v1[i].begin(), v1[i].end(), make_pair(-x1 + D + 1, -1LL << 60)) - v1[i].begin();
  26. ret += query(i, pl, pr, -x2, 1, 0, 524288);
  27. }
  28. return ret;
  29. }
  30. ret += solve(pos + 1, x1 + A[pos], x2 + A[pos], xe);
  31. ret += solve(pos + 1, x1 - A[pos], x2, xe);
  32. ret += solve(pos + 1, x1, x2 - A[pos], xe);
  33. if(xe < E) ret += solve(pos + 1, x1, x2, xe + 1);
  34. return ret;
  35. }
  36. int main() {
  37. scanf("%d%d%d", &N, &D, &E);
  38. for(int i = 0; i < N; i++) scanf("%d", &A[i]);
  39. rec(N / 2, 0, 0, 0);
  40. for(int i = 0; i <= E; i++) sort(v1[i].begin(), v1[i].end());
  41. for(int i = 0; i <= E; i++) {
  42. for(int j = 0; j <= v1[i].size(); j++) {
  43. int k = j + 524288;
  44. while(k > 0) seg[i][k].push_back(v1[i][j].second), k >>= 1;
  45. }
  46. for(int j = 1; j < 524288; j++) {
  47. if(seg[i][j].size()) sort(seg[i][j].begin(), seg[i][j].end());
  48. }
  49. }
  50. printf("%lld\n", solve(0, 0, 0, 0));
  51. }
Success #stdin #stdout 2.46s 42488KB
stdin
20 1000 2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
stdout
100341906651