2 条题解
- 
  0
#include <bits/stdc++.h> using namespace std; namespace z { #define int long long const int N = 5e5 + 5; int n, a[N], pmx[N], pmn[N], smx[N], smn[N], ans, pre[N], suf[N]; int pcx[N], pcn[N], scx[N], scn[N]; void main() { ios::sync_with_stdio(false); cin.tie(nullptr);cout.tie(nullptr); cin >> n; memset(pmn, 0x3f, sizeof(pmn)); memset(smn, 0x3f, sizeof(smn)); for(int i = 1; i <= n; i++) { cin >> a[i]; pmx[i] = max(pmx[i - 1], a[i]); pmn[i] = min(pmn[i - 1], a[i]); pre[i] = pre[i - 1] + a[i]; } int mx = 0; for(int i = 1; i <= n; i++) { if(a[i] > mx) pcx[i] = mx, mx = a[i]; else pcx[i] = a[i] > pcx[i - 1] ? a[i] : pcx[i - 1]; } int mn = 2e9; for(int i = 1; i <= n; i++) { if(a[i] < mn) pcn[i] = mn, mn = a[i]; else pcn[i] = a[i] < pcn[i - 1] ? a[i] : pcn[i - 1]; } mx = 0; for(int i = n; i >= 1; i--) { if(a[i] > mx) scx[i] = mx, mx = a[i]; else scx[i] = a[i] > scx[i + 1] ? a[i] : scx[i + 1]; } mn = 2e9; for(int i = n; i >= 1; i--) { if(a[i] < mn) scn[i] = mn, mn = a[i]; else scn[i] = a[i] < scn[i + 1] ? a[i] : scn[i + 1]; } for(int i = n; i >= 1; i--) { smx[i] = max(smx[i + 1], a[i]); smn[i] = min(smn[i + 1], a[i]); suf[i] = suf[i + 1] + a[i]; } int maxn, minn, cmax, cmin, sum; for(int i = 0; i <= n / 2; i++) { sum = pre[i] + suf[n / 2 + i + 1]; if(i != 0 && i != n / 2) { minn = min(pmn[i], smn[n / 2 + i + 1]); maxn = max(pmx[i], smx[n / 2 + i + 1]); cmax = max({min(pmx[i], smx[n / 2 + i + 1]), pcx[i], scx[n / 2 + i + 1]}); cmin = min({max(pmn[i], smn[n / 2 + i + 1]), pcn[i], scn[n / 2 + i + 1]}); } else { minn = i ? pmn[i] : smn[n / 2 + i + 1]; maxn = i ? pmx[i] : smx[n / 2 + i + 1]; cmax = i ? pcx[i] : scx[n / 2 + i + 1]; cmin = i ? pcn[i] : scn[n / 2 + i + 1]; } if((minn + minn + n / 2 - 1) * n / 4 == sum) { ans += (n / 2 - 1) * n / 2; if(minn - 1) ans++; if(maxn + 1 <= n) ans++; continue; } if(cmax == minn + n / 2 - 1) ans++; if(cmax == minn + n / 2 - 2) ans++, ans += minn != 1; if(cmin == maxn - n / 2 + 1) ans++; if(cmin == maxn - n / 2 + 2) ans++, ans += maxn < n; } cout << ans << '\n'; } #undef int } int main() { z::main(); return 0; } - 
  0
Answer is here!
#include <cstdio> #include <algorithm> #include <iostream> #include <cmath> #include <cstring> #include <vector> #include <set> using namespace std; using ll = long long; const int N = 2e5 + 10; const int MOD = 998244353; int n, m, a[N], x, y; ll ans; set<int> st; // 统计没有关键点的方案数 void Calc(int len, int k) { if(len >= m) x += k * (len - m + 1); } // 统计只有一个关键点的方案数,it 是这个关键点 void Get(set<int>::iterator it, int k) { if(*it < 1 || *it > n) return ; int l = *prev(it), p = *it, r = *next(it); if(r - l - 1 < m) return ; int lef = max(l + 1, p - m + 1); int rig = min(p, r - m); y += k * (rig - lef + 1); } void Add(int k) { auto it = st.lower_bound(k); int l = *prev(it), r = *it; Calc(r - l - 1, -1), Get(prev(it), -1); Calc(k - l - 1, 1), Get(it, -1); Calc(r - k - 1, 1); st.insert(k); it = st.find(k); Get(prev(it), 1); Get(next(it), 1); Get(it, 1); } void Del(int k) { auto it = st.find(k); int l = *prev(it), r = *next(it); Calc(r - l - 1, 1), Get(prev(it), -1); Calc(k - l - 1, -1), Get(next(it), -1); Calc(r - k - 1, -1), Get(it, -1); auto tl = prev(it), tr = next(it); st.erase(it); Get(tl, 1), Get(tr, 1); } void Solve() { cin >> n, m = n / 2; for(int i = 1; i <= n; ++i) cin >> a[i]; st.insert(0), st.insert(n + 1); x = m + 1; for(int i = 1; i <= m; ++i) Add(a[i]); for(int i = m; i <= n; ++i) { ans += 1ll * x * m * (m - 1) + y; if(i != n) { Add(a[i + 1]); Del(a[i - m + 1]); } } printf("%lld\n", ans); } int main() { cin.tie(0)->sync_with_stdio(0); int t = 1; //cin >> t; while(t--) Solve(); return 0; } # Thank you very much! 
- 1
 
信息
- ID
 - 77
 - 时间
 - 3000ms
 - 内存
 - 512MiB
 - 难度
 - 6
 - 标签
 - 递交数
 - 281
 - 已通过
 - 79
 - 上传者