「雅礼集训 2018-03-27」Subset

给你三个 11nn 的排列 aia_ibib_icic_i

称三元组 (x,y,z)(x,y,z) 是合法的,当且仅当存在一个下标集合 SS 满足 (x,y,z)=(max  ai,max  bi,max  ci)(x,y,z)=(max \ \ a_i,max \ \ b_i,max \ \ c_i) (iS)(i\subseteq S)

询问合法三元组的数量。

Constraints

$ n \leq 10^5 $

Solution

对于每一个合法三元组对应的 SS ,只保留对三元组有贡献的下标,可以得到 S3|S|\leq 3 ,且与合法三元组一一对应。问题转化为统计 SS 的数量。

S=1|S|=1 时,所有下标集合都是合法的。

S=2|S|=2 时,可以用总的下标集合数 - 非法下标集合数(即其中一个在 a,b,ca,b,c 里都比另一个大)来统计,这一步可以用 cdq 分治来完成。

S=3|S|=3 时,同样考虑非法下标集合数。分为几种情况进行讨论:

1.1. 存在一个下标在 a,b,ca,b,c 中都是最大的,同样可以用 cdq 分治来完成,记为 AA

2.2. 一个下标在 a,b,ca,b,c 中的两个最大,另一个下标在 a,b,ca,b,c 中的一个最大。考虑枚举 a,b,ca,b,c 中的两个,计算有多少个下标集合满足其中一个在对应枚举的排列里是最大的,计总和为 BB 。可以注意到, BB 中还包含着不合法的 3A3\cdot A 种情况,所以实际上 B=B3AB=B-3\cdot A

最后 (n2)(n1)nAB(n-2)\cdot (n-1)\cdot n-A-B 即可得到 S=3|S|=3 时的下标集合数。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#include<cstdio>
#include<algorithm>
#include<cstring>
#define LL long long
using namespace std;
const int N=1e5+5;
int n,temp,t[N],mn[N];
LL A,B,C,ans;
struct node{int a,b,c;}a[N],tmp[N];
int read()
{
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
return x*f;
}
bool cmp(node a,node b){return a.a<b.a;}
bool cmp2(node a,node b){return a.b<b.b;}
int lowbit(int x){return x&(-x);}
void add(int x,int val){while(x<=n)t[x]+=val,x+=lowbit(x);}
int query(int x){int ans=0;while(x)ans+=t[x],x-=lowbit(x);return ans;}
void cdq(int l,int r)
{
if(l==r)return;
int mid=(l+r)>>1;
for(int i=l;i<=r;i++)
if(a[i].c<=mid)add(a[i].b,1);
else mn[a[i].a]+=query(a[i].b);
int t1=l,t2=mid+1;
for(int i=l;i<=r;i++)
if(a[i].c<=mid)add(a[i].b,-1),tmp[t1++]=a[i];
else tmp[t2++]=a[i];
for(int i=l;i<=r;i++)a[i]=tmp[i];
cdq(l,mid);cdq(mid+1,r);
}
int main()
{
n=read();
for(int i=1;i<=n;i++)a[i].a=read();
for(int i=1;i<=n;i++)a[i].b=read();
for(int i=1;i<=n;i++)a[i].c=read();
sort(a+1,a+n+1,cmp);cdq(1,n);
for(int i=1;i<=n;i++)
A+=1ll*mn[i]*(mn[i]-1)/2,C+=mn[i];
sort(a+1,a+n+1,cmp);
memset(t,0,sizeof(t));
for(int i=1;i<=n;i++)
{
temp=query(a[i].b);
B+=1ll*temp*(temp-1)/2;
add(a[i].b,1);
}
memset(t,0,sizeof(t));
for(int i=1;i<=n;i++)
{
temp=query(a[i].c);
B+=1ll*temp*(temp-1)/2;
add(a[i].c,1);
}
sort(a+1,a+n+1,cmp2);
memset(t,0,sizeof(t));
for(int i=1;i<=n;i++)
{
temp=query(a[i].c);
B+=1ll*temp*(temp-1)/2;
add(a[i].c,1);
}
ans=1ll*(n-2)*(n-1)*n/6-(B-A*2);
ans+=1ll*n*(n-1)/2-C;
printf("%lld",ans+n);
return 0;
}