题解 BZOJ4401

$Description$

把一棵树分成几块,使得每个块中的点数都相同,问有多少种块的大小能满足该条件

$Solution$

考虑一个性质$:$对于一个块的大小$s,$当且仅当有$\frac{n}{s}$个节点的$size($即子树大小$)$为$s$ 的倍数。

证明$:$

假设$size$为$s$的倍数的节点为$a$节点

先证明一棵树中$a$节点的个数不会超过$\frac{n}{s}$个。采取反证法,假设一棵树中$a$节点有$\frac{n}{s}+1$个,那么我们可以先构造出一颗有$\frac{n}{s}+1($大于$\frac{n}{s}+1$个节点的情况可以以此类推$)$个节点的树,假设这些点都是$a$节点,然后通过加入$n-(\frac{n}{s}+1)$个点$($可以加在边上$)$来满足条件,不断找叶子节点,在它下面连一颗大小为$s-1$的树,然后我们可以删去这个节点和它的子树,因为它对上面的祖先$\%s$的余数已经没有影响了。我们发现,对于每一个$a$节点,我们都需要$s-1$个新增节点。那么总的新增节点数就是$(\frac{n}{s}+1)\times (s-1)=n-s-(\frac{n}{s}+1),$再加上原先的$a$节点数即总节点数为$n+s>n$,所以一棵树中$a$节点的个数不会超过$\frac{n}{s}$个。

再证明为什么有$\frac{n}{s}$个$a$节点就可以构造,我们只需要每次找到一个$size==s$的$a$节点然后将它和它的子树分成一个块然后在树上删除这个块即可。每次至少能找出一个$size==s$的$a$节点。因为如果每个$a$节点的$size$都至少为$2\times s,$那么假设$a$节点的$size$最小为$2\times s($大于$2\times s$的可以以此类推$),$那么删去这个$size==2\times s$的节点和它的子树,那么剩下的树的$a$节点数不会超过$\frac{n-2\times s}{s}$个,再加上这个删去的节点,$a$节点的数量只有$\frac{n}{s}-1$个所以每次至少能找出一个$size==s$的$a$节点,也就可以顺利构造了。

小于$\frac{n}{s}$个$a$节点用上面的构造方法模拟一下就知道显然不可行$($没有$size==s$的$a$节点你是没办法继续分块的$)$。

然后就很无脑了,先$dfs$求出每个点的$size,$然后全扔桶里面,然后从$1\sim n$枚举块的大小,$O(\frac{n}{i})$计算$a$节点的个数。

复杂度就是调和级数$O(nlogn)$

$Code$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define re register
#define N 1002302
using namespace std;
struct edge{
int to,next;
}e[N<<1];
inline int read(){
int x=0,w=0;char ch=getchar();
while (!isdigit(ch))w|=ch=='-',ch=getchar();
while (isdigit(ch))x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
return w?-x:x;
}
int cnt,head[N],size[N],n,ma[N],ans;
inline void add(int u,int v){
e[++cnt].to=v;
e[cnt].next=head[u];
head[u]=cnt;
}
void dfs(int u,int fa){
size[u]=1;
for (int i=head[u];i;i=e[i].next){
int v=e[i].to;
if (v==fa)continue;
dfs(v,u);
size[u]+=size[v];
}
}
signed main(){
n=read();
for (int i=1;i<n;++i){
int u=read(),v=read();
add(u,v);add(v,u);
}
dfs(1,0);
for (int i=1;i<=n;++i)++ma[size[i]];
for (int i=1;i<=n;++i)
if (n%i==0){
int res=0;
for (int j=i;j<=n;j+=i)
res+=ma[j];
ans+=(res==n/i);
}
printf("%d\n",ans);
return 0;
}