这大概是我目前学过的最难理解的知识点了吧(

概述

整体二分,意味着同时二分一切

这个算法适用于静态和动态区间第$k$大,以及一些区间询问问题.那么根据通常的思路,让我们先来介绍一下暴力,再来分析二者的区别.

静态区间第k大

让我们暴力地二分答案来做,应该怎么做呢?既然我们要求区间第$k$大,那么区间中就应该有$k-1$个数比答案大才对.所以我们二分答案$mid$,看看序列里有多少个比它小的数,然后缩小值域.

想象一下,如果对于每一个询问操作我们都这样做,会$T$成什么样子.对于每个操作都要做相同的二分,会出现大量的时间浪费.那么我们最好能一次性处理所有的操作.这时候就要用到整体二分了!

机制

对于答案的二分还是依旧,只不过这次不是针对一个操作了.我们把各个操作排成一列,和二分同步进行.这里我们设计一个函数$solve(ql,qr,l,r)$,表示我们把某些操作放到”符合答案在区间$[l,r]$中”的区域,根据值域进行二分答案,并把它们分别标号$[ql,qr]$来方便递归分组.现在我们来考虑怎么分组以及分组过程中的事情.

首先,如果$l=r$了的话,代表区间所属询问的答案在区间$[l,l]$中,那就直接记录答案即可.然后如果还没到最后,我们就要像暴力那样二分出mid,然后扫一遍对应的数看一看有多少数大于mid,然后接下来就是看操作如何分组了.

每个操作都询问一个区间,针对一个区间的时候我们可以直接做,那么针对许多区间呢?我们可以用树状数组的方式,通过前缀和相减来迅速确认有多少个是大于mid的.如果大于k,代表答案应该还在左边,所以把这个操作放到左边,否则放到右边,具体可以通过开两个数组暂存来实现.最后像归并排序一样把这些暂存器里的操作按已经分好的左右放回操作序列,并根据分的左右来递归左边多少,右边多少.

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
int sum[5000050],n,m,cnt,answer[5000050];
const int oo=1e9;
struct node
{
    int a,b,k,id,type;
}q[5000050],q1[5000050],q2[5000050];
/****************BIT*************/
int lowbit(int x)
{
    return x&(-x);
}
void add(int pos,int w)
{
    for(int i=pos;i<=n;i+=lowbit(i))sum[i]+=w;
}
int ask(int pos)
{
    int ans=0;
    for(int i=pos;i;i-=lowbit(i))ans+=sum[i];
    return ans;
}
/****************BIT**************/
void solve(int ql,int qr,int l,int r)//假设答案在[l,r]中,符合答案属于[l,r]的操作序列是[ql,qr] 
{
    if(ql>qr)return;
    if(l==r)//如果答案已经确定 
    {
        for(int i=ql;i<=qr;i++)//区间中所有的操作答案都已经确定 
            if(q[i].type==2)
                answer[q[i].id]=l;
        return ;
    }
    int cnt1=0,cnt2=0;
    int mid=(l+r)>>1;//二分答案 
    for(int i=ql;i<=qr;i++)//扫一遍这个操作和数的混合序列,进行分组和重新混合 
        if(q[i].type==1)//如果是数 
        {
            if(q[i].a<=mid)//考虑它应该放在哪个操作区间继续下传
                q1[++cnt1]=q[i],//q1是将要下放到左区间的暂时储存器 
                add(q[i].id,1);//加到树状数组,之后求前缀和 
            else 
                q2[++cnt2]=q[i];//放到右边先不管 
        }
        else//由于读入的顺序,数一定先于操作来处理. 
        {
            int tmp=ask(q[i].b)-ask(q[i].a-1);//用树状数组统计当前有多少小于mid的 
            if(q[i].k<=tmp)//如果小于tmp个,就代表答案在[l,mid] 
                q1[++cnt1]=q[i];
            else 
                q2[++cnt2]=q[i],//否则在[mid+1,r] 
                q2[cnt2].k-=tmp;//此时应减去cmp 
        }
    for(int i=1;i<=cnt1;i++){//还原树状数组 
        if(q1[i].type==2)
            break;
        add(q1[i].id,-1);
    }
    for(int i=1;i<=cnt1;i++)//赋值回原数组,类似于归并排序 
        q[i+ql-1]=q1[i];
    for(int i=1;i<=cnt2;i++)
        q[i+ql+cnt1-1]=q2[i];
    solve(ql,ql+cnt1-1,l,mid);//递归下去直到确定答案 
    solve(ql+cnt1,qr,mid+1,r);
    return; 
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
        scanf("%d",&q[++cnt].a),//读入序列 
        q[cnt].id=i,//储存下标 
        q[cnt].type=1;//代表这是一个数 
    for(int i=1;i<=m;i++)
        scanf("%d",&q[++cnt].a),//离线读入所有操作 
        scanf("%d",&q[cnt].b),
        scanf("%d",&q[cnt].k),
        q[cnt].id=i,//储存下标 
        q[cnt].type=2;//代表这是一个操作 
    solve(1,cnt,-oo,oo);
    for(int i=1;i<=m;i++)
        cout << answer[i],
        printf("\n");
    return 0;
}



动态区间K大

相比于静态,动态只需要加一点操作,就是当扫到修改的时候,如果修改后的值$\le mid$,才能放到左区间,而且要先执行以下免得后面的判断不符合事实.

其他问题[POI2011]MET-Meteors

在该问题中,我们首先要做的是断环成链,使得问题变成序列问题.操作区间右端点小于左端点时,将右端点+m即可.

这次我们二分的答案是每个国家最早什么时候收集到足够的陨石.因此我们二分出答案mid后,要让$[l,mid]$的陨石雨都落下来(也就是在树状数组上上传),然后按照哪些国家收集足够了为标准,对询问分类即可.值得注意的是,因为一个国家的空间站可能是分散在环上的,所以我们可以通过连边的方式来统一.同时因为是区间加,所以树状数组应维护差分数组.

在判断某个国家是否被满足的时候,应当扫过它的每个空间站$j$,然后将$sum(j)$与$sum(j+m)$相加来作为判断它有没有收集足够的标准.为什么呢?因为每场流星雨代表数列上一个点加$a$,一个点减去$a$,如果左端点小于右端点,那么$sum(j+m)$在后半段,一定是$0$,此时答案是$sum(j)$.如果左端点大于右端点,那么右端点在后半段,左右端点相差不到$m$.那么如果$j$在左端点右边,$j+m$一定在右端点右边,统计同上一种情况.如果$j$在左端点左边,$j+m$要么在右端点右边,$sum(j)=sum(j+m)$,符合这次陨石雨没有影响到这个空间站的情况,要么在右端点左边,此时$sum(j)$一定位0,答案是$sum(j+m)$.

#include <cmath>
#include <queue>
#include <deque>
#include <cctype>
#include <string>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 600005
#define int long long
using namespace std;

template<class T> inline void read(T &x) {
    x = 0;
    char ch = getchar(), w = 0;
    while (!isdigit(ch))
        w = (ch =='-'), ch = getchar();
    while (isdigit(ch))
        x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
    x = w ? -x : x;
    return;
}

struct node {
    int head, ind, tot;
}sta[N], stal[N], star[N];

struct NODE {
    int l, r, a;
}eve[N];

int n, m, noe, k;
int nxt[N], to[N];
int bit[N], ans[N];

inline int lowbit(int x) {
    return -x & x;
}

inline void add(int x, int v) {
    while (x <= m + m)
        bit[x] += v, x += lowbit(x);
    return;
}

inline int sum(int x) {
    int ret = 0;
    while (x) 
        ret += bit[x], x -= lowbit(x);
    return ret;
}

inline void addedge(int from, int t) {
    nxt[++noe] = sta[from].head;
    to[noe] = t;
    sta[from].head = noe;
    return;
}

void solve(int l, int r, int ql, int qr) {
    if (l > r) return;
    if (l == r) {
        for (int i = ql; i <= qr; ++i)
            ans[sta[i].ind] = l;
        return;
    }
    int mid = l + r >> 1, tl = 0, tr = 0;
    for (int i = l; i <= mid; ++i) {
        add(eve[i].l, eve[i].a);
        add(eve[i].r + 1, -eve[i].a);
    }
    for (int i = ql; i <= qr; ++i) {
        int temp = 0;
        for (int j = sta[i].head; j && temp <= sta[i].tot; j = nxt[j])
            temp += sum(to[j] + m) + sum(to[j]);
        if (temp >= sta[i].tot) 
            stal[++tl] = sta[i];
        else
            star[++tr] = sta[i],
            star[tr].tot -= temp;
    }
    for (int i = l; i <= mid; ++i) {
        add(eve[i].l, -eve[i].a);
        add(eve[i].r + 1, eve[i].a);
    }
    for (int i = 1; i <= tl; ++i)
        sta[ql + i - 1] = stal[i];
    for (int i = 1; i <= tr; ++i)
        sta[ql + tl + i - 1] = star[i];
    solve(l, mid, ql, ql + tl - 1),
    solve(mid + 1, r, ql + tl, qr);
    return;
}

signed main() {
    read(n), read(m);
    for (int i = 1, x; i <= m; ++i)
        read(x), addedge(x, i);
    for (int i = 1; i <= n; ++i)
        read(sta[i].tot), sta[i].ind = i;
    read(k);
    for (int i = 1; i <= k; ++i) {
        read(eve[i].l), read(eve[i].r), read(eve[i].a);
        if (eve[i].r < eve[i].l)
            eve[i].r += m;
    }
    solve(1, k + 1, 1, n);
    for (int i = 1; i <= n; ++i)
        if (ans[i] > k)
            printf("NIE\n");
        else 
            printf("%d\n", ans[i]);
    return 0;
}

感谢@SWK@Tian-Xing两位dalao的帮助和陪伴

——Gensokyo