【bzoj4543】Hotel加强版

  • 长链剖分优化树形dp真是一颗赛艇
  • 获得成就:第一次使用指针

首先我们设:

$f[i][j]$ 表示以 $i$ 为根的子树中与 $i$ 的距离为 $j$ 的点的个数


$g[i][j]$ 表示以 $i$ 为根的子树内有 $g[i][j]$ 对点深度相同,设 $t$ 为 $lca(x,y)$ ,且满足 $dis(x,t)=dis(y,t)=d$ 且 $dis(i,t)=d-j$


则状态转移方程(必须按顺序转移):

(1)$f[x][0]=1$


(2)$ans+=g[x][0]$


(3)$ans+=f[x][i-1]*g[y][i]+g[x][i+1]*f[y][i]$


(4)$g[x][i-1]+=g[y][i]$


(5)$g[x][i+1]+=f[x][i+1]*f[y][i]$


(6)$f[x][i+1]+=f[y][i]$


然后我们发现状态数是 $O(n^2)$ 级别的,怎么办?

下面介绍一种套路:长链剖分优化树形 $dp$

我们发现状态转移方程中的(4)和(6)在第一次转移时是可以 $O(1)$ 实现的

那么我们将树进行长链剖分,对于重儿子,我们直接进行 $O(1)$ 转移,对于轻儿子,只能暴力转移了

可以证明这样做的时间复杂度是 $O(n)$ 的

至于空间复杂度,我们使用动态内存分配即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include<bits/stdc++.h>
#define MAXN 100010
#define FILE "read"
using namespace std;
typedef long long ll;
struct node{int y,next;}e[MAXN<<1];
int n,len,Link[MAXN],deep[MAXN],bot[MAXN];
ll ans,*p,memp[MAXN*5],*f[MAXN],*g[MAXN];
char buf[1<<15],*fs,*ft;
inline char getc(){return (fs==ft&&(ft=(fs=buf)+fread(buf,1,1<<15,stdin),fs==ft))?0:*fs++;}
inline int read(){
int x=0,f=1; char ch=getc();
while(!isdigit(ch)) {if(ch=='-') f=-1; ch=getc();}
while(isdigit(ch)) {x=x*10+ch-'0'; ch=getc();}
return x*f;
}
void insert(int x,int y){e[++len].next=Link[x];Link[x]=len;e[len].y=y;}
void dfs(int x,int fa){
bot[x]=x;
for(int i=Link[x];i;i=e[i].next)if(e[i].y!=fa){//长链剖分
deep[e[i].y]=deep[x]+1; dfs(e[i].y,x);
if(deep[bot[e[i].y]]>deep[bot[x]]) bot[x]=bot[e[i].y];
}
for(int i=Link[x];i;i=e[i].next)if(e[i].y!=fa){//动态内存分配
if(bot[x]==bot[e[i].y]&&x!=1) continue;
p+=deep[bot[e[i].y]]-deep[x]+1;
f[bot[e[i].y]]=p; g[bot[e[i].y]]=(p+=1);
p+=(deep[bot[e[i].y]]-deep[x])<<1|1;
}
}
void dp(int x,int fa){
for(int i=Link[x];i;i=e[i].next)if(e[i].y!=fa){
dp(e[i].y,x);
if(bot[x]==bot[e[i].y])f[x]=f[e[i].y]-1,g[x]=g[e[i].y]+1;
}
f[x][0]=1; ans+=g[x][0];
for(int i=Link[x];i;i=e[i].next)if(e[i].y!=fa){
if(bot[x]==bot[e[i].y]) continue;
for(int j=0;j<=deep[bot[e[i].y]]-deep[x];++j)
ans+=f[x][j-1]*g[e[i].y][j]+g[x][j+1]*f[e[i].y][j];
for(int j=0;j<=deep[bot[e[i].y]]-deep[x];++j){
g[x][j-1]+=g[e[i].y][j];
g[x][j+1]+=f[e[i].y][j]*f[x][j+1];
f[x][j+1]+=f[e[i].y][j];
}
}
}
int main(){
freopen(FILE".in","r",stdin);
freopen(FILE".out","w",stdout);
n=read();
for(int i=1;i<n;++i){
int x=read(),y=read();
insert(x,y); insert(y,x);
}
p=memp+1; deep[1]=1;
dfs(1,0); dp(1,0);
printf("%lld\n",ans);
return 0;
}
文章目录
,