fork download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3. typedef long long lint;
  4. typedef pair<int, int> pi;
  5.  
  6. int n, k;
  7. vector<vector<int> > v;
  8. vector<lint> d;
  9.  
  10. vector<lint> solve(int c){
  11. lint tot = 0;
  12. for(int i=0; i<n; i++) tot += v[i][0];
  13.  
  14. auto cmp = [&](const vector<pi> &p, const vector<pi> &q){
  15. lint r1 = 0, r2 = 0;
  16. for(auto &i : p){
  17. r1 += v[i.first][i.second] - v[i.first][0];
  18. }
  19. for(auto &i : q){
  20. r2 += v[i.first][i.second] - v[i.first][0];
  21. }
  22. return r1 > r2;
  23. };
  24. priority_queue<vector<pi>, vector<vector<pi>>, decltype(cmp)> pq(cmp);
  25. set<vector<pi>> vis;
  26. vector<lint> dap;
  27.  
  28. vector<pi> w;
  29. for(int j=0; j<c; j++) w.push_back(pi(j, 1));
  30. pq.push(w);
  31.  
  32. while(!pq.empty() && dap.size() < k){
  33. auto cur = pq.top();
  34. pq.pop();
  35. if(vis.find(cur) != vis.end()) continue;
  36. vis.insert(cur);
  37. lint ret = tot;
  38. for(auto &i : cur){
  39. ret += v[i.first][i.second] - v[i.first][0];
  40. }
  41. if(d.size() + dap.size() >= k && ret >= d.back()) continue;
  42. dap.push_back(ret);
  43. for(int i=0; i<c; i++){
  44. auto nxt = cur;
  45. if(nxt[i].second + 1 < v[nxt[i].first].size()){
  46. nxt[i].second++;
  47. pq.push(nxt);
  48. nxt[i].second--;
  49. }
  50. if(nxt[i].second == 1 && i < c-1 && nxt[i].first + 1 < nxt[i+1].first){
  51. nxt[i].first++;
  52. pq.push(nxt);
  53. nxt[i].first--;
  54. }
  55. if(nxt[i].second == 1 && i == c-1 && nxt[i].first + 1 < n){
  56. nxt[i].first++;
  57. pq.push(nxt);
  58. nxt[i].first--;
  59. }
  60. }
  61. }
  62. return dap;
  63. }
  64.  
  65. int main(){
  66. freopen("roboherd.in", "r", stdin);
  67. freopen("roboherd.out", "w", stdout);
  68. cin >> n >> k;
  69. lint ret = 0;
  70. d.push_back(0);
  71. for(int i=0; i<n; i++){
  72. int x;
  73. scanf("%d",&x);
  74. vector<int> w(x);
  75. for(int j=0; j<x; j++){
  76. scanf("%d",&w[j]);
  77. }
  78. if(x == 1){
  79. ret += 1ll * w[0] * k;
  80. }
  81. else{
  82. sort(w.begin(), w.end());
  83. v.push_back(w);
  84. d[0] += w[0];
  85. }
  86. }
  87. sort(v.begin(), v.end(), [&](const vector<int> &a, const vector<int> &b){
  88. return a[1] - a[0] < b[1] - b[0];
  89. });
  90. n = v.size();
  91. for(int i=1; i<=16 && i<=n; i++){
  92. vector<lint> w = solve(i);
  93. for(auto &j : w) d.push_back(j);
  94. sort(d.begin(), d.end());
  95. while(d.size() > k) d.pop_back();
  96. }
  97. for(int i=0; i<k; i++) ret += d[i];
  98. cout << ret;
  99. }
  100.  
Success #stdin #stdout 0s 3480KB
stdin
Standard input is empty
stdout
Standard output is empty