fork download
  1. #include <bits/stdc++.h>
  2. #include <ext/pb_ds/assoc_container.hpp>
  3. #include <ext/pb_ds/tree_policy.hpp>
  4.  
  5. using namespace std;
  6. using namespace __gnu_pbds;
  7.  
  8. #define fi first
  9. #define se second
  10. #define mp make_pair
  11. #define pb push_back
  12. #define fbo find_by_order
  13. #define ook order_of_key
  14.  
  15. typedef long long ll;
  16. typedef pair<int,int> ii;
  17. typedef vector<int> vi;
  18. typedef long double ld;
  19. typedef tree<ll, null_type, less<ll>, rb_tree_tag, tree_order_statistics_node_update> pbds;
  20.  
  21. vi adj[222222];
  22. int sub[222222];
  23. int color[222222];
  24. vector<set<int> > cities;
  25. set<int> usedcolors;
  26. int colorcnt[222222];
  27. int ans=int(1e9);
  28. int banned[222222]; //banned colors
  29. bool cenvisited[222222];
  30. set<int> cursub; //current subtree elements
  31.  
  32. void prep(int u, int p=-1)
  33. {
  34. sub[u]=1;
  35. colorcnt[color[u]]++;
  36. usedcolors.insert(color[u]);
  37. cursub.insert(u);
  38. for(int v:adj[u])
  39. {
  40. if(v==p) continue;
  41. if(cenvisited[v]) continue;
  42. prep(v,u);
  43. sub[u]+=sub[v];
  44. }
  45. }
  46.  
  47. int centroid(int u, int p=-1, int r=-1)
  48. {
  49. for(int v:adj[u])
  50. {
  51. if(v==p) continue;
  52. if(cenvisited[v]) continue;
  53. if(sub[v]*2>sub[r])
  54. {
  55. return centroid(v,u,r);
  56. }
  57. }
  58. return u;
  59. }
  60.  
  61. bool valid(int c) //is a color valid in this subtree
  62. {
  63. return (int(cities[c].size())==colorcnt[c]);
  64. }
  65.  
  66. set<int> colorset;
  67. set<int> processed;
  68.  
  69. void push(int c)
  70. {
  71. if(colorset.find(c)==colorset.end()&&processed.find(c)==processed.end())
  72. {
  73. colorset.insert(c);
  74. }
  75. }
  76.  
  77. int h[222222];
  78. int par[222222];
  79.  
  80. void calch(int u, int p=-1)
  81. {
  82. for(int v:adj[u])
  83. {
  84. if(v==p) continue;
  85. if(cenvisited[v]) continue;
  86. par[v]=u;
  87. h[v]=h[u]+1;
  88. calch(v,u);
  89. }
  90. }
  91.  
  92.  
  93. struct DSU
  94. {
  95. int S;
  96.  
  97. struct node
  98. {
  99. int p, rt;
  100. };
  101. vector<node> dsu;
  102.  
  103. DSU(int n)
  104. {
  105. S = n;
  106. for(int i = 0; i < n; i++)
  107. {
  108. node tmp;
  109. tmp.p = i; tmp.rt = i;
  110. dsu.pb(tmp);
  111. }
  112. }
  113. void reset(int n)
  114. {
  115. dsu.clear();
  116. S = n;
  117. for(int i = 0; i < n; i++)
  118. {
  119. node tmp;
  120. tmp.p = i; tmp.rt = i;
  121. dsu.pb(tmp);
  122. }
  123. }
  124.  
  125. void resetnode(int u)
  126. {
  127. dsu[u].p = u;
  128. dsu[u].rt = u;
  129. }
  130.  
  131. int rt(int u)
  132. {
  133. if(dsu[u].p == u) return u;
  134. dsu[u].p = rt(dsu[u].p);
  135. return dsu[u].p;
  136. }
  137.  
  138. void merge(int u, int v)
  139. {
  140. u = rt(u); v = rt(v);
  141. if(u == v) return ;
  142. if(rand()&1) swap(u, v);
  143. dsu[v].p = u;
  144. if(h[dsu[v].rt]<h[dsu[u].rt])
  145. {
  146. dsu[u].rt=dsu[v].rt;
  147. }
  148. }
  149.  
  150. bool sameset(int u, int v)
  151. {
  152. if(rt(u) == rt(v)) return true;
  153. return false;
  154. }
  155.  
  156. int getrt(int u)
  157. {
  158. return dsu[rt(u)].rt;
  159. }
  160. };
  161. DSU dsu(1);
  162.  
  163. void solve(int u)
  164. {
  165. prep(u);
  166. int cent = centroid(u,-1,u);
  167. h[cent]=0; par[cent]=-1;
  168. calch(cent);
  169. //solve for this centroid
  170. colorset.insert(color[cent]);
  171. bool pos=1;
  172. while(!colorset.empty())
  173. {
  174. int c = (*colorset.begin()); processed.insert(c); colorset.erase(c);
  175. if(!valid(c)){pos=0; break;} //invalid color found
  176. //let's push all the stuff of this color
  177. for(int u:cities[c])
  178. {
  179. while(u!=cent)
  180. {
  181. u=dsu.getrt(u);
  182. if(u==cent) break;
  183. push(color[par[u]]);
  184. dsu.merge(par[u],u);
  185. u=par[u];
  186. }
  187. }
  188. }
  189. /*
  190. cerr<<"SOLVE WITH CENTROID = "<<cent<<'\n';
  191. cerr<<"POSSIBLE = "<<pos<<'\n';
  192. if(pos)
  193. {
  194. cerr<<"PROCESSED = ";
  195. for(int x:processed) cerr<<x<<' ';
  196. cerr<<'\n';
  197. }
  198. */
  199. for(int x:cursub)
  200. {
  201. dsu.resetnode(x);
  202. }
  203. if(pos)
  204. {
  205. ans=min(ans,int(processed.size()));
  206. }
  207. cenvisited[cent]=1;
  208. for(int x:usedcolors) colorcnt[x]=0;
  209. usedcolors.clear();
  210. colorset.clear();
  211. processed.clear();
  212. cursub.clear();
  213. for(int v:adj[cent])
  214. {
  215. if(cenvisited[v]) continue;
  216. //recurse
  217. solve(v);
  218. }
  219. }
  220.  
  221. int main()
  222. {
  223. ios_base::sync_with_stdio(0); cin.tie(0);
  224. int n,k; cin>>n>>k;
  225. dsu.reset(n);
  226. for(int i=0;i<n-1;i++)
  227. {
  228. int u,v; cin>>u>>v; u--; v--;
  229. adj[u].pb(v); adj[v].pb(u);
  230. }
  231. cities.resize(k);
  232. for(int i=0;i<n;i++)
  233. {
  234. cin>>color[i]; color[i]--;
  235. cities[color[i]].insert(i);
  236. }
  237. solve(0);
  238. cout<<ans-1<<'\n'; //# of cities - 1
  239. }
  240.  
Internal error #stdin #stdout 0s 0KB
stdin
Standard input is empty
stdout
Standard output is empty