fork download
  1. #pragma GCC optimize("O3,unroll-loops")
  2. #include <bits/stdc++.h>
  3. #define ll long long
  4. #define sti string
  5. #define itachi ios_base::sync_with_stdio(0);cin.tie(0);cout.tie(0);
  6. #define maxn 500005
  7. #define fi first
  8. #define se second
  9. #define ldb long double
  10. #define int long long
  11. using namespace std;
  12.  
  13. int n;
  14. int color[maxn], sz[maxn];
  15. bool banned[maxn];
  16. int par[maxn];
  17. vector<pair<int,int>> adj[maxn];
  18.  
  19. int ans=0;
  20. int cnt[maxn];
  21. int delta_only[maxn],delta_all[maxn];
  22. int delta_only_inside[maxn],delta_all_inside[maxn];
  23. int sum_only = 0;
  24. int sum_only_inside = 0;
  25. vector<int> usedDelta, usedInside;
  26. int current_root = 0;
  27.  
  28. void orient_color(int u, int p) {
  29. for (auto [v, w] : adj[u]) {
  30. if (v == p || banned[v]) continue;
  31. color[v] = w;
  32. orient_color(v, u);
  33. }
  34. }
  35.  
  36. void countChild(int u, int p) {
  37. sz[u] = 1;
  38. bool first = false;
  39. if (u != current_root) first = (++cnt[color[u]] == 1);
  40. for (auto [v,w] : adj[u]) {
  41. if (v == p || banned[v]) continue;
  42. countChild(v, u);
  43. sz[u] += sz[v];
  44. }
  45. if (u != current_root) {
  46. if (first) {
  47. sum_only += sz[u];
  48. if (delta_all[color[u]] == 0) usedDelta.push_back(color[u]);
  49. delta_all[color[u]] += sz[u];
  50. delta_only[color[u]] += sz[u];
  51. }
  52. else{
  53. sum_only -= sz[u];
  54. delta_only[color[u]] -= sz[u];
  55. }
  56. --cnt[color[u]];
  57. }
  58. }
  59.  
  60. int find_cen(int u, int p, int S) {
  61. for (auto [v,w] : adj[u]) {
  62. if (v != p && !banned[v] && sz[v] > S / 2) {
  63. return find_cen(v, u, S);
  64. }
  65. }
  66. return u;
  67. }
  68.  
  69. void cntInside(int u, int p) {
  70. bool first = (++cnt[color[u]] == 1);
  71.  
  72. if (first) {
  73. sum_only_inside += sz[u];
  74. if (delta_all_inside[color[u]] == 0) usedInside.push_back(color[u]);
  75. delta_only_inside[color[u]] += sz[u];
  76. delta_all_inside[color[u]] += sz[u];
  77. }
  78. else{
  79. sum_only_inside -= sz[u];
  80. delta_only_inside[color[u]] -= sz[u];
  81. }
  82.  
  83. for (auto [v,w] : adj[u]) {
  84. if (v == p || banned[v]) continue;
  85. cntInside(v, u);
  86. }
  87.  
  88. --cnt[color[u]];
  89. }
  90.  
  91. void addRootPaths(int u, int p, int distinct) {
  92. bool first = (++cnt[color[u]] == 1);
  93. if (first) distinct++;
  94. else if (cnt[color[u]] == 2) distinct--;
  95.  
  96. ans += 2 * distinct;
  97.  
  98. for (auto [v, w] : adj[u]) {
  99. if (v == p || banned[v]) continue;
  100. addRootPaths(v, u, distinct);
  101. }
  102. --cnt[color[u]];
  103. }
  104.  
  105. void dfs(int u, int p,int edge_color, int only_cur, int all_cur,int only_out,
  106. int distinctlen,int ver_out) {
  107. int C=edge_color;
  108. bool first = (++cnt[C] == 1);
  109.  
  110. if (first) {
  111. distinctlen++;
  112. all_cur += delta_all[C] - delta_all_inside[C];
  113. only_cur += delta_only[C] - delta_only_inside[C];
  114. }
  115. else if (cnt[C] == 2) {
  116. distinctlen--;
  117. all_cur -= delta_all[C] - delta_all_inside[C];
  118. }
  119.  
  120. ans += distinctlen * ver_out - all_cur - only_cur + only_out;
  121.  
  122. for (auto [v,w] : adj[u]) {
  123. if (v == p || banned[v]) continue;
  124. dfs(v, u, w,only_cur,all_cur,only_out,distinctlen,ver_out);
  125. }
  126.  
  127. --cnt[C];
  128. }
  129.  
  130. void solve(int u) {
  131. current_root = 0;
  132. sum_only = 0;
  133. countChild(u, 0);
  134.  
  135. int root = find_cen(u, 0, sz[u]);
  136.  
  137. for (int c : usedDelta) {
  138. delta_all[c]=0;
  139. delta_only[c]=0;
  140. }
  141. usedDelta.clear();
  142. orient_color(root, 0);
  143. current_root = root;
  144. sum_only = 0;
  145. countChild(root, 0);
  146. int all_node = sz[root];
  147. for (auto [v,w] : adj[root]) {
  148. if (banned[v]) continue;
  149. addRootPaths(v, root, 0);
  150. }
  151.  
  152. for (auto [v,w] : adj[root]) {
  153. if (banned[v]) continue;
  154.  
  155. sum_only_inside = 0;
  156. cntInside(v, root);
  157. int ver_out = all_node - sz[v] - 1;
  158. int only_out = sum_only - sum_only_inside;
  159.  
  160. dfs(v, root, w, 0, 0, only_out, 0, ver_out);
  161.  
  162. for (int c : usedInside) {
  163. delta_all_inside[c] = 0;
  164. delta_only_inside[c]=0;
  165. }
  166. usedInside.clear();
  167. }
  168.  
  169. banned[root] = 1;
  170. for (int c : usedDelta) {
  171. delta_all[c] = 0;
  172. delta_only[c]=0;
  173. }
  174. usedDelta.clear();
  175. for (auto [v,w] : adj[root]) {
  176. if (!banned[v]) solve(v);
  177. }
  178. }
  179.  
  180. void pre_dfs(int u,int p){
  181. for(auto [v,w] : adj[u]){
  182. if(v==p) continue;
  183. par[v]=u;
  184. pre_dfs(v,u);
  185. }
  186. }
  187.  
  188. signed main()
  189. {
  190. itachi
  191. cin >> n;
  192. for (int i = 1; i < n; i++) {
  193. int u, v,w;
  194. cin >> u >> v >> w;
  195. adj[u].push_back({v,w});
  196. adj[v].push_back({u,w});
  197. }
  198.  
  199. pre_dfs(1,0);
  200.  
  201. for(int i=1;i<=n;i++){
  202. for(auto [v,w] : adj[i]){
  203. if(par[v]==i) color[v]=w;
  204. }
  205. }
  206. solve(1);
  207. cout << ans / 2;
  208. return 0;
  209. }
  210.  
Success #stdin #stdout 0.01s 22088KB
stdin
Standard input is empty
stdout
Standard output is empty