fork download
  1. #pragma GCC optimize("O3,unroll-loops")
  2. #include <bits/stdc++.h>
  3. #define ll long long
  4. #define sti string
  5. #define itachi ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
  6. #define maxn 500005
  7. #define fi first
  8. #define se second
  9. #define ldb long double
  10. using namespace std;
  11.  
  12. int n;
  13.  
  14. int32_t color[maxn], sz[maxn];
  15. bool banned[maxn];
  16. int32_t cnt[maxn];
  17. int32_t delta_only[maxn], delta_all[maxn];
  18. int32_t delta_only_inside[maxn], delta_all_inside[maxn];
  19. vector<pair<int32_t, int32_t>> adj[maxn];
  20.  
  21. long long ans=0;
  22. int sum_only = 0;
  23. int sum_only_inside = 0;
  24. vector<int32_t> usedDelta, usedInside;
  25. int current_root = 0;
  26.  
  27. void orient_color(int u, int p) {
  28. for (auto [v, w] : adj[u]) {
  29. if (v == p || banned[v]) continue;
  30. color[v] = w;
  31. orient_color(v, u);
  32. }
  33. }
  34.  
  35. void get_sz(int u, int p) {
  36. sz[u] = 1;
  37. for (auto [v, w] : adj[u]) {
  38. if (v != p && !banned[v]) {
  39. get_sz(v, u);
  40. sz[u] += sz[v];
  41. }
  42. }
  43. }
  44.  
  45. void countChild(int u, int p) {
  46. sz[u] = 1;
  47. bool first = false;
  48. if (u != current_root) first = (++cnt[color[u]] == 1);
  49. for (auto [v,w] : adj[u]) {
  50. if (v == p || banned[v]) continue;
  51. countChild(v, u);
  52. sz[u] += sz[v];
  53. }
  54. if (u != current_root) {
  55. if (first) {
  56. sum_only += sz[u];
  57. if (delta_all[color[u]] == 0) usedDelta.push_back(color[u]);
  58. delta_all[color[u]] += sz[u];
  59. delta_only[color[u]] += sz[u];
  60. }
  61. else if (cnt[color[u]] == 2) {
  62. sum_only -= sz[u];
  63. delta_only[color[u]] -= sz[u];
  64. }
  65. --cnt[color[u]];
  66. }
  67. }
  68.  
  69. int find_cen(int u, int p, int S) {
  70. for (auto [v,w] : adj[u]) {
  71. if (v != p && !banned[v] && sz[v] > S / 2) {
  72. return find_cen(v, u, S);
  73. }
  74. }
  75. return u;
  76. }
  77.  
  78. void cntInside(int u, int p) {
  79. bool first = (++cnt[color[u]] == 1);
  80.  
  81. if (first) {
  82. sum_only_inside += sz[u];
  83. if (delta_all_inside[color[u]] == 0) usedInside.push_back(color[u]);
  84. delta_only_inside[color[u]] += sz[u];
  85. delta_all_inside[color[u]] += sz[u];
  86. }
  87. else if (cnt[color[u]] == 2) {
  88. sum_only_inside -= sz[u];
  89. delta_only_inside[color[u]] -= sz[u];
  90. }
  91.  
  92. for (auto [v,w] : adj[u]) {
  93. if (v == p || banned[v]) continue;
  94. cntInside(v, u);
  95. }
  96.  
  97. --cnt[color[u]];
  98. }
  99.  
  100. void addRootPaths(int u, int p, int distinct) {
  101. bool first = (++cnt[color[u]] == 1);
  102. if (first) distinct++;
  103. else if (cnt[color[u]] == 2) distinct--;
  104.  
  105. ans += 2 * distinct;
  106.  
  107. for (auto [v, w] : adj[u]) {
  108. if (v == p || banned[v]) continue;
  109. addRootPaths(v, u, distinct);
  110. }
  111. --cnt[color[u]];
  112. }
  113.  
  114. void dfs(int u, int p, int edge_color, int only_cur, int all_cur, int only_out, int distinctlen, int ver_out) {
  115. int C = edge_color;
  116. bool first = (++cnt[C] == 1);
  117.  
  118. if (first) {
  119. distinctlen++;
  120. all_cur += delta_all[C] - delta_all_inside[C];
  121. only_cur += delta_only[C] - delta_only_inside[C];
  122. }
  123. else if (cnt[C] == 2) {
  124. distinctlen--;
  125. all_cur -= delta_all[C] - delta_all_inside[C];
  126. }
  127.  
  128. ans += distinctlen * ver_out - all_cur - only_cur + only_out;
  129.  
  130. for (auto [v,w] : adj[u]) {
  131. if (v == p || banned[v]) continue;
  132. dfs(v, u, w, only_cur, all_cur, only_out, distinctlen, ver_out);
  133. }
  134.  
  135. --cnt[C];
  136. }
  137.  
  138. void solve(int u) {
  139. get_sz(u, 0);
  140.  
  141. int root = find_cen(u, 0, sz[u]);
  142. orient_color(root, 0);
  143.  
  144. current_root = root;
  145. sum_only = 0;
  146. countChild(root, 0);
  147. int all_node = sz[root];
  148.  
  149. for (auto [v,w] : adj[root]) {
  150. if (banned[v]) continue;
  151. addRootPaths(v, root, 0);
  152. }
  153.  
  154. for (auto [v,w] : adj[root]) {
  155. if (banned[v]) continue;
  156.  
  157. sum_only_inside = 0;
  158. cntInside(v, root);
  159. int ver_out = all_node - sz[v] - 1;
  160. int only_out = sum_only - sum_only_inside;
  161.  
  162. dfs(v, root, w, 0, 0, only_out, 0, ver_out);
  163.  
  164. for (int c : usedInside) {
  165. delta_all_inside[c] = 0;
  166. delta_only_inside[c]=0;
  167. }
  168. usedInside.clear();
  169. }
  170.  
  171. banned[root] = 1;
  172. for (int c : usedDelta) {
  173. delta_all[c] = 0;
  174. delta_only[c]=0;
  175. }
  176. usedDelta.clear();
  177. for (auto [v,w] : adj[root]) {
  178. if (!banned[v]) solve(v);
  179. }
  180. }
  181.  
  182. signed main()
  183. {
  184. itachi
  185. cin>>n;
  186. for (int i = 1; i < n; i++) {
  187. int u, v, w;
  188. cin >> u >> v >> w;
  189. adj[u].push_back({v,w});
  190. adj[v].push_back({u,w});
  191. }
  192.  
  193. solve(1);
  194. cout << ans / 2;
  195. return 0;
  196. }
Success #stdin #stdout 0.01s 18624KB
stdin
Standard input is empty
stdout
Standard output is empty