3 条题解

  • 6
    @ 2024-7-8 17:36:48

    好神仙的题目。赛时胡了一个状态和转移都和官解不同的做法,得到了 O(n10m)O(n10^m) 的优秀复杂度。卡了一场常卡进了 7575 分。这个做法和官解关系不大,并且很难进行最后的优化部分,所以在此不再赘述。

    首先考虑 k=1k=1 的情况。考虑记录一些状态能够描述子树内的选择方案,00 表示整个子树没有被覆盖过,11 表示子树内部有点被覆盖过并且子树外的点还能被覆盖,22 表示子树内部有点被覆盖过并且子树外的点不能被覆盖了。考虑转移,需要把转移描述为只和 u,vu,v 有关的形式才能较为简单的扩展到 k1k\neq 1 的情况。发现对于 121\rightarrow 2 的转移,很难描述为 u,vu,v 的形式,因为需要出现两个子树为 11 或者根节点被选择才能转移到 22。所以考虑记录辅助状态 33 表示出现过至少 2211 的方案。那么转移有以下 88 种:

    (0,0)0(0,0)\rightarrow 0 (0,1)1  (1,0)1(0,1)\rightarrow 1 \ \ (1,0)\rightarrow 1 (0,2)2  (2,0)2(0,2)\rightarrow 2 \ \ (2,0)\rightarrow 2 (3,0)3  (3,1)3(3,0)\rightarrow 3 \ \ (3,1)\rightarrow 3 (1,1)3(1,1)\rightarrow 3

    上面没有出现过的转移为不合法或者不存在对应状态。这么转移之后再考虑和根节点是否选择合并的转移,那么有:

    $$(0,0)\rightarrow 0 \ \ (1,0)\rightarrow 1 \ \ (2,0)\rightarrow 2 \ \ (3,0)\rightarrow 3 \ \ $$$$(0,1)\rightarrow 3 \ \ (1,1)\rightarrow3 \ \ (3,1)\rightarrow 3 \ \ $$

    转移的同时计入 p,ap,a 两个数组的贡献。最后将 33 状态放到 1,21,2 两种状态即可。因为 33 状态对应的状态可以封口也可以不封口。复杂度 O(n)O(n)

    考虑对于 k1k\neq 1 的情况,每一位暴力枚举上面的 88 种转移,第一部分的转移复杂度是 O(8k)O(8^k) 的。对于复合根节点情况的部分,暴力枚举根节点状态显然不优,可以类似 FMT 的对每一位依次进行变换,也就是逐位枚举根节点状态并处理这一位变换后的位置。复杂度为 O(k4k)O(k4^k)。对于 33 状态的下放可以用类似的做法也做到 O(k4k)O(k4^k)。复杂度 O(n(8k+k4k))O(n(8^k+k4^k)),视常数可以获得 458545\sim 85 分。

    考虑优化,目前的瓶颈在于 O(8k)O(8^k) 的部分。一个很神秘的做法是考虑到如果没有辅助状态 33,那么转移只有 O(5k)O(5^k)。所以考虑枚举儿子的一些位置的状态钦定为 33,由于对于 33 的转移是和 0/10/1 复合之后仍然为 33,所以为 33 的位可以让它的值为对应位为 0/10/1 的和。类似 OR 卷积的 FWT,经过一次正变换之后为 33 的位置真实值可以为 0011。然后对变换之后的部分进行 O(5k)O(5^k) 的转移,但是多了 33 的状态,由于经过了变换,只需要加入 (3,3)3(3,3)\rightarrow 3 的转移。这部分转移的复杂度是 O(6k)O(6^k) 的。对于转移之后 33 的位置,他们是从 (0/1,0/1)(0/1,0/1) 转移过来的,所以真实值可能是 0/1/30/1/3,所以要进行一次类似 OR 卷积的 IFWT 让他变成真实值为 33 的值。FWT 和 IFWT 的复杂度是 O(k4k)O(k4^k),所以总的复杂度就是 O(n(6k+k4k))O(n(6^k+k4^k)),可以通过。

    #include<bits/stdc++.h>
    using namespace std;
    struct edge{int v,nxt;}e[205];
    int n,m,u,v,cnt,h[105],w[105][256],p[105][8],dp[105][1<<16],num,tmp[1<<16];
    void add(int u,int v){e[++cnt]={v,h[u]};h[u]=cnt;}
    const int mod=998244353;
    void Add(int &x,int y){x=(x+y>=mod?x+y-mod:x+y);}
    struct node{int x,y,z;}go[2000005];
    void init(int k,int x,int y,int z)
    {
    	if(k==m){go[++num]={x,y,z};return;}
    	init(k+1,x,y,z);
    	init(k+1,x,y|(1<<(k<<1)),z|(1<<(k<<1)));
    	init(k+1,x|(1<<(k<<1)),y,z|(1<<(k<<1)));
    	init(k+1,x|(2<<(k<<1)),y,z|(2<<(k<<1)));
    	init(k+1,x,y|(2<<(k<<1)),z|(2<<(k<<1)));
    	init(k+1,x|(3<<(k<<1)),y|(3<<(k<<1)),z|(3<<(k<<1)));
    }
    void fwt(int *a)
    {
    	for(int i=0;i<m;i++)
    	{
    		for(int s=0;s<(1<<(m<<1));s++)
    		{
    			int c=(s>>(i<<1))&3;
    			if(c==0)Add(a[s+(3<<(i<<1))],a[s]);
    			else if(c==1)Add(a[s+(2<<(i<<1))],a[s]); 
    		}
    	}
    }
    void ifwt(int *a)
    {
    	for(int i=0;i<m;i++)
    	{
    		for(int s=0;s<(1<<(m<<1));s++)
    		{
    			int c=(s>>(i<<1))&3;
    			if(c==3)Add(a[s],mod-a[s-(3<<(i<<1))]),Add(a[s],mod-a[s-(2<<(i<<1))]);
    		}
    	}
    }
    void dfs(int u,int fa)
    {
    	dp[u][0]=1;
    	for(int i=h[u];i;i=e[i].nxt)
    	{
    		int v=e[i].v;
    		if(v==fa)continue;
    		dfs(v,u);
    		for(int s=0;s<(1<<(m<<1));s++)tmp[s]=dp[u][s],dp[u][s]=0;
    		fwt(tmp);fwt(dp[v]);
    		for(int s=1;s<=num;s++)Add(dp[u][go[s].z],1ll*tmp[go[s].x]*dp[v][go[s].y]%mod);
    		ifwt(dp[u]);
    	}
    	for(int i=0;i<m;i++)
    	{
    		for(int s=0;s<(1<<(m<<1));s++)tmp[s]=dp[u][s],dp[u][s]=0;
    		for(int s=0;s<(1<<(m<<1));s++)
    		{
    			int c=(s>>(i<<1))&3;
    			if(c==0)
    			{
    				Add(dp[u][s],1ll*tmp[s]*(mod+1-p[u][i])%mod);
    				Add(dp[u][s|(3<<(i<<1))],1ll*tmp[s]*p[u][i]%mod);
    			}
    			else if(c==1)
    			{
    				Add(dp[u][s],1ll*tmp[s]*(mod+1-p[u][i])%mod);
    				Add(dp[u][s|(2<<(i<<1))],1ll*tmp[s]*p[u][i]%mod);
    			}
    			else if(c==2)
    			{
    				Add(dp[u][s],1ll*tmp[s]*(mod+1-p[u][i])%mod);
    			}
    			else 
    			{
    				Add(dp[u][s],tmp[s]);
    			}
    		}
    	}
    	for(int s=0;s<(1<<(m<<1));s++)
    	{
    		int ns=0;
    		for(int i=0;i<m;i++)if((s>>(i<<1))&1)ns|=(1<<i);
    		dp[u][s]=1ll*dp[u][s]*w[u][ns]%mod;
    	}
    	for(int i=0;i<m;i++)
    	{
    		for(int s=0;s<(1<<(m<<1));s++)tmp[s]=dp[u][s],dp[u][s]=0;
    		for(int s=0;s<(1<<(m<<1));s++)
    		{
    			if(((s>>(i<<1))&3)==3)
    			{
    				Add(dp[u][s-(1<<(i<<1))],tmp[s]);
    				Add(dp[u][s-(2<<(i<<1))],tmp[s]); 
    			}
    			else Add(dp[u][s],tmp[s]);
    		}
    	}
    }
    int main()
    {
    	//freopen("e.in","r",stdin);
    	cin.tie(0)->sync_with_stdio(0);
    	cin>>n>>m;
    	for(int i=1;i<n;i++)
    	{
    		cin>>u>>v;
    		add(u,v);add(v,u);
    	}
    	for(int i=0;i<m;i++)for(int j=1;j<=n;j++)cin>>p[j][i];
    	for(int i=1;i<=n;i++)
    	{
    		for(int s=0;s<(1<<m);s++)cin>>w[i][s];
    	}
    	init(0,0,0,0);dfs(1,0);
    	int ans=0;
    	for(int s=0;s<(1<<(m<<1));s++)
    	{
    		int flag=1;
    		for(int i=0;i<m;i++)flag&=(((s>>(i<<1))&3)!=1);
    		if(flag)Add(ans,dp[1][s]);
    	}
    	cout<<ans;
    	return 0;
    }
    
    • 5
      @ 2024-7-8 11:59:33

      Solution

      非常好的题目!和正解差一个神秘的 FWT\rm FWT


      考虑树形 DP。如果你做了上一场梦熊周赛的 T3,你会发现他们惊人的相似:每次闪灯,对每个点记录状态 0/1/20/1/2 分别表示子树内没有闪灯的点、子树内有闪灯的点且当前节点闪灯、子树内有闪灯的点且子树外节点不能有闪灯。

      由于有多次闪灯,可以考虑将不同的闪灯情况状态压缩。

      如何转移?首先注意到,对于一次闪灯,如果 uu 存在两棵子树都是状态 11,那么 uu 节点一定闪灯;如果 uu 只有一棵子树状态为 11,那么 uu 必须以 pup_u 的概率选中才能闪灯。因此你还要新增状态 33(只在转移的过程中使用)表示子树内是否有 22 棵以上是状态 11

      因此我们维护 Dpu,{0,1,2,3}kDp_{u,\{0,1,2,3\}^k}。初始有 Dpu,0=1Dp_{u,0}=1。新加入节点 vv 时和 dpv,{0,1,2}kdp_{v,\{0,1,2\}^k} 作如下卷积:

      对于每一位独立考虑,有(下面的都是“uu 的状态,vv 的状态,结果状态”的顺序):

      • (0,0)0(0,0) \to 0
      • (1,0)1(1,0) \to 1
      • (2,0)2(2,0) \to 2
      • (3,0)3(3,0) \to 3
      • (0,1)1(0,1) \to 1
      • (1,1)3(1,1) \to 3
      • (3,1)3(3,1) \to 3
      • (0,2)2(0,2) \to 2

      处理 Dpu,{0,1,2,3}kDp_{u,\{0,1,2,3\}^k} 后,考虑 uu 节点的闪灯情况。每一维,如果闪灯,030 \to 3131 \to 3333 \to 3212 \to -1(不合法);如果不闪灯,000 \to 0111 \to 1222 \to 2333 \to 3。这些转移加到 dpu,{0,1,2,3}kdp_{u,\{0,1,2,3\}^k} 上。

      然后将 au,{0,1}ka_{u,\{0,1\}^k} 的优美值乘到 dpu,{0,1,2,3}kdp_{u,\{0,1,2,3\}^k} 中。

      这一步结束之后,可以将 33 状态转化为 1122,得到 dpu,{0,1,2}kdp_{u,\{0,1,2\}^k}

      这样得到了 O(n(8k+k4k))O(n (8^k+k 4^k)) 的做法,实现的好可以获得 7575 分。


      着手优化卷积。

      如果我们只考虑结果是 0/1/20/1/2 的转移,那么只需要 O(5k)O(5^k) 就可以完成单次卷积。

      那么结果是 33 的部分,就可以用结果是 0/1/30/1/3 的部分减去结果是 0/10/1 的来计算。

      因此考虑做一种类似 FWT\rm FWT 的操作:先将 0/10/1 状态钦定一部分变成 33,然后只处理 0/1/20/1/2 之间的转移和 (3,3)3(3,3) \to 3 的转移,再做 DFWT\rm DFWT33 减去 0/10/1 状态。

      这样得到了 O(n(6k+k4k))O(n (6^k+k 4^k)) 的做法,足以通过本题。

      #include<bits/stdc++.h>
      #define ll long long
      #define ffor(i,a,b) for(int i=(a);i<=(b);i++)
      #define roff(i,a,b) for(int i=(a);i>=(b);i--)
      using namespace std;
      const int MAXN=100+10,MAX4=65536+10,MOD=998244353,MAXK=2e7+10;
      int n,k,ans,dp[MAXN][MAX4],gain[MAX4],tot;
      ll p[MAXN][10],a[MAXN][300];
      pair<pair<int,int>,int> trans[MAXK];
      void dfss(int dep,int x,int y,int z) {
      	if(dep==k) return trans[++tot]={{x,y},z},void();	
      	dfss(dep+1,(x<<2)|0,(y<<2)|0,(z<<2)|0);
      	dfss(dep+1,(x<<2)|1,(y<<2)|0,(z<<2)|1);
      	dfss(dep+1,(x<<2)|2,(y<<2)|0,(z<<2)|2);
      	dfss(dep+1,(x<<2)|0,(y<<2)|1,(z<<2)|1);
      	dfss(dep+1,(x<<2)|0,(y<<2)|2,(z<<2)|2);
      	dfss(dep+1,(x<<2)|3,(y<<2)|3,(z<<2)|3);
      	return ;
      }
      int tmp[MAX4];
      vector<int> G[MAXN];
      void fwt(int *f,int op) {
      	if(op==1) {
      		ffor(i,1,k) {
      			ffor(j,0,(1<<k+k)-1) {
      				int v=(j>>(i+i-2))&3;
      				if(v==0) f[j+(1<<i+i-2)+(1<<i+i-1)]=(f[j+(1<<i+i-2)+(1<<i+i-1)]+f[j])%MOD;
      				else if(v==1) f[j+(1<<i+i-1)]=(f[j+(1<<i+i-1)]+f[j])%MOD;
      			}
      		}
      	}
      	else {
      		ffor(i,1,k) {
      			roff(j,(1<<k+k)-1,0) {
      				int v=(j>>(i+i-2))&3;
      				if(v==3) f[j]=((f[j]-f[j-(1<<i+i-2)-(1<<i+i-1)])%MOD-f[j-(1<<i+i-1)])%MOD;
      			}
      		}
      	}
      	return ;
      }
      void dfs(int u,int f) {
      	dp[u][0]=1;
      	for(auto v:G[u]) if(v!=f) {
      		dfs(v,u);
      		fwt(dp[u],1),fwt(dp[v],1);
      		memset(tmp,0,sizeof(tmp));
      		ffor(i,1,tot) {
      			auto pr=trans[i];
      			int x=pr.first.first,y=pr.first.second,to=pr.second;
      			if(dp[u][x]) tmp[to]=(tmp[to]+1ll*dp[u][x]*dp[v][y])%MOD;	
      		}
      		fwt(tmp,-1);
      		memcpy(dp[u],tmp,sizeof(tmp));
      	}
      	ffor(i,1,k) {
      		memset(tmp,0,sizeof(tmp));
      		ffor(j,0,(1<<k+k)-1) {
      			int v=(j>>(i+i-2))&3;
      			if(v==0) {
      				int nw=j;
      				tmp[nw]=(tmp[nw]+dp[u][j]*(1-p[u][i]))%MOD;
      				nw=j+(1<<(i+i-1))+(1<<(i+i-2));
      				tmp[nw]=(tmp[nw]+dp[u][j]*p[u][i])%MOD;
      			}
      			else if(v==1) {
      				int nw=j;
      				tmp[nw]=(tmp[nw]+dp[u][j]*(1-p[u][i]))%MOD;
      				nw=j+(1<<(i+i-1));
      				tmp[nw]=(tmp[nw]+dp[u][j]*p[u][i])%MOD;
      			}
      			else if(v==2) {
      				int nw=j;
      				tmp[nw]=(tmp[nw]+dp[u][j]*(1-p[u][i]))%MOD;
      			}
      			else {
      				int nw=j;
      				tmp[nw]=(tmp[nw]+dp[u][j])%MOD;
      			}
      		}
      		memcpy(dp[u],tmp,sizeof(tmp));
      	}
      	ffor(i,0,(1<<k+k)-1) dp[u][i]=dp[u][i]*a[u][gain[i]]%MOD;
      	ffor(i,1,k) {
      		memset(tmp,0,sizeof(tmp));
      		ffor(j,0,(1<<k+k)-1) {
      			int v=(j>>(i+i-2))&3;
      			if(v==0) {
      				int nw=j;
      				tmp[nw]=(tmp[nw]+dp[u][j])%MOD;
      			}
      			else if(v==1) {
      				int nw=j;
      				tmp[nw]=(tmp[nw]+dp[u][j])%MOD;
      			}
      			else if(v==2) {
      				int nw=j;
      				tmp[nw]=(tmp[nw]+dp[u][j])%MOD;
      			}
      			else {
      				int nw=j-(1<<(i+i-1));
      				tmp[nw]=(tmp[nw]+dp[u][j])%MOD;
      				nw=j-(1<<(i+i-2));
      				tmp[nw]=(tmp[nw]+dp[u][j])%MOD;
      			}
      		}
      		memcpy(dp[u],tmp,sizeof(tmp));
      	}
      	return ;
      }
      signed main() {
      	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
      	cin>>n>>k;
      	dfss(0,0,0,0);
      	ffor(i,1,n-1) {
      		int u,v;
      		cin>>u>>v;
      		G[u].push_back(v),G[v].push_back(u);	
      	}
      	ffor(i,1,k) ffor(j,1,n) cin>>p[j][i];
      	ffor(i,1,n) ffor(s,0,(1<<k)-1) cin>>a[i][s];
      	ffor(i,0,(1<<(k+k))-1) {
      		ffor(j,0,k-1) {
      			int psl=(i>>(j+j))&3;
      			if(psl==1||psl==3) gain[i]|=(1<<j);
      		}
      	}
      	dfs(1,0);
      	ffor(i,0,(1<<(k+k))-1) {
      		int flg=0;
      		ffor(j,1,k) {
      			int v=(i>>(j+j-2))&3;
      			if(v!=0&&v!=2) flg=1;
      		}
      		if(flg==1) continue ;
      		ans=(ans+dp[1][i])%MOD;
      	}
      	cout<<(ans%MOD+MOD)%MOD;
      	return 0;
      }
      
      • -6
        @ 2024-7-26 18:06:01

        114514.cn

        • 1

        信息

        ID
        11
        时间
        2000ms
        内存
        512MiB
        难度
        9
        标签
        递交数
        56
        已通过
        10
        上传者