【LOJ#6156】A*B Problem

  • 本文为博主原创,未经许可不得转载

我们先计算下面的式子:

$Ans_k=\sum_{i=1}^{n}\sum_{j=1}^{n}a_i a_j \equiv k(mod m)$

$Ans_k=\sum_{i=1}^{n}\sum_{j=1}^{n} cnt_i cnt_j [ij\equiv k(mod m)]$

我们取离散对数变成循环卷积的形式

$Ans_{g^k}=\sum_{i=1}^{n}\sum_{j=1}^{n} cnt_{g^i} cnt_{g^j} [g^{i+j}\equiv g^k (mod m)]$

$Ans_{g^k}=\sum_{i+j\equiv k(mod \phi(m))}cnt_{g^i} cnt_{g^j} $

然后减去不合法的部分就行了

零的答案要单独算,根据乘除性也是很好处理的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include<bits/stdc++.h>
#define FILE "read"
#define MAXN 300010
typedef long long ll;
const double pi=acos(-1);
int n,m,G,cnt[MAXN],scnt[MAXN],R[MAXN],b[MAXN];ll ans[MAXN];
std::vector<int>fac;
char buf[1<<15],*fs,*ft;
inline char getc(){return (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<15,stdin),fs==ft))?0:*fs++;}
inline int read(){
int x=0,f=1; char ch=getc();
while(ch>'9'||ch<'0') {if(ch=='-') f=-1; ch=getc();}
while(ch>='0'&&ch<='9') {x=x*10+ch-'0'; ch=getc();}
return x*f;
}
struct complex{
double r,v;
complex(double a=0,double b=0):r(a),v(b){}
inline complex operator+(const complex &b){return complex(r+b.r,v+b.v);}
inline complex operator-(const complex &b){return complex(r-b.r,v-b.v);}
inline complex operator*(const complex &b){return complex(r*b.r-v*b.v,r*b.v+v*b.r);}
}a[MAXN],w[MAXN];
inline void swap(complex &a,complex &b){complex t(a);a=b;b=t;}
void FFT(complex *a,int L,int f){
for(int i=0;i<L;++i) if(i<R[i]) swap(a[i],a[R[i]]);
for(int len=2;len<=L;len<<=1){
int l=len>>1; //complex wn(cos(pi/l),f*sin(pi/l));
for(int i=1;i<l;i++) w[i]=w[i-1]*wn;
for(int st=0;st<L;st+=len) for(int k=0;k<l;++k){
complex x=a[st+k],y=w[k]*a[st+k+l];
a[st+k]=x+y; a[st+k+l]=x-y;
}
}
if(f==-1) for(int i=0;i<L;++i) a[i].r/=L;
}
int pow(int a,int b,int mod){
int ret=1;
while(b){
if(b&1) ret=1LL*ret*a%mod;
b>>=1; a=1LL*a*a%mod;
}return ret;
}
int getG(int m){
int phi=m-1; fac.clear();
for(int i=2;i<=sqrt(phi);++i)if(phi%i==0){
fac.push_back(i);
fac.push_back(phi/i);
}
for(int i=2;i<m;++i){
int flag=1;
for(int j=0;j<fac.size();++j)
if(pow(i,phi/fac[j],m)==1) {flag=0;break;}
if(flag) return i;
}
}
void solve(){
n=read(); m=read(); G=getG(m);
int tot=0,c=0; w[0]=1;
for(int i=0;i<=m;++i) cnt[i]=scnt[i]=0;
for(int i=0;i<=m;++i) b[i]=-1;
for(int i=1;i<=n;++i){
int x=read(); ++cnt[x%m];
++scnt[1LL*x*x%m]; c+=(x%m==0);
}
for(int i=1;b[i]==-1;i=1LL*i*G%m) b[i]=tot++;//离散对数
int L=1,H=0; while(L<m+m) L<<=1,++H;
for(int i=0;i<L;++i) R[i]=(R[i>>1]>>1)|((i&1)<<(H-1));
for(int i=0;i<L;++i) a[i]=0;
for(int i=1;i<m;++i) a[b[i]]=a[b[i]]+cnt[i];
FFT(a,L,1);
for(int i=0;i<L;++i) a[i]=a[i]*a[i];
FFT(a,L,-1);
ans[0]=1LL*n*(n-1)/2-1LL*(n-c)*(n-c-1)/2;
for(int i=1;i<m;++i) ans[i]=(a[b[i]].r+a[b[i]+m-1].r-scnt[i])/2+0.5;
for(int i=0;i<m;++i) printf("%lld\n",ans[i]);
}
int main(){
freopen(FILE".in","r",stdin);
freopen(FILE".out","w",stdout);
int T=read();
while(T--) solve();
return 0;
}
文章目录
,