【题目描述】 现在有一棵二叉树,所有非叶子节点都有两个孩子。在每个叶子节点上有一个权值(有$n$个叶子节点,满足这些权值为$1\dots n$的一个排列)。可以任意交换每个非叶子节点的左右孩子。
要求进行一系列交换,使得最终所有叶子节点的权值按照遍历序写出来,逆序对个数最少。
【输入格式】 第一行n下面每行,一个数x
如果$x=0$,表示这个节点非叶子节点,递归地向下读入其左孩子和右孩子的信息,
如果$x\ne 0$,表示这个节点是叶子节点,权值为$x$
【输出格式】 一行,最少逆序对个数。
题解 此题输入稍微有点毒瘤啊。。。但是反正是按照dfs序给的 就边dfs边读入好了
如果我们把所有的叶子节点从左到右编号为$1\sim n$,那么某个节点$x$的子树中必定含有编号连续的一段叶子节点$[l,r]$,也就是说叶子节点$l$到叶子节点$r$都在$x$的子树里
方便起见我们再把叶子节点$i$的权值定义为$v_i$
现在考虑一个非叶子节点$x$它的左儿子是$a$,右儿子是$b$我们记$a$子树里含有$[l_a, r_a]$的叶子节点,$b$子树里有$[l_b, r_b]$的叶子节点
我们让$f(x)$等于 满足$l_a\le i\le r_a,\ l_b\le j\le r_b,\ v_i>v_j$的所有逆序对$(i,j)$的数量 那么如果没有交换操作的话 答案就是所有的$f(i)$之和
这个我也不知道怎么解释。。。似乎挺显然的 因为左右儿子内部的逆序对之前已经统计完了嘛
那么如果我们交换了$a,b$的位置$f(x)$会有什么变化呢
$f(x)$就会变成满足$l_a\le i\le r_a,\ l_b\le j\le r_b,\ v_i<v_j$的$(i,j)$的数量
因为左右交换了 所以原来的顺序对全部变成了逆序对 逆序对全部变成了顺序对
怎么求逆序对?
因为此题给出的二叉树不一定是完全二叉树 所以不能用归并排序 树状数组则不方便我们快速统计上面要求的那个东西,也不好合并 所以我们用权值线段树
我们要求的是这个:满足$l_a\le i\le r_a,\ l_b\le j\le r_b,\ v_i>v_j$的所有$(i,j)$的数量 这个可以在合并左儿子和右儿子的线段树时顺便统计出来
我们依然把左儿子叫做$a$,右儿子叫做$b$
具体来说 对于权值线段树的一个节点代表的区间$[l,r]$令$cnt1$等于$a$的线段树在$[mid+1,r]$区间的元素个数(就是说左儿子的子树里的叶子节点$i$有多少个$mid+1\le v_i\le r$)$cnt2$等于$b$的线段树在$[l,mid]$区间的元素个数
那么$cnt1$的节点和$cnt2$的节点两两匹配都一定满足$l_a\le i\le r_a,\ l_b\le j\le r_b,\ v_i>v_j$所以$f(x)$加上$cnt1*cnt2$
对于权值线段树上每个节点都算一次
注意 这样统计一定是不重不漏的 很好理解 但是我并不知道怎么解释。。。
反过来也同理
因为你要么交换 要么不交换 所以答案就加上两种$f(x)$中较小的那一个
这篇题解真难写
代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 #include <bits/stdc++.h> using namespace std ;typedef long long ll;int n, m, son[1000005 ][2 ], out[1000005 ];struct segtree { int lc, rc, cnt; } tr[4000005 ]; int tot;int rt[1000005 ];ll ans; void update (int &ind, int l, int r, int p) { if (!ind) ind = ++tot; tr[ind].cnt++; if (l == r) return ; int mid = (l + r) >> 1 ; if (p <= mid) update(tr[ind].lc, l, mid, p); else update(tr[ind].rc, mid+1 , r, p); } int query (int ind, int l, int r, int x, int y) { if (x > y || !ind) return 0 ; if (x <= l && r <= y) return tr[ind].cnt; int mid = (l + r) >> 1 , ret = 0 ; if (x <= mid) ret += query(tr[ind].lc, l, mid, x, y); if (mid < y) ret += query(tr[ind].rc, mid+1 , r, x, y); return ret; } int merge (int x, int y, ll &mn, ll &mx) { if (!x) return y; if (!y) return x; mn += 1l l * tr[tr[x].lc].cnt * tr[tr[y].rc].cnt; mx += 1l l * tr[tr[y].lc].cnt * tr[tr[x].rc].cnt; tr[x].cnt += tr[y].cnt; tr[x].lc = merge(tr[x].lc, tr[y].lc, mn, mx); tr[x].rc = merge(tr[x].rc, tr[y].rc, mn, mx); return x; } int dfs () { int ind = ++m; int val; scanf ("%d" , &val); if (val) { update(rt[ind], 1 , n, val); } else { int lc = dfs(), rc = dfs(); ll mn = 0 , mx = 0 ; rt[ind] = merge(rt[lc], rt[rc], mn, mx); ans += min(mn, mx); } return ind; } int main () { scanf ("%d" , &n); dfs(); printf ("%lld\n" , ans); return 0 ; }