fork(3) download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3. #define gc getchar_unlocked
  4. #define fo(i,n) for(i=0;i<n;i++)
  5. #define Fo(i,k,n) for(i=k;k<n?i<n:i>n;k<n?i+=1:i-=1)
  6. #define ll long long
  7. #define si(x) scanf("%d",&x)
  8. #define sl(x) scanf("%lld",&x)
  9. #define ss(s) scanf("%s",s)
  10. #define pi(x) printf("%d\n",x)
  11. #define pl(x) printf("%lld\n",x)
  12. #define ps(s) printf("%s\n",s)
  13. #define pb push_back
  14. #define mp make_pair
  15. #define F first
  16. #define S second
  17. #define all(x) x.begin(), x.end()
  18. #define clr(x) memset(x, 0, sizeof(x))
  19. #define sortall(x) sort(all(x))
  20. #define tr(it, a) for(auto it = a.begin(); it != a.end(); it++)
  21. #define PI 3.1415926535897932384626
  22. typedef pair<int, int> pii;
  23. typedef pair<ll, ll> pl;
  24. typedef vector<int> vi;
  25. typedef vector<ll> vl;
  26. typedef vector<pii> vpii;
  27. typedef vector<pl> vpl;
  28. typedef vector<vi> vvi;
  29. typedef vector<vl> vvl;
  30. int mpow(int base, int exp);
  31. void ipgraph(int m);
  32. void dfs(int u, int par);
  33. const int mod = 1000000007;
  34. const int N = 2e5+1, M = N;
  35. //=======================
  36.  
  37. vi g[N];
  38. int a[N], p[N], v[N];
  39. int sum[N], id[N];
  40. map<int,int> cnt[N];
  41. int ans = 0;
  42.  
  43. int main()
  44. {
  45. ios_base::sync_with_stdio(false);
  46. cin.tie(NULL);
  47. int i,n,k,j;
  48. cin >> n;
  49. Fo(i, 1, n+1){
  50. id[i] = i;
  51. cin >> p[i] >> v[i];
  52. if(p[i]>0)
  53. g[p[i]].pb(i);
  54. }
  55. dfs(1, 0);
  56. cout << ans << endl;
  57. return 0;
  58. }
  59.  
  60. void dfs(int u, int par){
  61. for(int v:g[u]){
  62. if (v == par) continue;
  63. dfs(v, u);
  64. if(cnt[id[v]].size() > cnt[id[u]].size()) swap(id[u], id[v]);
  65. }
  66.  
  67. int &cur = sum[id[u]];
  68. for(int v:g[u]){
  69. if (v == par) continue;
  70. for(auto it: cnt[id[v]]){
  71. int col = it.F, no = it.S;
  72. int &val = cnt[id[u]][col];
  73. cur -= val;
  74. if(cur < 0) cur += mod;
  75. val = ((val+1) *1LL* (no+1))%mod;
  76. val -= 1;
  77. if(val < 0) val += mod;
  78. cur += val;
  79. if(cur >= mod) cur -= mod;
  80. }
  81. }
  82. ll add = cur;
  83. cnt[id[u]][v[u]] += 1+add;
  84. if(cnt[id[u]][v[u]] >= mod) cnt[id[u]][v[u]] -= mod;
  85. ans += 1+add;
  86. if(ans >= mod) ans -= mod;
  87. cur += 1+add;
  88. if(cur >= mod) cur -= mod;
  89. }
Success #stdin #stdout 0s 34032KB
stdin
6
-1 3
1 2
1 2
1 2
1 5
1 5
stdout
16