fork download
  1. #include<bits/stdc++.h>
  2. #define str string
  3. #define ll long long
  4. #define db double
  5. #define pii pair<int, int>
  6. #define piii pair<int, pii>
  7. #define piiii pair<pii, pii>
  8. #define se second
  9. #define fi first
  10. #define vi vector<int>
  11. #define vii vector<vector<int>>
  12. #define mpii map<int, int>
  13. #define umpii unordered_map<int, int>
  14. #define si set<int>
  15. #define usa unordered_set<int>
  16. #define mulsi multiset<int>
  17. using namespace std;
  18.  
  19. const int mod=1e9+7;
  20. const int maxn=1e5+1;
  21.  
  22. int n, m, cnt;
  23. int c1[maxn], in[maxn], out[maxn], tree[maxn << 3];
  24. int par[20][maxn], h[maxn];
  25. int ans[maxn];
  26. vi c2[maxn], a[maxn];
  27. vector<piii> query[maxn];
  28.  
  29. void dfs1(int u, int p){
  30. in[u] = ++cnt;
  31. for(int v : a[u])
  32. if(v != p)
  33. dfs1(v, u);
  34. out[u] = ++cnt;
  35. c2[c1[u]].push_back(in[u]);
  36. c2[c1[u]].push_back(out[u]);
  37. }
  38.  
  39. void update(int id, int l, int r, int pos, int val){
  40. if(r < pos || pos < l)
  41. return;
  42. if(l == r){
  43. tree[id] = val;
  44. return;
  45. }
  46. int m = (l + r) >> 1;
  47. update(id * 2, l, m, pos, val);
  48. update(id * 2 + 1, m + 1, r, pos, val);
  49. tree[id] = tree[id * 2] + tree[id * 2 + 1];
  50. }
  51.  
  52. int get(int id, int l, int r, int u, int v){
  53. if(r < u || v < l)
  54. return 0;
  55. if(u <= l && r <= v)
  56. return tree[id];
  57. int m = (l + r) >> 1;
  58. return get(id * 2, l, m, u, v) + get(id * 2 + 1, m + 1, r, u, v);
  59. }
  60.  
  61. void dfs2(int u){
  62. for(int v : a[u])
  63. if(par[0][u] != v){
  64. par[0][v] = u; h[v] = h[u] + 1;
  65. for(int i = 1; i < 20; ++i)
  66. par[i][v] = par[i - 1][par[i - 1][v]];
  67. dfs2(v);
  68. }
  69. }
  70.  
  71. void get(int &u, int k){
  72. for(int i = 0; i < 20; ++i)
  73. if((k >> i) & 1)
  74. u = par[i][u];
  75. }
  76.  
  77. int lca(int u, int v){
  78. if(h[u] > h[v])
  79. swap(u, v);
  80. get(v, h[v] - h[u]);
  81. if(u == v) return u;
  82. for(int i = 19; ~i; --i)
  83. if(par[i][u] != par[i][v]){
  84. u = par[i][u];
  85. v = par[i][v];
  86. }
  87. return par[0][v];
  88. }
  89.  
  90. int main(){
  91. ios_base::sync_with_stdio(0);
  92. cin.tie(0); cout.tie(0);
  93. cin >> n >> m;
  94. for(int i = 1; i <= n; ++i)
  95. cin >> c1[i];
  96. for(int i = 1; i < n; ++i){
  97. int u, v; cin >> u >> v;
  98. a[u].push_back(v);
  99. a[v].push_back(u);
  100. }
  101. dfs1(1, -1); dfs2(1);
  102. for(int i = 0; i < m; ++i){
  103. int u, v, w; cin >> u >> v >> w;
  104. query[w].push_back({i, {u, v}});
  105. }
  106. for(int i = 1; i <= n; ++i)
  107. if(query[i].size()){
  108. for(int x: c2[i])
  109. update(1, 1, cnt, x, 1);
  110. for(piii q: query[i]){
  111. int pos = q.fi, u = q.se.fi, v = q.se.se;
  112. if(c1[u] == i || c1[v] == i){
  113. ans[pos] = 1;
  114. continue;
  115. }
  116. int uv = lca(u, v);
  117. int cnt1 = get(1, 1, cnt, in[uv], out[uv]);
  118. cnt1 -= get(1, 1, cnt, in[v], out[v]);
  119. cnt1 -= get(1, 1, cnt, in[u], out[u]);
  120. ans[pos] = (cnt1 > 0);
  121. }
  122. for(int x: c2[i])
  123. update(1, 1, cnt, x, 0);
  124. }
  125. for(int i = 0; i < m; ++i) cout << ans[i];
  126. return 0;
  127. }
  128.  
Success #stdin #stdout 0.01s 17376KB
stdin
5 5
1 1 2 1 2
1 2
2 3
2 4
1 5
1 4 1
1 4 2
1 3 2
1 3 1
5 5 1
stdout
10110