fork download
  1. #include <cstdio>
  2. #include <vector>
  3.  
  4. #define dep(i,j,k) for (int (i)=(j);(i)>=(k);--(i))
  5. #define rep(i,j,k) for (int (i)=(j);(i)<=(k);++(i))
  6.  
  7. using namespace std;
  8.  
  9. typedef long long ll;
  10.  
  11. const int C=10,L=20,N=(int)1e5;
  12.  
  13. vector<int> leaf;
  14. int n,c,h,mm,lim,cnt=1;
  15. int trie[N*20+10][C+10],fa[N*20+10][L+10];
  16. int a[N+10],son[N+10],degree[N+10],ed[N*2+10],nxt[N*2+10];
  17. int b[N*20+10],d[N*20+10],lg[N*20+10],sa[N*20+10],sum[N*20+10],tsa[N*20+10],rank[N*20+10],trank[N*20+10],height[N*20+10];
  18.  
  19. inline int getint(){
  20. int x=0;
  21. bool flag=false;
  22. char ch=getchar();
  23. while (!(ch>='0' && ch<='9' || ch=='-')) ch=getchar();
  24. if (ch=='-') flag=true,ch=getchar();
  25. while (ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();
  26. return flag?-x:x;
  27. }
  28.  
  29. inline void addedge(int x,int y){
  30. nxt[++mm]=son[x]; son[x]=mm; ed[mm]=y;
  31. nxt[++mm]=son[y]; son[y]=mm; ed[mm]=x;
  32. }
  33.  
  34. void dfs(int x,int pre,int u){
  35. if (!trie[u][a[x]]) trie[u][a[x]]=++cnt,b[cnt]=a[x];
  36. u=trie[u][a[x]];
  37. for (int i=son[x];i;i=nxt[i])
  38. if (ed[i]!=pre) dfs(ed[i],x,u);
  39. }
  40.  
  41. void dfs2(int x){
  42. rep(i,1,L) fa[x][i]=fa[fa[x][i-1]][i-1];
  43. rep(i,1,c) if (trie[x][i]){
  44. d[trie[x][i]]=d[x]+1;
  45. fa[trie[x][i]][0]=x;
  46. dfs2(trie[x][i]);
  47. }
  48. }
  49.  
  50. inline int jump(int x,int y){
  51. for (;y;y-=y&-y) x=fa[x][lg[y&-y]];
  52. return x;
  53. }
  54.  
  55. void radix(int j){
  56. rep(i,0,lim) sum[i]=0;
  57. rep(i,1,cnt) sum[rank[fa[i][j]]]++;
  58. rep(i,1,lim) sum[i]+=sum[i-1];
  59. dep(i,cnt,1) tsa[sum[rank[fa[i][j]]]--]=i;
  60. rep(i,0,lim) sum[i]=0;
  61. rep(i,1,cnt) sum[rank[tsa[i]]]++;
  62. rep(i,1,lim) sum[i]+=sum[i-1];
  63. dep(i,cnt,1) sa[sum[rank[tsa[i]]]--]=tsa[i];
  64. }
  65.  
  66. void getsa(){
  67. rep(i,1,cnt) trank[i]=b[i];
  68. rep(i,1,cnt) sum[trank[i]]++;
  69. rep(i,1,c) sum[i]+=sum[i-1];
  70. dep(i,cnt,1) sa[sum[trank[i]]--]=i;
  71. rank[sa[1]]=lim=1;
  72. for (int i=2;i<=cnt;++i){
  73. if (trank[sa[i]]!=trank[sa[i-1]]) lim++;
  74. rank[sa[i]]=lim;
  75. }
  76. for (int j=0;(1<<j)<=cnt && (1<<j)<=h;++j){
  77. radix(j);
  78. trank[sa[1]]=lim=1;
  79. for (int i=2;i<=cnt;++i){
  80. if (rank[sa[i]]!=rank[sa[i-1]] || rank[fa[sa[i]][j]]!=rank[fa[sa[i-1]][j]]) lim++;
  81. trank[sa[i]]=lim;
  82. }
  83. rep(i,1,cnt) rank[i]=trank[i];
  84. }
  85. }
  86.  
  87. int getheight(int x){
  88. int y=0;
  89. rep(i,1,c) if (trie[x][i]) y=max(y,getheight(trie[x][i]));
  90. if (y) y--;
  91. if (x==1) return 0;
  92. int u=jump(x,y),v=jump(sa[rank[x]-1],y);
  93. while (u>1 && v>1 && b[u]==b[v]) y++,u=fa[u][0],v=fa[v][0];
  94. height[rank[x]]=y;
  95. return y;
  96. }
  97.  
  98. int main(){
  99. n=getint(); c=getint();
  100. rep(i,1,n) a[i]=getint(),a[i]++;
  101. rep(i,2,n){
  102. int x=getint(),y=getint();
  103. addedge(x,y);
  104. degree[x]++;
  105. degree[y]++;
  106. }
  107. rep(i,1,n) if (degree[i]==1) leaf.push_back(i);
  108. for (vector<int>::iterator i=leaf.begin();i!=leaf.end();++i) dfs(*i,0,1);
  109. rep(i,2,cnt) lg[i]=lg[i-1]+(1<<lg[i-1]+1==i);
  110. dfs2(1);
  111. rep(i,1,cnt) h=max(h,d[i]);
  112. getsa();
  113. getheight(1);
  114. ll ans=0;
  115. rep(i,1,cnt) ans+=d[i];
  116. rep(i,1,cnt) ans-=height[i];
  117. printf("%lld\n",ans);
  118. return 0;
  119. }
Success #stdin #stdout 0s 466880KB
stdin
7 3
0 2 1 2 1 0 0
1 2
3 4
3 5
4 6
5 7
2 5
stdout
30