-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathWavelet Tree
More file actions
113 lines (109 loc) · 3.08 KB
/
Wavelet Tree
File metadata and controls
113 lines (109 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#include<bits/stdc++.h>
using namespace std;
const int MAXN = (int)3e5 + 9;
const int MAXV = (int)1e9 + 9; //maximum value of any element in array
//array values can be negative too, use appropriate minimum and maximum value
struct wavelet_tree {
int lo, hi;
wavelet_tree *l, *r;
int *b, *c, bsz, csz; // c holds the prefix sum of elements
wavelet_tree() {
lo = 1;
hi = 0;
bsz = 0;
csz = 0, l = NULL;
r = NULL;
}
void init(int *from, int *to, int x, int y) {
lo = x, hi = y;
if(from >= to) return;
int mid = (lo + hi) >> 1;
auto f = [mid](int x) {
return x <= mid;
};
b = (int*)malloc((to - from + 2) * sizeof(int));
bsz = 0;
b[bsz++] = 0;
c = (int*)malloc((to - from + 2) * sizeof(int));
csz = 0;
c[csz++] = 0;
for(auto it = from; it != to; it++) {
b[bsz] = (b[bsz - 1] + f(*it));
c[csz] = (c[csz - 1] + (*it));
bsz++;
csz++;
}
if(hi == lo) return;
auto pivot = stable_partition(from, to, f);
l = new wavelet_tree();
l->init(from, pivot, lo, mid);
r = new wavelet_tree();
r->init(pivot, to, mid + 1, hi);
}
//kth smallest element in [l, r]
//for array [1,2,1,3,5] 2nd smallest is 1 and 3rd smallest is 2
int kth(int l, int r, int k) {
if(l > r) return 0;
if(lo == hi) return lo;
int inLeft = b[r] - b[l - 1], lb = b[l - 1], rb = b[r];
if(k <= inLeft) return this->l->kth(lb + 1, rb, k);
return this->r->kth(l - lb, r - rb, k - inLeft);
}
//count of numbers in [l, r] Less than or equal to k
int LTE(int l, int r, int k) {
if(l > r || k < lo) return 0;
if(hi <= k) return r - l + 1;
int lb = b[l - 1], rb = b[r];
return this->l->LTE(lb + 1, rb, k) + this->r->LTE(l - lb, r - rb, k);
}
//count of numbers in [l, r] equal to k
int count(int l, int r, int k) {
if(l > r || k < lo || k > hi) return 0;
if(lo == hi) return r - l + 1;
int lb = b[l - 1], rb = b[r];
int mid = (lo + hi) >> 1;
if(k <= mid) return this->l->count(lb + 1, rb, k);
return this->r->count(l - lb, r - rb, k);
}
//sum of numbers in [l ,r] less than or equal to k
int sum(int l, int r, int k) {
if(l > r or k < lo) return 0;
if(hi <= k) return c[r] - c[l - 1];
int lb = b[l - 1], rb = b[r];
return this->l->sum(lb + 1, rb, k) + this->r->sum(l - lb, r - rb, k);
}
~wavelet_tree() {
delete l;
delete r;
}
};
wavelet_tree t;
int a[MAXN];
int main() {
int i, j, k, n, m, q, l, r;
cin >> n;
for(i = 1; i <= n; i++) cin >> a[i];
t.init(a + 1, a + n + 1, -MAXV, MAXV);
//beware! after the init() operation array a[] will not be same
cin >> q;
while(q--) {
int x;
cin >> x;
cin >> l >> r >> k;
if(x == 0) {
//kth smallest
cout << t.kth(l, r, k) << endl;
} else if(x == 1) {
//less than or equal to K
cout << t.LTE(l, r, k) << endl;
} else if(x == 2) {
//count occurence of K in [l, r]
cout << t.count(l, r, k) << endl;
}
if(x == 3) {
//sum of elements less than or equal to K in [l, r]
cout << t.sum(l, r, k) << endl;
}
}
return 0;
}