AK-dream

【题目描述】

给你一棵 $n$ 个点的树,编号 $1\sim n$。每个点可以是黑色,可以是白色。初始时所有点都是黑色。下面有两种操作:

  • 0 u:询问有多少个节点 $v$ 满足路径 $u$ 到 $v$ 上所有节点(包括 $u$)都拥有相同的颜色。

  • 1 u:翻转 $u$ 的颜色。

【输入/输出格式】

不关心

$n,m\le 10^5$

最近不知道为什么一直在敲数据结构。。。感觉要换换题型了

题解

随便找个点当根吧

如果有两个点$u,v$满足查询操作那个条件 我们就说$u,v$联通

注意到我们只需要维护一个点子树里有多少点和它联通

对于查询操作只需要找到深度最浅的和查询点联通的祖先就可以了

为了方便操作 我们让$cnt[x][0]$表示如果$x$是黑点 那么子树里有多少点和它联通 $cnt[x][1]$表示白点

那么对于修改操作 我们假设是把$x$从黑改成白

我们只需要找到那个深度最浅的和$x$联通的祖先$p$ 然后把$fa[p]\sim fa[x]$这条链上所有点的$cnt[i][0]$减掉$cnt[x][0]$
然后更改$x$的颜色
再找到此时深度最浅的和$x$联通的祖先$p_2$(注意$x$的颜色变了 所以和祖先的联通也已经变了) 把$fa[p_2]\sim fa[x]$这条链上所有点的$cnt[i][1]$加上$cnt[x][1]$

因为$x$变白点之后子树里的黑点就不和外面联通了 而子树里的白点就会和外面联通

(实际上你会发现$p$和$p_2$中有一个肯定就是$x$ 因为$x$的父亲要么是白点要么是黑点 但是无所谓)

白改黑同理

区间修改树剖就可以了 问题在于如何快速找到深度最浅的和$x$联通的祖先?

还是树剖 线段树再维护一下区间内有多少个黑点白点 那么从$fa[x]$开始一直往上跳重链 如果整段都和$x$颜色一样就继续跳 否则一定可以线段树二分找到第一个和$x$颜色不一样的点

但是这里写起来就会比较麻烦。。。这题估计还是LCT简单点

时间复杂度$O(n\log^2 n)$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#include <bits/stdc++.h>
#define N 100005
using namespace std;

inline int read() {
int x = 0, f = 1; char ch = getchar();
for (; ch > '9' || ch < '0'; ch = getchar()) if (ch == '-') f = -1;
for (; ch <= '9' && ch >= '0'; ch = getchar()) x = (x << 3) + (x << 1) + (ch ^ '0');
return x * f;
}

int n, m, col[N];
int head[N], pre[N<<1], to[N<<1], sz;
int dfn[N], rnk[N], tme, d[N], siz[N], top[N], son[N], fa[N];

inline void addedge(int u, int v) {
pre[++sz] = head[u]; head[u] = sz; to[sz] = v;
pre[++sz] = head[v]; head[v] = sz; to[sz] = u;
}

void dfs(int x) {
siz[x] = 1;
for (int i = head[x]; i; i = pre[i]) {
int y = to[i];
if (y == fa[x]) continue;
d[y] = d[x] + 1; fa[y] = x;
dfs(y);
siz[x] += siz[y];
if (!son[x] || siz[son[x]] < siz[y]) son[x] = y;
}
}

void dfs2(int x, int _top) {
top[x] = _top; dfn[x] = ++tme; rnk[tme] = x;
if (son[x]) dfs2(son[x], _top);
for (int i = head[x]; i; i = pre[i]) {
int y = to[i];
if (y == fa[x] || y == son[x]) continue;
dfs2(y, y);
}
}

struct segtree{
int l, r, cnt[2], tag[2], sum[2]; //0:black 1:white
} tr[N<<2];

#define lson ind<<1
#define rson ind<<1|1

inline void pushup(int ind) {
tr[ind].cnt[0] = tr[lson].cnt[0] + tr[rson].cnt[0];
tr[ind].cnt[1] = tr[lson].cnt[1] + tr[rson].cnt[1];
tr[ind].sum[0] = tr[lson].sum[0] + tr[rson].sum[0];
tr[ind].sum[1] = tr[lson].sum[1] + tr[rson].sum[1];
}

void build(int ind, int l, int r) {
tr[ind].l = l; tr[ind].r = r; tr[ind].tag[0] = tr[ind].tag[1] = 0;
if (l == r) {
tr[ind].cnt[0] = siz[rnk[l]]; tr[ind].cnt[1] = 1;
tr[ind].sum[0] = 1; tr[ind].sum[1] = 0;
return;
}
int mid = (l + r) >> 1;
build(lson, l, mid); build(rson, mid+1, r);
pushup(ind);
}

void pushdown(int ind) {
if (tr[ind].tag[0]) {
int v = tr[ind].tag[0]; tr[ind].tag[0] = 0;
tr[lson].cnt[0] += v; tr[lson].tag[0] += v;
tr[rson].cnt[0] += v; tr[rson].tag[0] += v;
}
if (tr[ind].tag[1]) {
int v = tr[ind].tag[1]; tr[ind].tag[1] = 0;
tr[lson].cnt[1] += v; tr[lson].tag[1] += v;
tr[rson].cnt[1] += v; tr[rson].tag[1] += v;
}
}

void update(int ind, int x, int y, int v, int c) {
int l = tr[ind].l, r = tr[ind].r;
if (x <= l && r <= y) {
tr[ind].cnt[c] += (r - l + 1) * v; tr[ind].tag[c] += v;
return;
}
pushdown(ind);
int mid = (l + r) >> 1;
if (x <= mid) update(lson, x, y, v, c);
if (mid < y) update(rson, x, y, v, c);
pushup(ind);
}

int query(int ind, int pos, int c) {
int l = tr[ind].l, r = tr[ind].r;
if (l == r) return tr[ind].cnt[c];
pushdown(ind);
int mid = (l + r) >> 1;
if (pos <= mid) return query(lson, pos, c);
else return query(rson, pos, c);
}

void change(int ind, int pos, int c) {
int l = tr[ind].l, r = tr[ind].r;
if (l == r) {
tr[ind].sum[c^1] = 0;
tr[ind].sum[c] = 1;
return;
}
int mid = (l + r) >> 1;
if (pos <= mid) change(lson, pos, c);
else change(rson, pos, c);
pushup(ind);
}

int find(int ind, int x, int y, int c) {
int l = tr[ind].l, r = tr[ind].r;
if (l == r) {
if (tr[ind].sum[c]) return l;
else return 0;
}
if (x <= l && r <= y) {
if (!tr[ind].sum[c]) return 0;
}
int mid = (l + r) >> 1;
if (mid >= y) return find(lson, x, y, c);
if (x > mid) return find(rson, x, y, c);
int ret = find(rson, x, y, c);
if (!ret) return find(lson, x, y, c);
else return ret;
}

void Update(int x) {
int c = col[x], tmp[2] = {query(1, dfn[x], 0), query(1, dfn[x], 1)};
col[x] ^= 1;
int xx = fa[x];
while (xx) { //边找边修改
int lst = find(1, dfn[top[xx]], dfn[xx], c^1);
if (lst) {
update(1, lst, dfn[xx], -tmp[c], c);
break;
} else {
update(1, dfn[top[xx]], dfn[xx], -tmp[c], c);
xx = fa[top[xx]];
}
}
xx = fa[x];
while (xx) {
int lst = find(1, dfn[top[xx]], dfn[xx], c);
if (lst) {
update(1, lst, dfn[xx], tmp[c^1], c^1);
break;
} else {
update(1, dfn[top[xx]], dfn[xx], tmp[c^1], c^1);
xx = fa[top[xx]];
}
}
change(1, dfn[x], col[x]);
}

int Query(int x) {
int c = col[x], xx = fa[x], lstson = x;
while (xx) {
int lst = find(1, dfn[top[xx]], dfn[xx], c^1);
if (lst) {
if (lst == dfn[xx]) {
return query(1, dfn[lstson], c);
} else {
return query(1, lst + 1, c);
}
}
lstson = top[xx];
xx = fa[top[xx]];
}
return query(1, dfn[1], c);
}

int main() {
n = read();
for (int i = 1, u, v; i < n; i++) {
u = read(), v = read();
addedge(u, v);
}
dfs(1); dfs2(1, 1);
build(1, 1, n);
m = read();
for (int i = 1, tp, x; i <= m; i++) {
tp = read(), x = read();
if (!tp) {
printf("%d\n", Query(x));
} else {
Update(x);
}
}
return 0;
}

 评论