master 对树上的求和非常感兴趣。他生成了一棵有根树,并且希望多次询问这棵树上一段路径上所有节点深度的 $k$ 次方和,而且每次的$k$ 可能是不同的。此处节点深度的定义是这个节点到根的路径上的边数。他把这个问题交给了pupil,但pupil 并不会这么复杂的操作,你能帮他解决吗?
思路
树上路径问题,考虑树剖。
$ k$ 次方和,且 $k$ 值范围较小,不超过 $50$ 。
于是考虑维护 $50$ 个线段树,而且这题没有修改操作,只有查询区间和的操作。
但是开 int 会溢出,开 longlong 又会爆空间。
参考了题解的做法,乘法操作时强制类型转换,开 int ,于是便可以通过本题。
代码
#include<bits/stdc++.h>
#define ll long long
#define N 300010
#define mod 998244353
using namespace std;
inline int read() {
int w=1,x=0;
char ch=0;
while(ch<'0'||ch>'9') {if(ch=='-') w=-1;ch=getchar();}
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
return w*x;
}
int n,m;
int f[N][55];
int head[N],tot=0;
struct graph{
int v,next;
}edge[N<<1];
void add_edge(int u,int v) {edge[++tot].v=v,edge[tot].next=head[u],head[u]=tot;}
void add(int u,int v) {add_edge(u,v),add_edge(v,u);}
int depth[N],fa[N],top[N],son[N],size[N],id[N],rk[N],cnt=0;
void dfs_first(int u) {
depth[u]=depth[fa[u]]+1;
size[u]=1;
for(int i=head[u];i;i=edge[i].next) {
int v=edge[i].v;
if(v==fa[u]) continue;
fa[v]=u;
dfs_first(v);
size[u]+=size[v];
if(size[v]>size[son[u]]) son[u]=v;
}
}
void dfs_second(int u,int t) {
if(!u) return;
top[u]=t;
id[u]=++cnt;
rk[cnt]=u;
dfs_second(son[u],t);
for(int i=head[u];i;i=edge[i].next) {
int v=edge[i].v;
if(v!=fa[u]&&v!=son[u]) dfs_second(v,v);
}
}
class SegmentTree{
private:
int sum[N<<2];
#define s(x) sum[x]
#define ls (p<<1)
#define rs (ls|1)
#define mid (l+r>>1)
void update(int p) {
(s(p)=s(ls)+s(rs))%=mod;
}
public:
void build(int p,int l,int r,int k) {
if(l==r) {
s(p)=f[rk[l]][k];
return ;
}
build(ls,l,mid,k),build(rs,mid+1,r,k);
update(p);
}
int ask(int p,int l,int r,int L,int R) {
if(l>R||r<L) return 0;
if(l>=L&&r<=R) return s(p);
return (ask(ls,l,mid,L,R)+ask(rs,mid+1,r,L,R))%mod;
}
}SMT[55];
int ask(int x,int y,int k) {
int tx=top[x],ty=top[y],ans=0;
while(tx!=ty) {
if(depth[tx]>depth[ty]) {
(ans+=SMT[k].ask(1,1,n,id[tx],id[x]))%=mod;
x=fa[tx],tx=top[x];
}
else {
(ans+=SMT[k].ask(1,1,n,id[ty],id[y]))%=mod;
y=fa[ty],ty=top[y];
}
}
if(id[x]<id[y]) (ans+=SMT[k].ask(1,1,n,id[x],id[y]))%=mod;
else (ans+=SMT[k].ask(1,1,n,id[y],id[x]))%=mod;
return ans;
}
void init() {
depth[0]=-1;
dfs_first(1);
dfs_second(1,1);
for(int i=1;i<=n;i++) f[i][1]=depth[i];
for(int i=1;i<=n;i++) {
for(int j=2;j<=50;j++) f[i][j]=(ll)f[i][j-1]*(ll)depth[i]%mod;
}
for(int i=1;i<=50;i++) SMT[i].build(1,1,n,i);
}
int main() {
n=read();
for(int i=1;i<n;i++) {
int u=read(),v=read();
add(u,v);
}
init();
m=read();
for(int i=1;i<=m;i++) {
int u=read(),v=read(),k=read();
printf("%d\n",ask(u,v,k));
}
system("pause");
return 0;
}