fork download
  1. #include<bits/stdc++.h>
  2. #define ll long long
  3. #define ldb long double
  4. #define fi first
  5. #define se second
  6. #define sza(a) (int)a.size()
  7. #define pir pair<int,int>
  8. #define pirll pair<ll,ll>
  9. using namespace std;
  10. const int maxn = 1e5 + 5;
  11. const int modu = 1e9 + 7;
  12.  
  13. inline void add(ll &x,ll y){
  14. x = (x + y) % modu;
  15. }
  16.  
  17. int a[maxn];
  18. ll dp[maxn][2],XOR[maxn];
  19.  
  20. vector <int> vec[maxn];
  21.  
  22. void input(int n){
  23. for (int i = 1 ; i <= n ; i++) cin >> a[i];
  24. for (int i = 1 ; i < n ; i++){
  25. int u,v;
  26. cin >> u >> v;
  27. vec[u].push_back(v);
  28. vec[v].push_back(u);
  29. }
  30. }
  31.  
  32. void prepare_dfs(int u,int par){
  33. XOR[u] = a[u];
  34.  
  35. for (int v : vec[u])
  36. if (v != par){
  37. prepare_dfs(v,u);
  38. XOR[u] ^= XOR[v];
  39. }
  40. }
  41.  
  42. void solve(int u,int par,int x){
  43. for (int v : vec[u])
  44. if (v != par)
  45. solve(v,u,x);
  46.  
  47. ll f[2];
  48. f[0] = f[1] = 0;
  49.  
  50. for (int v : vec[u])
  51. if (v != par){
  52. f[0] = (dp[u][0] * dp[v][0] + dp[u][1] * dp[v][1]) % modu;
  53. f[1] = (dp[u][0] * dp[v][1] + dp[u][1] * dp[v][0]) % modu;
  54.  
  55. add(dp[u][0],f[0]);
  56. add(dp[u][1],f[1]);
  57.  
  58. add(dp[u][0],dp[v][0]);
  59. add(dp[u][1],dp[v][1]);
  60. }
  61.  
  62. f[0] = dp[u][0];
  63. f[1] = dp[u][1];
  64.  
  65. if (u != 1 && XOR[u] == x) add(dp[u][1],f[0]);
  66. if (u != 1 && XOR[u] == 0) add(dp[u][0],f[1]);
  67.  
  68. if (u != 1 && XOR[u] == x) add(dp[u][1],1);
  69. }
  70.  
  71. void solve_0(int u,int par,ll &res){
  72. if (!XOR[u] && u != 1) res = res * 2 % modu;
  73.  
  74. for (int v : vec[u])
  75. if (v != par)
  76. solve_0(v,u,res);
  77. }
  78. int main(){
  79. ios_base::sync_with_stdio(false);
  80. cin.tie(0);cout.tie(0);
  81.  
  82. int n,x;
  83. cin >> n >> x;
  84. input(n);
  85.  
  86. prepare_dfs(1,0);
  87.  
  88. if (x == 0){
  89. ll res1 = 1;
  90. solve_0(1,0,res1);
  91.  
  92. cout << res1;
  93. return 0;
  94. }
  95.  
  96. solve(1,0,x);
  97.  
  98. if (XOR[1] != 0 && XOR[1] != x){
  99. cout << 0;
  100. return 0;
  101. }
  102.  
  103. int st = (XOR[1] == x);
  104.  
  105. cout << (dp[1][1 - st] + st) % modu;
  106.  
  107. return 0;
  108. }
  109.  
Success #stdin #stdout 0.01s 6332KB
stdin
Standard input is empty
stdout
Standard output is empty