- 给出一棵 $n$ 个节点的树,根节点为 $1$。每个节点上有一种颜色 $c_i$。$m$ 次操作。操作有两种:
1 u c
:将以 $u$ 为根的子树上的所有节点的颜色改为 $c$。2 u
:询问以 $u$ 为根的子树上的所有节点的颜色数量。
- $1\le n,m\le 4\times 10^5$,$1\le c_i,c\le 60$。
思路
看到是对子树进行操作,所以很自然地想到了 dfs
序。
这个操作我想到了每个点维护一个 set
和一个延迟标记。对于操作一,对对应 dfs
序上的点的 set
清空,并插入所需要修改的颜色 $c$ ,然后修改延迟标记;对于操作二,把对应区间的 set
取并集,然后输出并集的大小即可。
如此,过了 47 个点,然后就 TLE 了。
#include<bits/stdc++.h>
#define N 400010
using namespace std;
inline int read() {
int w=1,x=0;
char ch=0;
while(ch<'0'||ch>'9') {if(ch=='-') w=-1;ch=getchar();}
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
return w*x;
}
int n,m,c[N];
int head[N],tot=0;
struct graph{
int v,next;
}edge[N<<1];
void add_edge(int u,int v) {edge[++tot].v=v,edge[tot].next=head[u],head[u]=tot;}
void add(int u,int v) {add_edge(u,v),add_edge(v,u);}
int dfn[N<<1],js=0,lid[N],rid[N];
void dfs(int u,int fa) {
dfn[++js]=u;
lid[u]=js;
for(int i=head[u];i;i=edge[i].next) {
int v=edge[i].v;
if(v==fa) continue;
dfs(v,u);
}
dfn[++js]=u;
rid[u]=js;
}
class SegmentTree{
private:
int lazy[N<<2];
set<int> color[N<<2];
#define l(x) lazy[x]
#define c(x) color[x]
#define ls (p<<1)
#define rs (ls|1)
#define mid (l+r>>1)
void update(int p) {
c(p).clear();
c(p).insert(c(ls).begin(),c(ls).end());
c(p).insert(c(rs).begin(),c(rs).end());
}
void reset(int p,int x) {
c(p).clear();
c(p).insert(x);
l(p)=x;
}
void spread(int p) {
if(!l(p)) return ;
reset(ls,l(p));
reset(rs,l(p));
l(p)=0;
}
public:
void build(int p,int l,int r) {
if(l==r) {
c(p).insert(c[dfn[l]]);
return ;
}
build(ls,l,mid),build(rs,mid+1,r);
update(p);
}
void Modify(int p,int l,int r,int L,int R,int x) {
if(l>R||r<L) return ;
if(l>=L&&r<=R) {
reset(p,x);
return ;
}
spread(p);
Modify(ls,l,mid,L,R,x),Modify(rs,mid+1,r,L,R,x);
update(p);
}
set<int> ask(int p,int l,int r,int L,int R) {
set<int> ans;
if(l>R||r<L) return ans;
if(l>=L&&r<=R) return c(p);
spread(p);
set<int> lc=ask(ls,l,mid,L,R),rc=ask(rs,mid+1,r,L,R);
ans.insert(lc.begin(),lc.end()),ans.insert(rc.begin(),rc.end());
return ans;
}
}SMT;
int main() {
n=read(),m=read();
for(int i=1;i<=n;i++) c[i]=read();
for(int i=1;i<n;i++) {
int u=read(),v=read();
add(u,v);
}
dfs(1,0);
SMT.build(1,1,js);
for(int i=1;i<=m;i++) {
int op,u,c;
op=read(),u=read();
if(op==1) {
c=read();
SMT.Modify(1,1,js,lid[u],rid[u],c);
}
else printf("%d\n",SMT.ask(1,1,js,lid[u],rid[u]).size());
}
system("pause");
return 0;
}
后来注意到 $c$ 的范围不超过 $60$,于是想到可以二进制状压,这样取并集等操作可以更高效,直接位运算即可,修改后就过了。
代码
#include<bits/stdc++.h>
#define ll long long
#define N 400010
using namespace std;
inline ll read() {
ll w=1,x=0;
char ch=0;
while(ch<'0'||ch>'9') {if(ch=='-') w=-1;ch=getchar();}
while(ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
return w*x;
}
int n,m;
ll c[N];
int head[N],tot=0;
struct graph{
int v,next;
}edge[N<<1];
void add_edge(int u,int v) {edge[++tot].v=v,edge[tot].next=head[u],head[u]=tot;}
void add(int u,int v) {add_edge(u,v),add_edge(v,u);}
int dfn[N<<1],js=0,lid[N],rid[N];
void dfs(int u,int fa) {
dfn[++js]=u;
lid[u]=js;
for(int i=head[u];i;i=edge[i].next) {
int v=edge[i].v;
if(v==fa) continue;
dfs(v,u);
}
dfn[++js]=u;
rid[u]=js;
}
class SegmentTree{
private:
ll lazy[N<<3],s[N<<3];
#define l(x) lazy[x]
#define s(x) s[x]
#define ls (p<<1)
#define rs (ls|1)
#define mid (l+r>>1)
void update(int p) {
s(p)=s(ls)|s(rs);
}
void spread(int p) {
if(!l(p)) return ;
s(ls)=1ll<<l(p),l(ls)=l(p);
s(rs)=1ll<<l(p),l(rs)=l(p);
l(p)=0;
}
public:
void build(int p,int l,int r) {
if(l==r) {
s(p)=1ll<<c[dfn[l]];
return ;
}
build(ls,l,mid),build(rs,mid+1,r);
update(p);
}
void Modify(int p,int l,int r,int L,int R,ll x) {
if(l>R||r<L) return ;
if(l>=L&&r<=R) {
s(p)=1ll<<x;
l(p)=x;
return ;
}
spread(p);
Modify(ls,l,mid,L,R,x),Modify(rs,mid+1,r,L,R,x);
update(p);
}
ll ask(int p,int l,int r,int L,int R) {
if(l>R||r<L) return 0;
if(l>=L&&r<=R) return s(p);
spread(p);
return ask(ls,l,mid,L,R)|ask(rs,mid+1,r,L,R);
}
}SMT;
int count(ll x) {
int ans=0;
while(x) {
ans+=x&1ll;
x>>=1ll;
}
return ans;
}
int main() {
n=read(),m=read();
for(int i=1;i<=n;i++) c[i]=read();
for(int i=1;i<n;i++) {
int u=read(),v=read();
add(u,v);
}
dfs(1,0);
SMT.build(1,1,js);
for(int i=1;i<=m;i++) {
int op,u;
ll c;
op=read(),u=read();
if(op==1) {
c=read();
SMT.Modify(1,1,js,lid[u],rid[u],c);
}
else printf("%d\n",count(SMT.ask(1,1,js,lid[u],rid[u])));
}
system("pause");
return 0;
}