计数DP
https://vjudge.net/contest/237357#problem/H
题意:给出n个数,从这些数中选出两组数S,T,使得S中的所有数的下标都比T中的数的下标小,且S集合的异或和等于T集合的与结果,问有多少种方案?
解法:DP
dp1[i][j]:由0~i的元素异或得到j的种类数。
dp2[i][j]:由i~n-1的元素AND得到j的种类数。
dp2[i][j]:由i~n-1的元素AND得到j的种类数。
dp3[i][j]:由i~n-1的元素,且一定包含a[i],AND得到j的种类数。
求出这些,最后把dp1[i][j]*dp3[i+1][j]求和就能得到答案了。
这里多用了一个数组dp3,而不是直接用dp2,是为了防止重复计数。(注意该细节!)
这里多用了一个数组dp3,而不是直接用dp2,是为了防止重复计数。(注意该细节!)
注意:dp数组记录的是对应区间至少取了一个数的方案数,如果一个数都没有取,则认为该状态不可达(因为S,T中一定要有数),所以计算dp[i][j]时首先要判断上一个状态是否可达。
1 #include <bits/stdc++.h> 2 #include <iostream> 3 #include <algorithm> 4 #include <cstdio> 5 #include <cstring> 6 #include <string> 7 #include <cmath> 8 #include <cstdlib> 9 #include <queue> 10 #include <stack> 11 #include <map> 12 #include <vector> 13 #include <set> 14 #include <bitset> 15 #include <iomanip> 16 #define ms(a, b) memset(a, b, sizeof(a)); 17 using namespace std; 18 typedef long long LL; 19 typedef pair<int, int> pii; 20 const int INF = 0x3f3f3f3f; 21 const int maxn = 1030; 22 const int MAXN = 2e4 + 10; 23 const double eps = 1e-8; 24 const LL mod = 1e9 + 7; 25 int a[maxn]; 26 LL dp1[maxn][maxn], dp2[maxn][maxn], dp3[maxn][maxn]; 27 LL ans; 28 int n; 29 30 void solve() { 31 dp1[0][a[0]] = 1; 32 for(int i = 1; i < n - 1; i++) { 33 dp1[i][a[i]]++; 34 for(int j = 0; j < maxn; j++) { 35 if(dp1[i-1][j]) { 36 dp1[i][j] += dp1[i-1][j]; 37 dp1[i][j] %= mod; 38 dp1[i][j^a[i]] += dp1[i-1][j]; 39 dp1[i][j^a[i]] %= mod; 40 } 41 } 42 } 43 dp2[n-1][a[n-1]] = 1; 44 dp3[n-1][a[n-1]] = 1; 45 for(int i = n - 2; i > 0; i--) { 46 dp2[i][a[i]]++; 47 dp3[i][a[i]]++; 48 for(int j = 0; j < maxn; j++) { 49 if(dp2[i+1][j]) { 50 dp2[i][j] += dp2[i+1][j]; 51 dp2[i][j] %= mod; 52 dp2[i][j&a[i]] += dp2[i+1][j]; 53 dp2[i][j&a[i]] %= mod; 54 dp3[i][j&a[i]] += dp2[i+1][j]; 55 dp3[i][j&a[i]] %= mod; 56 } 57 } 58 } 59 for(int i = 0; i < n - 1; i++) { 60 for(int j = 0; j < maxn; j++) { 61 if(dp1[i][j] && dp3[i+1][j]) { 62 // cout << i << " " << j << " " << dp1[i][j] << " " << dp3[i+1][j] << endl; 63 ans = (ans + dp1[i][j] * dp3[i+1][j] % mod) % mod; 64 } 65 } 66 } 67 printf("%lld\n", ans); 68 } 69 70 int main() 71 { 72 #ifdef local 73 freopen("case.in","r",stdin); 74 // freopen("out.in","w",stdout); 75 #endif 76 int T; 77 scanf("%d", &T); 78 while(T--) { 79 scanf("%d", &n); 80 ms(dp1, 0); 81 ms(dp2, 0); 82 ms(dp3, 0); 83 for(int i = 0; i < n; i++) { 84 scanf("%d", &a[i]); 85 } 86 ans = 0; 87 solve(); 88 } 89 return 0; 90 }
版权声明:本文为Sissi-hss原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。