3 条题解

  • 5
    @ 2024-7-8 11:59:33

    Solution

    非常好的题目!和正解差一个神秘的 FWT\rm FWT


    考虑树形 DP。如果你做了上一场梦熊周赛的 T3,你会发现他们惊人的相似:每次闪灯,对每个点记录状态 0/1/20/1/2 分别表示子树内没有闪灯的点、子树内有闪灯的点且当前节点闪灯、子树内有闪灯的点且子树外节点不能有闪灯。

    由于有多次闪灯,可以考虑将不同的闪灯情况状态压缩。

    如何转移?首先注意到,对于一次闪灯,如果 uu 存在两棵子树都是状态 11,那么 uu 节点一定闪灯;如果 uu 只有一棵子树状态为 11,那么 uu 必须以 pup_u 的概率选中才能闪灯。因此你还要新增状态 33(只在转移的过程中使用)表示子树内是否有 22 棵以上是状态 11

    因此我们维护 Dpu,{0,1,2,3}kDp_{u,\{0,1,2,3\}^k}。初始有 Dpu,0=1Dp_{u,0}=1。新加入节点 vv 时和 dpv,{0,1,2}kdp_{v,\{0,1,2\}^k} 作如下卷积:

    对于每一位独立考虑,有(下面的都是“uu 的状态,vv 的状态,结果状态”的顺序):

    • (0,0)0(0,0) \to 0
    • (1,0)1(1,0) \to 1
    • (2,0)2(2,0) \to 2
    • (3,0)3(3,0) \to 3
    • (0,1)1(0,1) \to 1
    • (1,1)3(1,1) \to 3
    • (3,1)3(3,1) \to 3
    • (0,2)2(0,2) \to 2

    处理 Dpu,{0,1,2,3}kDp_{u,\{0,1,2,3\}^k} 后,考虑 uu 节点的闪灯情况。每一维,如果闪灯,030 \to 3131 \to 3333 \to 3212 \to -1(不合法);如果不闪灯,000 \to 0111 \to 1222 \to 2333 \to 3。这些转移加到 dpu,{0,1,2,3}kdp_{u,\{0,1,2,3\}^k} 上。

    然后将 au,{0,1}ka_{u,\{0,1\}^k} 的优美值乘到 dpu,{0,1,2,3}kdp_{u,\{0,1,2,3\}^k} 中。

    这一步结束之后,可以将 33 状态转化为 1122,得到 dpu,{0,1,2}kdp_{u,\{0,1,2\}^k}

    这样得到了 O(n(8k+k4k))O(n (8^k+k 4^k)) 的做法,实现的好可以获得 7575 分。


    着手优化卷积。

    如果我们只考虑结果是 0/1/20/1/2 的转移,那么只需要 O(5k)O(5^k) 就可以完成单次卷积。

    那么结果是 33 的部分,就可以用结果是 0/1/30/1/3 的部分减去结果是 0/10/1 的来计算。

    因此考虑做一种类似 FWT\rm FWT 的操作:先将 0/10/1 状态钦定一部分变成 33,然后只处理 0/1/20/1/2 之间的转移和 (3,3)3(3,3) \to 3 的转移,再做 DFWT\rm DFWT33 减去 0/10/1 状态。

    这样得到了 O(n(6k+k4k))O(n (6^k+k 4^k)) 的做法,足以通过本题。

    #include<bits/stdc++.h>
    #define ll long long
    #define ffor(i,a,b) for(int i=(a);i<=(b);i++)
    #define roff(i,a,b) for(int i=(a);i>=(b);i--)
    using namespace std;
    const int MAXN=100+10,MAX4=65536+10,MOD=998244353,MAXK=2e7+10;
    int n,k,ans,dp[MAXN][MAX4],gain[MAX4],tot;
    ll p[MAXN][10],a[MAXN][300];
    pair<pair<int,int>,int> trans[MAXK];
    void dfss(int dep,int x,int y,int z) {
    	if(dep==k) return trans[++tot]={{x,y},z},void();	
    	dfss(dep+1,(x<<2)|0,(y<<2)|0,(z<<2)|0);
    	dfss(dep+1,(x<<2)|1,(y<<2)|0,(z<<2)|1);
    	dfss(dep+1,(x<<2)|2,(y<<2)|0,(z<<2)|2);
    	dfss(dep+1,(x<<2)|0,(y<<2)|1,(z<<2)|1);
    	dfss(dep+1,(x<<2)|0,(y<<2)|2,(z<<2)|2);
    	dfss(dep+1,(x<<2)|3,(y<<2)|3,(z<<2)|3);
    	return ;
    }
    int tmp[MAX4];
    vector<int> G[MAXN];
    void fwt(int *f,int op) {
    	if(op==1) {
    		ffor(i,1,k) {
    			ffor(j,0,(1<<k+k)-1) {
    				int v=(j>>(i+i-2))&3;
    				if(v==0) f[j+(1<<i+i-2)+(1<<i+i-1)]=(f[j+(1<<i+i-2)+(1<<i+i-1)]+f[j])%MOD;
    				else if(v==1) f[j+(1<<i+i-1)]=(f[j+(1<<i+i-1)]+f[j])%MOD;
    			}
    		}
    	}
    	else {
    		ffor(i,1,k) {
    			roff(j,(1<<k+k)-1,0) {
    				int v=(j>>(i+i-2))&3;
    				if(v==3) f[j]=((f[j]-f[j-(1<<i+i-2)-(1<<i+i-1)])%MOD-f[j-(1<<i+i-1)])%MOD;
    			}
    		}
    	}
    	return ;
    }
    void dfs(int u,int f) {
    	dp[u][0]=1;
    	for(auto v:G[u]) if(v!=f) {
    		dfs(v,u);
    		fwt(dp[u],1),fwt(dp[v],1);
    		memset(tmp,0,sizeof(tmp));
    		ffor(i,1,tot) {
    			auto pr=trans[i];
    			int x=pr.first.first,y=pr.first.second,to=pr.second;
    			if(dp[u][x]) tmp[to]=(tmp[to]+1ll*dp[u][x]*dp[v][y])%MOD;	
    		}
    		fwt(tmp,-1);
    		memcpy(dp[u],tmp,sizeof(tmp));
    	}
    	ffor(i,1,k) {
    		memset(tmp,0,sizeof(tmp));
    		ffor(j,0,(1<<k+k)-1) {
    			int v=(j>>(i+i-2))&3;
    			if(v==0) {
    				int nw=j;
    				tmp[nw]=(tmp[nw]+dp[u][j]*(1-p[u][i]))%MOD;
    				nw=j+(1<<(i+i-1))+(1<<(i+i-2));
    				tmp[nw]=(tmp[nw]+dp[u][j]*p[u][i])%MOD;
    			}
    			else if(v==1) {
    				int nw=j;
    				tmp[nw]=(tmp[nw]+dp[u][j]*(1-p[u][i]))%MOD;
    				nw=j+(1<<(i+i-1));
    				tmp[nw]=(tmp[nw]+dp[u][j]*p[u][i])%MOD;
    			}
    			else if(v==2) {
    				int nw=j;
    				tmp[nw]=(tmp[nw]+dp[u][j]*(1-p[u][i]))%MOD;
    			}
    			else {
    				int nw=j;
    				tmp[nw]=(tmp[nw]+dp[u][j])%MOD;
    			}
    		}
    		memcpy(dp[u],tmp,sizeof(tmp));
    	}
    	ffor(i,0,(1<<k+k)-1) dp[u][i]=dp[u][i]*a[u][gain[i]]%MOD;
    	ffor(i,1,k) {
    		memset(tmp,0,sizeof(tmp));
    		ffor(j,0,(1<<k+k)-1) {
    			int v=(j>>(i+i-2))&3;
    			if(v==0) {
    				int nw=j;
    				tmp[nw]=(tmp[nw]+dp[u][j])%MOD;
    			}
    			else if(v==1) {
    				int nw=j;
    				tmp[nw]=(tmp[nw]+dp[u][j])%MOD;
    			}
    			else if(v==2) {
    				int nw=j;
    				tmp[nw]=(tmp[nw]+dp[u][j])%MOD;
    			}
    			else {
    				int nw=j-(1<<(i+i-1));
    				tmp[nw]=(tmp[nw]+dp[u][j])%MOD;
    				nw=j-(1<<(i+i-2));
    				tmp[nw]=(tmp[nw]+dp[u][j])%MOD;
    			}
    		}
    		memcpy(dp[u],tmp,sizeof(tmp));
    	}
    	return ;
    }
    signed main() {
    	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
    	cin>>n>>k;
    	dfss(0,0,0,0);
    	ffor(i,1,n-1) {
    		int u,v;
    		cin>>u>>v;
    		G[u].push_back(v),G[v].push_back(u);	
    	}
    	ffor(i,1,k) ffor(j,1,n) cin>>p[j][i];
    	ffor(i,1,n) ffor(s,0,(1<<k)-1) cin>>a[i][s];
    	ffor(i,0,(1<<(k+k))-1) {
    		ffor(j,0,k-1) {
    			int psl=(i>>(j+j))&3;
    			if(psl==1||psl==3) gain[i]|=(1<<j);
    		}
    	}
    	dfs(1,0);
    	ffor(i,0,(1<<(k+k))-1) {
    		int flg=0;
    		ffor(j,1,k) {
    			int v=(i>>(j+j-2))&3;
    			if(v!=0&&v!=2) flg=1;
    		}
    		if(flg==1) continue ;
    		ans=(ans+dp[1][i])%MOD;
    	}
    	cout<<(ans%MOD+MOD)%MOD;
    	return 0;
    }
    

    信息

    ID
    11
    时间
    2000ms
    内存
    512MiB
    难度
    9
    标签
    递交数
    56
    已通过
    10
    上传者