1 条题解
-
1
这题作为 MX-X 最后一题算是很简单了,但是我怎么做了这么久?
很难不发现 (x,y) 和 (a,b) 做乘法得到的二元组 (ay+bx,ax+by),它两项的和等于 (x+y)×(a+b),后一项减前一项的差等于 (y−x)×(b−a)。这启发我们用 (x+y,y−x) 来表示二元组,这样二元组的乘法就是对应的项相乘,同时两个二元组 (x,y),(a,b) 相同的条件仍然是 ay≡bx(modp)。并且 (x+y,y−x) 在 p 为奇素数时能和 (x,y) 构成双射。当 p=2 时需要特判,若区间出现 (1,0) 且不出现 (0,0),(1,1) 时答案为 2,否则为 1。
因为 b i 需要是正整数,所以 ∏a i
如果有至少一项是 0,那么答案为 1。否则 ∏a i 两项都 ≥1。现在我们可以直接用
y x 来替换二元组 (x,y),这样序列变成了一个正整数序列。看到乘法考虑取离散对数,假设找到了 p 的一个原根 g,设 b i 为最小的正整数使得 g b i
≡a i (modp)。那么乘积变成了和,一个序列的答案即为
gcdb ip-1 。
但是对每个 a i
求离散对数时间复杂度太大了。考虑求阶,设 c i
为最小的正整数使得 aici≡1(modp),那么 bi=ci
p-1
。一个序列的答案变为 lcmc i 。
现在问题变成求一个序列所有子区间的 lcm 的和。考虑枚举子区间右端点,因为序列中所有数都是 p−1 的因数,所以左端点变化时 lcm 只会变化 O(logV) 次。维护每种 lcm 对应的左端点范围即可。
时间复杂度 O( p +nlog 2 p)。
#include <bits/stdc++.h> #define pb emplace_back #define fst first #define scd second #define mkp make_pair #define mems(a, x) memset((a), (x), sizeof(a)) using namespace std; typedef long long ll; typedef __int128 lll; typedef double db; typedef unsigned long long ull; typedef long double ldb; typedef pair<ll, ll> pii; const int maxn = 100100; const ll mod = 1000000007; ll n, P, a[maxn], b[maxn], c[maxn], f[maxn], m; inline ll qpow(ll b, ll p, ll mod) { ll res = 1; while (p) { if (p & 1) { res = (lll)res * b % mod; } b = (lll)b * b % mod; p >>= 1; } return res; } int tot; pii p[99], q[99]; inline ll calc(ll x) { if (x == 1) { return 1; } ll y = P - 1; for (int i = 1; i <= tot; ++i) { for (int j = 1; j <= p[i].scd; ++j) { if (qpow(x, y / p[i].fst, P) == 1) { y /= p[i].fst; } else { break; } } } return y; } inline ll work() { int tot = 1; ll ans = f[1] % mod; p[1] = mkp(1, f[1]); for (int i = 2; i <= m; ++i) { p[++tot] = mkp(i, f[i]); for (int j = 1; j <= tot; ++j) { p[j].scd = f[i] / __gcd(f[i], p[j].scd) * p[j].scd; } int nt = 1; q[1] = p[1]; for (int j = 2; j <= tot; ++j) { if (p[j].scd != p[j - 1].scd) { q[++nt] = p[j]; } } tot = nt; q[tot + 1].fst = i + 1; for (int j = 1; j <= tot; ++j) { p[j] = q[j]; ans = (ans + q[j].scd % mod * (q[j + 1].fst - q[j].fst)) % mod; } } return ans; } void solve() { scanf("%lld%lld", &n, &P); if (P == 2) { int p = 0, q = 0; ll ans = n * (n + 1) / 2; for (int i = 1, x, y; i <= n; ++i) { scanf("%d%d", &x, &y); if (x == 1 && y == 0) { q = i; } else if (x == y) { p = i; } ans = (ans + max(q - p, 0)) % mod; } printf("%lld\n", ans % mod); return; } ll x = P - 1; for (ll i = 2; i * i <= x; ++i) { if (x % i == 0) { ll cnt = 0; while (x % i == 0) { x /= i; ++cnt; } p[++tot] = mkp(i, cnt); } } if (x > 1) { p[++tot] = mkp(x, 1); } for (int i = 1; i <= n; ++i) { ll x, y; scanf("%lld%lld", &x, &y); a[i] = (x + y) % P; b[i] = (y - x + P) % P; if (a[i] || b[i]) { c[i] = calc((lll)a[i] * qpow(b[i], P - 2, P) % P); } } ll ans = 0, sl = n * (n + 1) / 2; for (int i = 1, j = 1; i <= n; i = (++j)) { if (!a[i] || !b[i]) { continue; } while (j < n && a[j + 1] && b[j + 1]) { ++j; } m = 0; for (int k = i; k <= j; ++k) { f[++m] = c[k]; } ll len = j - i + 1; sl -= len * (len + 1) / 2; ans = (ans + work()) % mod; } ans = (ans + sl) % mod; printf("%lld\n", ans); } int main() { int T = 1; // scanf("%d", &T); while (T--) { solve(); } return 0; }
记得点赞
- 1
信息
- ID
- 140
- 时间
- 6000ms
- 内存
- 1024MiB
- 难度
- 10
- 标签
- 递交数
- 10
- 已通过
- 1
- 上传者