【题目描述】
有一个$n$行$m$列的表格,行从$0$到$n-1$编号,列从$0$到$m-1$编号。
每个格子都储存着能量。最初,第$i$行第$j$列的格子储存着$(i\ \text{xor}\ j)$点能量。所以,整个表格储存的总能量是:
$\sum\limits_{i=0}^{n-1}\sum\limits_{j=0}^{m-1}(i\ \text{xor}\ j)$
随着时间的推移,格子中的能量会渐渐减少。一个时间单位后,每个格子中的能量都会减少$1$。显然,一个格子的能量减少到$0$之后就不会再减少了。
也就是说,$k$个时间单位后,整个表格储存的总能量是:
$\sum\limits_{i=0}^{n-1}\sum\limits_{j=0}^{m-1}max((i\ xor\ j)-k,0)$
给出一个表格,求$k$个时间单位后它储存的总能量。
由于总能量可能较大,输出时对$p$取模。
题解
先把原式拆开 原式=所有满足(i^j)>k 的(i^j)之和 - 所有满足(i^j)>k 的数对(i,j)的数量 * k
可以使用数位DP来求解有多少对(i,j)满足(i^j)>k 以及 它们的(i^j)之和是多少
具体做法就是记忆化搜索 记录$limn, limm, limk$分别表示当前枚举二进制前$i$位,是否卡满$n,m$上界以及$k$下界
这个东西我也不知道怎么讲清楚 不是数位DP的基本套路吗
我用了个pair来存DP答案 first存的是有多少对满足条件 second存的是满足条件的(i^j)之和
注意取模
【代码】
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
| #include <bits/stdc++.h> #define mp make_pair using namespace std; typedef long long ll; typedef pair<ll, ll> pii;
inline ll read() { ll x = 0, f = 1; char ch = getchar(); for (; ch > '9' || ch < '0'; ch = getchar()) if (ch == '-') f = -1; for (; ch <= '9' && ch >= '0'; ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ '0'); return x * f; }
ll t, n, m, k, mod, mx; pii dp[105][2][2][2], ans; bool vis[105][2][2][2];
inline ll getl(ll x) { ll ret = 0; while (x) ret++, x >>= 1; return ret; }
pii dfs(ll d, bool limn, bool limm, bool limk) { if (d > mx) return mp(1, 0); if (vis[d][limn][limm][limk]) return dp[d][limn][limm][limk]; ll N = (n>>(mx-d)) & 1, M = (m>>(mx-d)) & 1, K = (k>>(mx-d)) & 1; for (ll i = 0; i <= (limn ? N : 1); i++) { for (ll j = 0; j <= (limm ? M : 1); j++) { if (limk && (i^j) < K) continue; pii res = dfs(d+1, limn&&(i==N), limm&&(j==M), limk&&((i^j)==K)); dp[d][limn][limm][limk].first = (dp[d][limn][limm][limk].first + res.first) % mod; dp[d][limn][limm][limk].second = (((dp[d][limn][limm][limk].second + res.second) % mod) + (1ll << (mx - d)) * (i^j) % mod * res.first % mod) % mod; } } vis[d][limn][limm][limk] = true; return dp[d][limn][limm][limk]; }
int main() { t = read(); while (t--) { n = read(); m = read(); k = read(); mod = read(); n--, m--; memset(dp, 0, sizeof(dp)); memset(vis, 0, sizeof(vis)); mx = max(max(getl(n), getl(m)), getl(k)); ans = dfs(1, 1, 1, 1); printf("%lld\n", (ans.second - k % mod * ans.first % mod + mod) % mod); } return 0; }
|