「51nod 1792」Jabby's segment tree

线段树是一种经典的数据结构,一颗 [1,n][1,n] 的线段树他的根是 [1,n][1,n] ,当一个线段树的结点是 [l,r][l,r] 时,设 mid=(l+r)/2mid=(l+r)/2 ,则这个结点的左儿子右儿子分别是 [l,mid][l,mid][mid+1,r][mid+1,r]

当我们在线段树上跑 [x,y][x,y] 询问时,一般是从根节点开始计算的,设现在所在结点是 [l,r][l,r] ,有以下几种分支:

  • [x,y][x,y] 包含 [l,r][l,r] ,计算结束。
  • 否则,若左儿子和 [x,y][x,y] 有交,计算左儿子,若右儿子和 [x,y][x,y] 有交,计算右儿子。

定义询问 [x,y][x,y] 的费用是询问时计算了几个结点。

给定 QQ 次询问,每次给定 ll , rr ,求满足 lxyrl\leq x\leq y\leq r(x,y)(x,y) 的费用之和。

Constraints

1n,Q1000001\leq n,Q \leq 100000

Solution

直接在线段树上维护答案。

假设当前节点编号为 xx ,代表的区间为 [l,r][l,r] ,则令: t(x)t(x) 表示两个端点都在区间 [l,r][l,r] 内的答案; tl(x)tl(x) 表示查询左端点为 ll ,右端点位于 [l,r)[l,r) 的答案; tr(x)tr(x) 表示查询右端点为 rr ,左端点位于 (l,r](l,r] 内的答案。

具体的统计方式见代码。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include<cstdio>
#include<algorithm>
#include<cstring>
#define lc (x<<1)
#define rc (x<<1|1)
#define LL long long
using namespace std;
const int N=1e5+5;
const int mod=1e9+7;
int n,Q,L,R,ans;
int t[N*4],tl[N*4],tr[N*4];
void Mod(int &a,int b){a+=b;while(a>=mod)a-=mod;}
int read()
{
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
return x*f;
}
void update(int x,int l,int r)
{
int mid=(l+r)>>1;
Mod(tl[x],tl[lc]+1);
Mod(tl[x],tl[rc]+r-mid-1);
Mod(tl[x],r-l);
Mod(tr[x],tr[rc]+1);
Mod(tr[x],tr[lc]+mid-l);
Mod(tr[x],r-l);
Mod(t[x],t[lc]);
Mod(t[x],t[rc]);
Mod(t[x],1ll*tr[lc]*(r-mid)%mod);
Mod(t[x],1ll*tl[rc]*(mid-l+1)%mod);
Mod(t[x],r-mid-1);
Mod(t[x],mid-l);
int len=r-l+1;
Mod(t[x],1ll*len*(len+1)/2%mod);
}
void build(int x,int l,int r)
{
if(l==r){t[x]=1;return;}
int mid=(l+r)>>1;
build(lc,l,mid);
build(rc,mid+1,r);
update(x,l,r);
}
void calc(int x,int l,int r)
{
if(L<=l&&r<=R)
{
Mod(ans,t[x]);
if(L<l)Mod(ans,1ll*(tl[x]+1)*(l-L)%mod);
if(R>r)Mod(ans,1ll*(tr[x]+1)*(R-r)%mod);
if(L<l&&R>r)Mod(ans,1ll*(l-L)*(R-r)%mod);
return;
}
int ll=max(L,l),rr=min(R,r),len=rr-ll+1;
Mod(ans,1ll*len*(len+1)/2%mod);
if(L<l)Mod(ans,1ll*(l-L)*len%mod);
if(R>r)Mod(ans,1ll*(R-r)*len%mod);
int mid=(l+r)>>1;
if(L<=mid)calc(lc,l,mid);
if(R>mid)calc(rc,mid+1,r);
}
int main()
{
n=read();Q=read();
build(1,1,n);
while(Q--)
{
L=read();R=read();
ans=0;calc(1,1,n);
printf("%d\n",ans);
}
return 0;
}