fork download
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3. using ll = long long;
  4.  
  5. int main() {
  6. ios::sync_with_stdio(false);
  7. cin.tie(nullptr);
  8.  
  9. int n;
  10. cin >> n;
  11. vector<ll> a(n);
  12. vector<int> b(n);
  13. for (int i = 0; i < n; i++) cin >> a[i];
  14. for (int i = 0; i < n; i++) cin >> b[i];
  15.  
  16. // 1) Prefix sums of squares
  17. vector<ll> pre(n + 1);
  18. pre[0] = 0;
  19. for (int i = 0; i < n; i++) pre[i + 1] = pre[i] + a[i] * a[i];
  20.  
  21. // 2) Build sparse table for OR
  22. int LOG = __lg(n) + 1;
  23. vector<vector<int>> sp(n, vector<int>(LOG));
  24. for (int i = 0; i < n; i++) sp[i][0] = b[i];
  25. for (int j = 1; j < LOG; j++) {
  26. for (int i = 0; i + (1 << j) <= n; i++) {
  27. sp[i][j] = sp[i][j - 1] | sp[i + (1 << (j - 1))][j - 1];
  28. }
  29. }
  30. auto getOR = [&](int l, int r) {
  31. int j = __lg(r - l + 1);
  32. return sp[l][j] | sp[r - (1 << j) + 1][j];
  33. };
  34.  
  35. // 3) Precompute next positions of each bit
  36. const int B = 21;
  37. vector<array<int, B>> nxt(n);
  38. array<int, B> last;
  39. last.fill(n);
  40. for (int i = n - 1; i >= 0; --i) {
  41. for (int k = 0; k < B; ++k) nxt[i][k] = last[k];
  42. for (int k = 0; k < B; ++k) if ((b[i] >> k) & 1) last[k] = i;
  43. }
  44.  
  45. // 4) Precompute, for each i, the sorted breakpoints where OR changes
  46. vector<vector<int>> changes(n);
  47. for (int i = 0; i < n; ++i) {
  48. int baseOR = b[i];
  49. auto &pts = changes[i];
  50. pts.reserve(B);
  51. for (int k = 0; k < B; ++k) {
  52. if (!((baseOR >> k) & 1) && nxt[i][k] < n) {
  53. pts.push_back(nxt[i][k]);
  54. }
  55. }
  56. sort(pts.begin(), pts.end());
  57. pts.erase(unique(pts.begin(), pts.end()), pts.end());
  58. }
  59.  
  60. // 5) Count good pairs using precomputed breakpoints
  61. ll answer = 0;
  62. for (int i = 0; i < n; ++i) {
  63. int prev = i;
  64. auto &pts = changes[i];
  65. int m = pts.size();
  66. for (int idx = 0; idx <= m; ++idx) {
  67. int end = (idx < m ? pts[idx] - 1 : n - 1);
  68. ll ORval = getOR(i, prev);
  69. ll G = ORval * ORval;
  70.  
  71. int lo = prev, hi = end, best = prev - 1;
  72. while (lo <= hi) {
  73. int mid = (lo + hi) >> 1;
  74. ll F = pre[mid + 1] - pre[i];
  75. if (F < G) {
  76. best = mid;
  77. lo = mid + 1;
  78. } else {
  79. hi = mid - 1;
  80. }
  81. }
  82. if (best >= prev) answer += best - prev + 1;
  83.  
  84. if (idx < m) prev = pts[idx];
  85. }
  86. }
  87.  
  88. cout << answer << '\n';
  89. return 0;
  90. }
  91.  
Success #stdin #stdout 0.01s 5288KB
stdin
5
1 2 4 3 1
1 4 8 7 1
stdout
13