首先\(a,b,c\)肯定在一条链上。当\(b\)为\(a\)的祖先时,\(a\)的子树中所有与它不同的点都可以作为点\(c\),当\(a\)为\(b\)的祖先时,\(b\)的子树中所有与它不同的点都可以作为答案
前者直接\(sz[a]*min(k,dep[a])\)即可,关键是后者,如果把\(size\)作为节点的值,我们需要知道这棵树的子树中所有与它距离不超过\(k\)的节点的权值之和。据说可以用长链剖分离线,不过这里也可以用线段树合并在线实现
我会说我以前根本没写过线段树合并结果完全不知道错在哪里么
// luogu-judger-enable-o2//minamoto#include#define ll long longusing namespace std;const int N=6e5+5;#define getc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)char buf[1<<21],*p1=buf,*p2=buf;int read(){ int res,f=1;char ch; while((ch=getc())>'9'||ch<'0')(ch=='-')&&(f=-1); for(res=ch-'0';(ch=getc())>='0'&&ch<='9';res=res*10+ch-'0'); return res*f;}char sr[1<<21],z[20];int C=-1,Z=0;inline void Ot(){fwrite(sr,1,C+1,stdout),C=-1;}void print(ll x){ if(C>1<<20)Ot();if(x<0)sr[++C]='-',x=-x; while(z[++Z]=x%10+48,x/=10); while(sr[++C]=z[Z],--Z);sr[++C]='\n';}int head[N],Next[N<<1],ver[N<<1],tot;inline void add(int u,int v){ver[++tot]=v,Next[tot]=head[u],head[u]=tot;}int L[N<<5],R[N<<5],rt[N],cnt,sz[N],dep[N],n,m;ll s[N<<5];void upd(int &p,int l,int r,int x,int v){ if(!p)p=++cnt;s[p]+=v;if(l==r)return; int mid=(l+r)>>1; x<=mid?upd(L[p],l,mid,x,v):upd(R[p],mid+1,r,x,v);}int merge(int x,int y,int l,int r){ if(!x||!y)return x|y; int mid=(l+r)>>1,u=++cnt;s[u]=s[x]+s[y]; L[u]=merge(L[x],L[y],l,mid); R[u]=merge(R[x],R[y],mid+1,r); return u;}ll query(int p,int l,int r,int ql,int qr){ if(!p)return 0;if(ql<=l&&qr>=r)return s[p]; int mid=(l+r)>>1;ll res=0; if(ql<=mid)res+=query(L[p],l,mid,ql,qr); if(qr>mid)res+=query(R[p],mid+1,r,ql,qr); return res;}void dfs(int u,int fa){ dep[u]=dep[fa]+1,sz[u]=1; for(int i=head[u];i;i=Next[i])if(ver[i]!=fa) dfs(ver[i],u),sz[u]+=sz[ver[i]]; upd(rt[u],1,n,dep[u],sz[u]-1); if(fa)rt[fa]=merge(rt[fa],rt[u],1,n);}int main(){// freopen("testdata.in","r",stdin); n=read(),m=read(); for(int i=1,u,v;i