fork download
  1. #include <bits/stdc++.h>
  2.  
  3. #define ld long double
  4. #define ll long long
  5. #define pb push_back
  6. #define x first
  7. #define y second
  8. #define all(x) x.begin(),x.end()
  9. #define sz(x) (int)(x.size())
  10.  
  11. using namespace std;
  12.  
  13. const int MAXN=200000;
  14.  
  15. int n,k;
  16. int lca[20][500000];
  17. vector<vector<int>> g(MAXN);
  18. vector<set<int>> unvis(MAXN);
  19. vector<int> eul,lvl(MAXN),in(MAXN),par(MAXN),color,colorLCA,dsu,dsu2,sz,highestV,doneColors(MAXN);
  20. vector<int> vis(MAXN);
  21.  
  22. int getp(int x)
  23. {
  24. if(x==dsu[x]) return x;
  25. return dsu[x]=getp(dsu[x]);
  26. }
  27.  
  28. void connect(int u, int v)
  29. {
  30. u=getp(u), v=getp(v);
  31. if(u==v) return;
  32. if(sz[v]>sz[u]) swap(u,v);
  33. dsu[v]=u;
  34. sz[u]+=sz[v];
  35. doneColors[u]+=doneColors[v];
  36. }
  37.  
  38. int getp2(int x)
  39. {
  40. if(x==dsu2[x]) return x;
  41. return dsu2[x]=getp2(dsu2[x]);
  42. }
  43. void connect2(int u, int v)
  44. {
  45. u=getp2(u), v=getp2(v);
  46. if(u==v) return;
  47. if(rand()&1) swap(u,v);
  48. dsu2[v]=u;
  49. if(lvl[highestV[u]]>lvl[highestV[v]]) highestV[u]=highestV[v];
  50. }
  51. void makeeuler(int u, int pr, int d)
  52. {
  53. par[u]=pr;
  54. in[u]=sz(eul);
  55. lvl[u]=d;
  56. eul.pb(u);
  57. for(int v:g[u])
  58. {
  59. if(v==pr) continue;
  60. makeeuler(v,u,d+1);
  61. eul.pb(u);
  62. }
  63. }
  64.  
  65. int getlca(int u, int v)
  66. {
  67. int l=in[u], r=in[v];
  68. if(l>r) swap(l,r);
  69. int lg=log2(r-l+1);
  70. return (lvl[lca[lg][l]]<=lvl[lca[lg][r-(1<<lg)+1]]?lca[lg][l]:lca[lg][r-(1<<lg)+1]);
  71. }
  72.  
  73. int main()
  74. {
  75. ios::sync_with_stdio(0); cin.tie(0);
  76. cin>>n>>k;
  77. for(int i=0;i<k;i++) dsu.pb(i), sz.pb(1);
  78. for(int i=0;i<n-1;i++)
  79. {
  80. int u,v; cin>>u>>v;
  81. u--;v--;
  82. g[u].pb(v);
  83. g[v].pb(u);
  84. }
  85. makeeuler(0,-1,0);
  86. for(int i=0;i<sz(eul);i++) lca[0][i]=eul[i];
  87. for(int i=1;i<20;i++)
  88. for(int j=0;j+(1<<i)<=sz(eul);j++)
  89. lca[i][j] = (lvl[lca[i-1][j]]<=lvl[lca[i-1][j+(1<<(i-1))]]?lca[i-1][j]:lca[i-1][j+(1<<(i-1))]);
  90. for(int i=0;i<n;i++)
  91. {
  92. dsu2.pb(i), highestV.pb(i);
  93. int c; cin>>c; c--;
  94. color.pb(c);
  95. unvis[c].insert(i);
  96. }
  97. int ans=1e9;
  98. set<pair<int,int>> q;
  99. for(int i=0;i<k;i++)
  100. {
  101. int lc=*unvis[i].begin();
  102. for(int v:unvis[i])
  103. lc=getlca(lc,v);
  104. colorLCA.pb(lc);
  105. q.insert({-lvl[lc],lc});
  106. }
  107. for(auto p:q)
  108. {
  109. int candidLCA=p.y;
  110. if(candidLCA!=colorLCA[color[candidLCA]]) continue;
  111. queue<int> qu;
  112. for(int v:unvis[color[candidLCA]])
  113. {
  114. qu.push(v);
  115. if(v!=candidLCA) vis[v]=1;
  116. }
  117. unvis[color[candidLCA]].clear();
  118. doneColors[getp(color[candidLCA])]++;
  119. while(!qu.empty())
  120. {
  121. int v=qu.front(); qu.pop();
  122. while(v!=candidLCA)
  123. {
  124. if(vis[v]==2) break;
  125. connect(color[candidLCA],color[v]);
  126. if(lvl[colorLCA[color[v]]]>=lvl[candidLCA])
  127. {
  128. int bef=sz(unvis[color[v]]);
  129. unvis[color[v]].erase(v);
  130. if(sz(unvis[color[v]])==0 && bef==1) doneColors[getp(color[v])]++;
  131. vis[v]=2;
  132. auto it=unvis[color[v]].begin();
  133. while(it!=unvis[color[v]].end())
  134. {
  135. int t=*it;
  136. if(vis[t])
  137. {
  138. it++;
  139. continue;
  140. }
  141. vis[t]=1;
  142. qu.push(t);
  143. it=unvis[color[t]].erase(it);
  144. if(sz(unvis[color[t]])==0) doneColors[getp(color[t])]++;
  145. }
  146. }
  147. else break;
  148. connect2(v,par[v]);
  149. v=highestV[getp2(par[v])];
  150. if(vis[v]==1) break;
  151. }
  152. }
  153. int rep=getp(color[candidLCA]);
  154. if(doneColors[rep]==sz[rep]) ans=min(ans,sz[rep]-1);
  155. }
  156. cout<<ans<<'\n';
  157. return 0;
  158. }
  159.  
Success #stdin #stdout 0.01s 20932KB
stdin
Standard input is empty
stdout
1000000000