fork(1) download
  1. #import resource
  2. import sys
  3. #resource.setrlimit(resource.RLIMIT_STACK, [0x100000000, resource.RLIM_INFINITY])
  4. sys.setrecursionlimit(10**6)
  5. mod=(10**9)+7
  6. #fact=[1]
  7. #import collections
  8. #for i in range(1,1000001):
  9. # fact.append((fact[-1]*i)%mod)
  10. #ifact=[0]*1000001
  11. #ifact[1000000]=pow(fact[1000000],mod-2,mod)
  12. #for i in range(1000000,0,-1):
  13. # ifact[i-1]=(i*ifact[i])%mod
  14. from sys import stdin, stdout
  15. from bisect import bisect_left as bl
  16. from bisect import bisect_right as br
  17. import itertools
  18. import math
  19. import heapq
  20. from random import randint as rn
  21. from Queue import Queue as Q
  22. def modinv(n,p):
  23. return pow(n,p-2,p)
  24. def ncr(n,r,p):
  25. t=((fact[n])*((ifact[r]*ifact[n-r])%p))%p
  26. return t
  27. def ain():
  28. return map(int,sin().split())
  29. def sin():
  30. return stdin.readline().strip()
  31. def GCD(x, y):
  32. while(y):
  33. x, y = y, x % y
  34. return x
  35. def isprime(x):
  36. p=int(math.sqrt(x))+1
  37. if(x==1):
  38. return 0
  39. for i in range(2,p):
  40. if(x%p==0):
  41. return 0
  42. return 1
  43. """**************************************************************************"""
  44. #pv is parent v
  45. #xo is xor of all selected ancestors
  46. def dfs(v,pv,xo):
  47. if(dp[v][xo]!=-1):
  48. return dp[v][xo]
  49. s1=0
  50. s2=0
  51. for i in ad[v]:
  52. if(i!=pv):
  53. s1+=dfs(i,v,xo)
  54. if(xo^value[v-1]==0):
  55. s2+=dfs(i,v,0)
  56. else:
  57. s2+=dfs_skip(i,v,xo^value[v-1],xo^value[v-1])
  58. s2+=value[v-1]
  59. dp[v][xo]=max(s1,s2)
  60. return dp[v][xo]
  61. #for skipping J levels
  62. def dfs_skip(v,pv,xo,J):
  63. s=0
  64. if(J==1):
  65. for i in ad[v]:
  66. if(i!=pv):
  67. s+=dfs(i,v,xo)
  68. else:
  69. for i in ad[v]:
  70. if(i!=pv):
  71. s+=dfs_skip(i,v,xo,J-1)
  72. return s
  73. n=input()
  74. value=ain()
  75. ad=[[] for i in range(n+1)]
  76. for i in range(n-1):
  77. x,y=ain()
  78. ad[x].append(y)
  79. ad[y].append(x)
  80. dp=[[-1 for i in range(1024)] for i in range(1024)]
  81. q=dfs(1,0,0)
  82. print q
Success #stdin #stdout 0.05s 33376KB
stdin
4
1 1 1 1
1 2
1 3
1 4
stdout
3