-
Notifications
You must be signed in to change notification settings - Fork 51
/
Copy pathWavelet.cpp
134 lines (128 loc) · 3.43 KB
/
Wavelet.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
/* Given array of size n.Now you will be given q queries of 1 type l r x,
you need to tell no of elements less than or equal to x in array in range l,r inclusive */
#include <bits/stdc++.h>
using namespace std;
#define gc getchar_unlocked
#define fo(i,n) for(i=0;i<n;i++)
#define Fo(i,k,n) for(i=k;k<n?i<n:i>n;k<n?i+=1:i-=1)
#define ll long long
#define si(x) scanf("%d",&x)
#define sl(x) scanf("%lld",&x)
#define ss(s) scanf("%s",s)
#define pi(x) printf("%d\n",x)
#define pl(x) printf("%lld\n",x)
#define ps(s) printf("%s\n",s)
#define deb(x) cout << #x << "=" << x << endl
#define deb2(x, y) cout << #x << "=" << x << "," << #y << "=" << y << endl
#define pb push_back
#define mp make_pair
#define F first
#define S second
#define all(x) x.begin(), x.end()
#define clr(x) memset(x, 0, sizeof(x))
#define sortall(x) sort(all(x))
#define tr(it, a) for(auto it = a.begin(); it != a.end(); it++)
#define PI 3.1415926535897932384626
typedef pair<int, int> pii;
typedef pair<ll, ll> pl;
typedef vector<int> vi;
typedef vector<ll> vl;
typedef vector<pii> vpii;
typedef vector<pl> vpl;
typedef vector<vi> vvi;
typedef vector<vl> vvl;
int mpow(int base, int exp);
void ipgraph(int m);
void dfs(int u, int par);
const int mod = 1000000007;
const int N = 3e5, M = N;
//=======================
const int MAX = 1e6;
vi g[N];
int a[N];
struct wavelet_tree{
#define vi vector<int>
#define pb push_back
int lo, hi;
wavelet_tree *l, *r;
vi b;
//nos are in range [x,y]
//array indices are [from, to)
wavelet_tree(int *from, int *to, int x, int y){
lo = x, hi = y;
if(lo == hi or from >= to) return;
int mid = (lo+hi)/2;
auto f = [mid](int x){
return x <= mid;
};
b.reserve(to-from+1);
b.pb(0);
for(auto it = from; it != to; it++)
b.pb(b.back() + f(*it));
//see how lambda function is used here
auto pivot = stable_partition(from, to, f);
l = new wavelet_tree(from, pivot, lo, mid);
r = new wavelet_tree(pivot, to, mid+1, hi);
}
//kth smallest element in [l, r]
int kth(int l, int r, int k){
if(l > r) return 0;
if(lo == hi) return lo;
int inLeft = b[r] - b[l-1];
int lb = b[l-1]; //amt of nos in first (l-1) nos that go in left
int rb = b[r]; //amt of nos in first (r) nos that go in left
if(k <= inLeft) return this->l->kth(lb+1, rb , k);
return this->r->kth(l-lb, r-rb, k-inLeft);
}
//count of nos in [l, r] Less than or equal to k
int LTE(int l, int r, int k) {
if(l > r or k < lo) return 0;
if(hi <= k) return r - l + 1;
int lb = b[l-1], rb = b[r];
return this->l->LTE(lb+1, rb, k) + this->r->LTE(l-lb, r-rb, k);
}
//count of nos in [l, r] equal to k
int count(int l, int r, int k) {
if(l > r or k < lo or k > hi) return 0;
if(lo == hi) return r - l + 1;
int lb = b[l-1], rb = b[r], mid = (lo+hi)/2;
if(k <= mid) return this->l->count(lb+1, rb, k);
return this->r->count(l-lb, r-rb, k);
}
~wavelet_tree(){
delete l;
delete r;
}
};
int main()
{
ios_base::sync_with_stdio(false);
cin.tie(NULL);
srand(time(NULL));
int i,n,k,j,q,l,r;
cin >> n;
fo(i, n) cin >> a[i+1];
wavelet_tree T(a+1, a+n+1, 1, MAX);
cin >> q;
while(q--){
int x;
cin >> x;
cin >> l >> r >> k;
if(x == 0){
//kth smallest
cout << "Kth smallest: ";
cout << T.kth(l, r, k) << endl;
}
if(x == 1){
//less than or equal to K
cout << "LTE: ";
cout << T.LTE(l, r, k) << endl;
}
if(x == 2){
//count occurence of K in [l, r]
cout << "Occurence of K: ";
cout << T.count(l, r, k) << endl;
}
}
return 0;
}