# [题解] CF1864E Guess Game [期望][Trie]

# 题目大意传送门\footnotesize^{传送门}

给一个长度为nn 的数组,每次从中随机选取两个数a,ba,b,进行游戏:
Alice 得到aaaba|b 的值,Bob 得到bbaba|b 的值,从 Alice 开始,他们轮流猜aabb 的大小关系。
如果不能确定就说 idk ,否则直接会说出a<ba\lt b,a>ba\gt ba=ba=b
问游戏进行轮数的期望

# 题解

# 一次游戏的求法

先考虑确定a,ba,b 的情况怎么求游戏轮数。
对于aba|b 中二进制位上的00,我们直接去除,因为它对a,ba,b 大小关系的判断没有影响。
aa 开始,假设现在ab=1111,a=0111a|b=1111,a=0111,那么aa 直接可以确定a<ba\lt b;否则,aa 的最高位为11aa 只能确定aabb 的最高位相等。
之后轮到bb,假设现在ab=1111,b=1011a|b=1111,b=101101110111,那么bb 直接可以确定b<ab\lt a;否则,到现在只能确定aa 的第一位为11bb 的前两位为11
发现其实每进行一轮,需要判断aa 或者bb 的、还没有确定的、最高的两位是不是有00,如果有00,那么说明手中的数更小;否则两位均为11,仍然不能确定。
那么根据两个数的大小关系,并且求得a,ba,b (去除两者均为 0 的位之后) 最高不同的位置ii,即可直接求出答案:a<ba\lt b 时,答案是2i+1(i&1)2*i+1-(i\&1)a>ba>b 时,答案是2i+(i&1)2*i+(i\&1)a=ba=b 时,答案是1的个数+11的个数+1
建议在纸上画一画,每轮(第一轮除外)其实只判断自己手里最高的两位,很容易发现规律。
Alt text

# n*n 次游戏

现在考虑两两取数,答案的总和。
我们关注的是二进制位上第一个不同数的位置,同时还要知道 a|b 的 1 的个数。前者使用 Trie 树是经典做法,而后者只需要在遍历 Trie 树时记录当前走过的 1 的个数。
在 Trie 上的每个节点计算两个贡献:

  • 一个是取出两个数不相等的情况,考虑对a>ba>ba<ba<b 的情况进行合并
    左子树大小lsls,右子树大小rsrs,当前11 的个数为cnt1cnt_1
    那么贡献lsrs(2cnt1+1)ls*rs*(2*cnt_1+1)
    贡献为节点左右子树的节点个数的乘积 乘上2cnt1+12*cnt_1+1
  • 另一个是取出两个数相等的情况
    在当前节点结束的数的个数记为cntcnt,当前11 的个数为cnt1cnt_1
    那么贡献cntcnt(cnt1+1)cnt*cnt*(cnt_1+1)

# 代码

#include<cstdio>
int read(){
	int out(0),c(getchar());
	for(;c<'0' || c>'9';c=getchar());
	for(;c<='9' && c>='0';c=getchar())
		out=(out<<1)+(out<<3)+(c^48);
	return out;
}
const int MAXN=2e5+10;
const int P=998244353;
int T,N,a[MAXN];
struct Tr{
	Tr* ch[2];
	int sum,cnt;
	Tr(){ch[0]=ch[1]=0;sum=cnt=0;}
}*rt;
long long ans;
void dfs(Tr *p,const int &c1){
	if(!p) return;
	if(p->ch[0] && p->ch[1]){
		int ls=p->ch[0]->sum,rs=p->ch[1]->sum;
		(ans+=(1ll*ls*rs)%P*(2ll*(c1+1)+1)%P)%=P;
	}
	(ans+=1ll*p->cnt*p->cnt*(c1+1)%P)%P;
	dfs(p->ch[0],c1);
	dfs(p->ch[1],c1+1);
}
void del(Tr *p){
	if(!p) return;
	del(p->ch[0]);
	del(p->ch[1]);
	delete p;
}
int qpow(int a,int b){
	if(!a) return 0;
	int res=1;
	for(;b;b>>=1,a=(long long)a*a%P)
		if(b&1) res=(long long)res*a%P;
	return res;
}
int main(){
	T=read();
	while(T--){
		N=read();
		ans=0;
		rt=new Tr;
		for(int i=1;i<=N;++i){
			a[i]=read();
			Tr *t=rt;
			for(int k=29;k>=0;--k){
				bool d=(a[i]>>k)&1;
				if(!t->ch[d])
					t->ch[d]=new Tr;
				t=t->ch[d];
				++t->sum;
			}
			++t->cnt;
		}
		dfs(rt,0);
		printf("%lld\n",1ll*ans*qpow(1ll*N*N%P,P-2)%P);
		del(rt);
	}
}
更新于 阅读次数