(题面来自Luogu)
题目描述
奶牛们又一次试图创建一家创业公司,还是没有从过去的经验中吸取教训--牛是可怕的管理者!
为了方便,把奶牛从 1⋯N(1≤N≤100,000) 编号,把公司组织成一棵树,1 号奶牛作为总裁(这棵树的根节点)。除了总裁以外的每头奶牛都有一个单独的上司(它在树上的 “双亲结点”)。所有的第 i 头牛都有一个不同的能力指数 p(i),描述了她对其工作的擅长程度。如果奶牛 i 是奶牛 j 的祖先节点(例如,上司的上司的上司),那么我们我们把奶牛 j 叫做 i 的下属。
不幸地是,奶牛们发现经常发生一个上司比她的一些下属能力低的情况,在这种情况下,上司应当考虑晋升她的一些下属。你的任务是帮助奶牛弄清楚这是什么时候发生的。简而言之,对于公司的中的每一头奶牛 i,请计算其下属 j 的数量满足 p(j)>p(i)。
输入格式
输入的第一行包括一个整数 N。
接下来的 N 行包括奶牛们的能力指数 p(1)⋯p(N). 保证所有数互不相同,在区间 1⋯⋯1e9 之间。
接下来的 N−1 行描述了奶牛 2⋯⋯N 的上司(父节点)的编号。再次提醒,1 号奶牛作为总裁,没有上司。
输出格式
输出包括 N 行。输出的第 i 行应当给出有多少奶牛 i 的下属比奶牛 i 能力高。
同样是大规模统计子树信息的问题,这题并不能拿启发式合并来做,因为子树内的要求信息和根节点的信息要求不同,与根本身的信息有关。换句话说,无法把重儿子的信息直接合并给根。
线段树合并的一般对象是两棵动态开点的权值线段树。权值线段树动态开点可以只见出建出没有被插入过的位置,时空复杂度O(nloginf),可以节省很大的空间。查询时避开空节点即可。
代码:
- void modify(int &nd, int l, int r, int x) {
- if (!nd) nd = ++tot;
- if (l == r) {
- ++cnt[nd];
- return;
- }
- if (x <= mid) modify(lc[nd], l, mid, x);
- else modify(rc[nd], mid + 1, r, x);
- update(nd);
- }
动态开点是不需要离散化序列的,因为没有被访问到的值不会被建出来。不过在题目条件允许的情况下,离散化可以把线段树本身的空间和时间的复杂度都降到O(nlogn),而离散化本身的复杂度又很小,可以视题目尽量加上。
线段树合并的过程很简单,它利用了动态开点线段树部分节点为空的性质,在递归的过程中直接暴力合并两棵线段树对应节点的信息,遇到空节点直接返回。合并函数写法有两种,前者通过新建线段树保留旧线段树的信息,后者直接利用原有节点,不需要额外开辟空间。静态子树统计的问题中每个节点返回后不会被再次访问,一般采用省空间的写法。
写法1:
- int merge(int &u, int v) {
- if (!v) return u;
- if (!u) return v;
- int nd = ++tot;
- cnt[nd] = cnt[lc[nd]] + cnt[rc[nd]];//把v的信息并到u上
- lc[nd] = merge(lc[u], lc[v]);
- rc[nd] = merge(rc[u], rc[v]);
- return nd;
- }
写法2:
- void merge(int &u, int v) {
- if (!u || !v) {//存在一个空节点
- u += v;//直接把u指向非空的那个点
- return;
- }
- seg[u].cnt += seg[v].cnt;//把v的信息并到u上
- merge(lc[u], lc[v]);
- merge(rc[u], rc[v]));
- return;
- }
时间复杂度证明:假设最初存在的线段树共有O(nlogn)级别的节点数。每一次合并操作至少会减少一个原有节点,所以总复杂度的上界就是O(nlogn)的。
线段树合并常数比树上启发式合并大一些,能承受的范围大概在1e6左右。不过由于每个节点合并起来得到的线段树是分开统计的,它更普适于类似的子树统计问题。
其实该题中的查询子树权值信息和合并操作,用树状数组就可以维护。更神仙的做法是暴上主席树……但是作为线段树合并的模板还是要认真打的……
考虑对每个叶子节点维护一棵权值线段树,然后递归地合并一个节点u的每一个子节点中的线段树,在这棵树上查询比能力值score[u]大的节点数就等价于查询[score[u] + 1, inf]的区间和。然后把score[u]插入该线段树中即可返回。由于只需要比较大小关系,可以做离散化来优化时间复杂度。
代码:
- #include <iostream>
- #include <cstdio>
- #include <cstring>
- #include <algorithm>
- #define BUG puts("$$$")
- #define maxn 100010
- template <typename T>
- void read(T &x) {
- x = 0;
- char ch = getchar();
- // int f = 1;
- while (!isdigit(ch)) {
- // if (ch == '-') f = -1;
- ch = getchar();
- }
- while (isdigit(ch)) {
- x = x * 10 + (ch ^ 48);
- ch = getchar();
- }
- // x *= f;
- }
- using namespace std;
- int n, N;
- int head[maxn], top;
- struct E {
- int to, nxt;
- } edge[maxn << 1];
- inline void insert(int u, int v) {
- edge[++top] = (E) {v, head[u]};
- head[u] = top;
- }
- int score[maxn];
- namespace Segment_tree {
- #define mid ((l + r) >> 1)
- int tot = 0;
- struct node {
- int cnt, lc, rc;
- node(): cnt(0), lc(0), rc(0) {}
- } seg[maxn << 4];
- #define lc(nd) seg[nd].lc
- #define rc(nd) seg[nd].rc
- void update(int nd) {
- seg[nd].cnt = seg[lc(nd)].cnt + seg[rc(nd)].cnt;
- }
- void modify(int &nd, int l, int r, int x) {
- if (!nd) nd = ++tot;
- if (l == r) {
- ++seg[nd].cnt;
- return;
- }
- if (x <= mid) modify(lc(nd), l, mid, x);
- else modify(rc(nd), mid + 1, r, x);
- update(nd);
- }
- int query(int nd, int l, int r, int ql, int qr) {
- if (!nd) return 0;
- if (l >= ql && r <= qr)
- return seg[nd].cnt;
- if (l > qr || r < ql)
- return 0;
- return query(lc(nd), l, mid, ql, qr) + query(rc(nd), mid + 1, r, ql, qr);
- }
- void merge(int &u, int v) {
- if (!u || !v) {
- u += v;
- return;
- }
- seg[u].cnt += seg[v].cnt;
- merge(lc(u), lc(v));
- merge(rc(u), rc(v));
- return;
- }
- int ans[maxn], rt[maxn];
- void solve(int u) {//核心代码
- for (int i = head[u]; i; i = edge[i].nxt) {
- int v = edge[i].to;
- solve(v);
- merge(rt[u], rt[v]);
- }
- ans[u] = query(rt[u], 1, N, score[u] + 1, N);
- modify(rt[u], 1, N, score[u]);
- return;
- }
- } using namespace Segment_tree;
- int st[maxn];
- int contra(int *a) {//离散化
- memcpy(st, a, sizeof(st));
- sort(st + 1, st + 1 + n);
- int len = unique(st + 1, st + 1 + n) - st - 1;
- for (int i = 1; i <= n; ++i)
- a[i] = lower_bound(st + 1, st + 1 + len, a[i]) - st;
- return len;
- }
- int main() {
- read(n);
- int u, v;
- for (int i = 1; i <= n; ++i)
- read(score[i]);
- N = contra(score);
- for (u = 2; u <= n; ++u) {
- read(v);
- insert(v, u);
- }
- solve(1);
- for (int i = 1; i <= n; ++i)
- printf("%d\n", ans[i]);
- return 0;
- }
© 著作权归作者所有
发表评论