主席树模板

本文最后更新于:3 years ago

模板来源:HKer_YM的博客

1求区间第k大

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN = 2e5 + 5;
const ll MOD = 100000007;
ll a[MAXN], b[MAXN], n, m, q, sz;
int lc[MAXN << 5], rc[MAXN << 5], rt[MAXN << 5], sum[MAXN << 5];
void init() {
sz = 0;
rt[0] = 0;
}
void build(int &rt, int l, int r) {
rt = ++sz;
sum[rt] = 0;
if (l == r) return ;
int mid = l + r >> 1;
build(lc[rt], l, mid);
build(rc[rt], mid + 1, r);
}
int update(int o, int l, int r, int p) {
int oo = ++sz;
lc[oo] = lc[o], rc[oo] = rc[o], sum[oo] = sum[o] + 1;
if (l == r) return oo;
int mid = l + r >> 1;
if (mid >= p) {
lc[oo] = update(lc[oo], l, mid, p);
} else {
rc[oo] = update(rc[oo], mid + 1, r, p);
}
return oo;
}

int query(int u, int v, int l, int r, int k) {
int mid = l + r >> 1;
if (l == r) return l;
// 改变左右即可求区间第k大
// 当前为区间第k小
// int x = sum[lc[v]] - sum[lc[u]];
// if(x >= k) return query(lc[u], lc[v], l, mid, k);
// else return query(rc[u], rc[v], mid + 1, r, k - x);
// 当前为区间第k大
int x = sum[rc[v]] - sum[rc[u]];
if(x >= k) query(rc[u], rc[v], mid + 1, r, k);
else return query(lc[u], lc[v], l, mid, k - x);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr), cout.tie(nullptr);
cin >> n >> m;
init();
for(int i = 1; i <= n; i++) {
cin >> a[i];
b[i] = a[i];
}
// 主席树有序节点
sort(b + 1, b + 1 + n);
// 不同元素个数
q = unique(b + 1, b + 1 + n) - b - 1;
// 构建0为根的一颗线段树
build(rt[0], 1, q);
for(int i = 1; i <= n; i++) {
// 将元素一个一个插入进去
int tmp = lower_bound(b + 1, b + 1 + q, a[i]) - b;
rt[i] = update(rt[i - 1], 1, q, tmp);
}
while(m--) {
int l, r, k;
cin >> l >> r >> k;
// 查询区间第k小
cout << b[query(rt[l - 1], rt[r], 1, q, k)] << endl;
}
return 0;
}

2.求区间前k大的和

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
void build(int &rt, int l, int r) {
rt = ++sz;
sum[rt] = 0;
ans[rt] = 0;
if (l == r) return ;
int mid = l + r >> 1;
build(lc[rt], l, mid);
build(rc[rt], mid + 1, r);
}

int update(int o, int l, int r, int p) {
int oo = ++sz;
lc[oo] = lc[o], rc[oo] = rc[o], sum[oo] = sum[o] + 1, ans[oo] = ans[o] + b[p];
if (l == r) return oo;
int mid = l + r >> 1;
if (mid >= p) {
lc[oo] = update(lc[oo], l, mid, p);
} else {
rc[oo] = update(rc[oo], mid + 1, r, p);
}
return oo;
}
ll querySum(int u, int v, int l, int r, int k) {
int mid = l + r >> 1;
if (l == r) return b[l] * k;
// 当前为区间第k大
int x = sum[rc[v]] - sum[rc[u]];
if(x >= k) return querySum(rc[u], rc[v], mid + 1, r, k);
else return querySum(lc[u], lc[v], l, mid, k - x) + ans[rc[v]] - ans[rc[u]];
}


本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!