您的位置:首页 > 其它

挑战程序竞赛系列(36):3.3线段树和平方分割

2017-08-25 12:58 686 查看

挑战程序竞赛系列(36):3.3线段树和平方分割

详细代码可以fork下Github上leetcode项目,不定期更新。

练习题如下:

POJ 2104: K-th Number

POJ 3264: Balanced Lineup

POJ 3368: Frequent values

POJ 3470: Walls

POJ 1201: Intervals

UVA 11990: Inversion

分桶法和平方分割

具体可以参考《挑战》P183页,这里简单说说思想。

我的理解:空间换时间,举个例子:

1 2 3 4 5 6 7 8 9 10

求指定区间内的最小值

区间 [1, 3]中的最小值为1
区间 [4, 8]中的最小值为4


传统做法,遍历指定区间需要O(n)次,能够求出答案,但由于频繁查询可能需要O(m)次,所以整体时间复杂度为O(nm)次,有没有办法把时间复杂度降低一些?平方分桶法可以降低到O(mn√)。

说白了,上述情况的每个结点维护自己的信息,分桶法的思想是:

组合几个个体成一个桶,由桶统一维护信息,所以对我们来说,它的呈现形式是多个个体和一个个桶,也就是所谓的空间换时间。

比如上述例子:

{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}

如果给定查询区间[2, 9]

那就意味着要遍历个体元素2和个体元素9,但查询区间中,有几个桶完全包含了该区间,如{3, 4}

那么,我们就可以直接拿出桶维护的最小元素信息

所以遍历时,我们只需要遍历 2个元素 + 3个桶
传统做法需要遍历 8个元素

谁快?


当然分桶法实现比线段树简单,但划分整体比线段树粗暴,所以时间复杂度略慢于线段树。

POJ 2104: K-th Number

思路很简单,根据分桶法,可以把它们放在一个个桶内单独维护,在区间内的桶,因为全部包含,所以排序后可以用二分快速找出答案,而桶不完全包含在区间内的,需要单独计算。整体再采用二分,在所有候选答案中猜出即可。

具体思路可以参考《挑战》P186页,代码如下:

static final int MAX_N = 100000 + 16;
static final int B = 1000;
List<Integer>[] bucket = new ArrayList[MAX_N / B + 1];

void solve() {
int n = ni();
int m = ni();
int[] sort = new int
;
int[] arra = new int
;

for (int i = 0; i <= n / B; ++i) bucket[i] = new ArrayList<Integer>();

for (int i = 0; i < n; ++i){
arra[i] = ni();
bucket[i / B].add(arra[i]);
sort[i] = arra[i];
}

for (int i = 0; i < n / B; ++i){
Collections.sort(bucket[i]);
}

Arrays.sort(sort);

for (int t = 0; t < m; ++t){
int i = ni();
int j = ni();
int k = ni();
i--;
j--;

int lf = -1, rt = n - 1;
while (rt - lf > 1){
int l = i;
int r = j;

int s = l / B;
int e = r / B;
int mid = (lf + rt) / 2;

int key = sort[mid];

int x = 0;
if (e - s <= 1){
for (int y = l; y <= r; ++y){
if (arra[y] <= key) x++;
}
}
else{

while (l < n && l / B == s){
if (arra[l] <= key) x++;
l++;
}

while (r >= 0 && r / B == e){
if (arra[r] <= key) x++;
r--;
}

for (int y = s + 1; y < e; ++y){
x += binarySearch(bucket[y], key) + 1;
}
}
if (x < k){
lf = mid;
}
else{
rt = mid;
}
}
out.println(sort[rt]);
}

}

public int binarySearch(List<Integer> aux, int key){
int lf = 0, rt = aux.size() - 1;
while (lf < rt){
int mid = lf + (rt - lf + 1) / 2;
if (aux.get(mid) > key){
rt = mid - 1;
}
else lf = mid;
}
if (aux.get(lf) <= key) return lf;
return -1;
}


TLE了,显然此题用分桶法还不够快,因此我们采用线段树来解决,线段树维护的独立个体是自底向上慢慢长大的,所以空间复杂度更高,但速度会更快,代码如下:

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.InputMismatchException;
import java.util.List;

public class Main{
InputStream is;
PrintWriter out;
String INPUT = "./data/judge/201708/P2104.txt";

static final int SIZE = (1 << 18) - 1;
List<Integer>[] dat = new ArrayList[SIZE];
int[] A;

void solve(){
int N = ni();
int M = ni();
A = new int
;
int[] sort = new int
;
for (int i = 0; i < N; ++i){
A[i] = ni();
sort[i] = A[i];
}
for (int i = 0; i < dat.length; ++i) dat[i] = new ArrayList<Integer>();

Arrays.sort(sort);
init(0, 0, N);

for (int t = 0; t < M; ++t){
int i = ni();
int j = ni();
int k = ni();
i--;
int lf = -1, rt = N - 1;
while (rt - lf > 1){
int mid = (lf + rt) / 2;
int query = query(0, i, j, sort[mid], 0, N);
if (query < k){
lf = mid;
}
else{
rt = mid;
}
}

out.println(sort[rt]);
}
}

/******************以下是线段树******************/

/***
* 区间 [l, r)
* @param k
* @param l
* @param r
*/
public void init(int k, int l, int r){
if (r - l == 1){
dat[k].add(A[l]);
}
else{
int lch = 2 * k + 1;
int rch = 2 * k + 2;
init(lch, l, (l + r) / 2); //为了能够准确的划分区间
init(rch, (l + r) / 2, r);

merge (dat[lch], dat[rch], dat[k]);
}
}

/**
* 查询区间 [i, j)
* 线段树区间 [l, r)
* @param k
* @param i
* @param j
* @param x
* @param l
* @param r
* @return
*/
public int query(int k, int i, int j, int x, int l, int r){
if (j <= l || i >= r) return 0;
else if (i <= l && j >= r){
return binarySearch(dat[k], x) + 1;
}else{
int ans = 0;
ans += query(2 * k + 1, i, j, x, l, (l + r) / 2);
ans += query(2 * k + 2, i, j, x, (l + r) / 2, r);
return ans;
}
}

public void merge(List<Integer> lch, List<Integer> rch, List<Integer> k){
int l = 0, r = 0;
while (l < lch.size() && r < rch.size()){
if (lch.get(l) <= rch.get(r)){
k.add(lch.get(l));
l++;
}
else{
k.add(rch.get(r));
r++;
}
}

while (l < lch.size()) k.add(lch.get(l++));
while (r < rch.size()) k.add(rch.get(r++));
}

public int binarySearch(List<Integer> aux, int key){
int lf = 0, rt = aux.size() - 1;
while (lf < rt){
int mid = lf + (rt - lf + 1) / 2;
if (aux.get(mid) > key){
rt = mid - 1;
}
else lf = mid;
}
if (aux.get(lf) <= key) return lf;
return -1;
}

void run() throws Exception {
is = oj ? System.in : new FileInputStream(new File(INPUT));
out = new PrintWriter(System.out);

long s = System.currentTimeMillis();
solve();
out.flush();
tr(System.currentTimeMillis() - s + "ms");
}

public static void main(String[] args) throws Exception {
new Main().run();
}

private byte[] inbuf = new byte[1024];
public int lenbuf = 0, ptrbuf = 0;

private int readByte() {
if (lenbuf == -1)
throw new InputMismatchException();
if (ptrbuf >= lenbuf) {
ptrbuf = 0;
try {
lenbuf = is.read(inbuf);
} catch (IOException e) {
throw new InputMismatchException();
}
if (lenbuf <= 0)
return -1;
}
return inbuf[ptrbuf++];
}

private boolean isSpaceChar(int c) {
return !(c >= 33 && c <= 126);
}

private int skip() {
int b;
while ((b = readByte()) != -1 && isSpaceChar(b))
;
return b;
}

private double nd() {
return Double.parseDouble(ns());
}

private char nc() {
return (char) skip();
}

private String ns() {
int b = skip();
StringBuilder sb = new StringBuilder();
while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
// ')
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}

private char[] ns(int n) {
char[] buf = new char
;
int b = skip(), p = 0;
while (p < n && !(isSpaceChar(b))) {
buf[p++] = (char) b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}

private char[][] nm(int n, int m) {
char[][] map = new char
[];
for (int i = 0; i < n; i++)
map[i] = ns(m);
return map;
}

private int[] na(int n) {
int[] a = new int
;
for (int i = 0; i < n; i++)
a[i] = ni();
return a;
}

private int ni() {
int num = 0, b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}

while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 + (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}

private long nl() {
long num = 0;
int b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}

while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 + (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}

private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

private void tr(Object... o) {
if (!oj)
System.out.println(Arrays.deepToString(o));
}
}


POJ 3264: Balanced Lineup

水题,思路很直接,关键怎么加快速度,采用分桶法,可以直接参考《挑战》P187页代码。

代码如下:

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
InputStream is;
PrintWriter out;
String INPUT = "./data/judge/201708/P3264.txt";

static final int B = 1000;
static final int MAX_N = 50000 * 2;
int[] max = new int[MAX_N / B + 1];
int[] min = new int[MAX_N / B + 1];

static final int INF = 1 << 29;
void solve() {
int N = ni();
int Q = ni();

int[] cows = new int
;

Arrays.fill(max, -INF);
Arrays.fill(min, INF);

for (int i = 0; i < N; ++i){
cows[i] = ni();
max[i / B] = Math.max(max[i / B], cows[i]);
min[i / B] = Math.min(min[i / B], cows[i]);
}

for (int q = 0; q < Q; ++q){
int i = ni();
int j = ni();
i--;
// [i, j)
int minHeight = INF;
int maxHeight = -INF;

int l = i, r = j;
while (l < r && l % B != 0){
minHeight = Math.min(minHeight, cows[l]);
maxHeight = Math.max(maxHeight, cows[l++]);
}

while (l < r && r % B != 0){
minHeight = Math.min(minHeight, cows[--r]);
maxHeight = Math.max(maxHeight, cows[r]);
}

while (l < r){
int b = l / B;
minHeight = Math.min(minHeight, min);
maxHeight = Math.max(maxHeight, max[b]);
l += B;
}

out.println(maxHeight - minHeight);
}
}

void run() throws Exception {
is = oj ? System.in : new FileInputStream(new File(INPUT));
out = new PrintWriter(System.out);

long s = System.currentTimeMillis();
solve();
out.flush();
tr(System.currentTimeMillis() - s + "ms");
}

public static void main(String[] args) throws Exception {
new Main().run();
}

private byte[] inbuf = new byte[1024];
public int lenbuf = 0, ptrbuf = 0;

private int readByte() {
if (lenbuf == -1)
throw new InputMismatchException();
if (ptrbuf >= lenbuf) {
ptrbuf = 0;
try {
lenbuf = is.read(inbuf);
} catch (IOException e) {
throw new InputMismatchException();
}
if (lenbuf <= 0)
return -1;
}
return inbuf[ptrbuf++];
}

private boolean isSpaceChar(int c) {
return !(c >= 33 && c <= 126);
}

private int skip() {
int b;
while ((b = readByte()) != -1 && isSpaceChar(b))
;
return b;
}

private double nd() {
return Double.parseDouble(ns());
}

private char nc() {
return (char) skip();
}

private String ns() {
int b = skip();
StringBuilder sb = new StringBuilder();
while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
// ')
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}

private char[] ns(int n) {
char[] buf = new char
;
int b = skip(), p = 0;
while (p < n && !(isSpaceChar(b))) {
buf[p++] = (char) b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}

private char[][] nm(int n, int m) {
char[][] map = new char
[];
for (int i = 0; i < n; i++)
map[i] = ns(m);
return map;
}

private int[] na(int n) {
int[] a = new int
;
for (int i = 0; i < n; i++)
a[i] = ni();
return a;
}

private int ni() {
int num = 0, b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}

while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 + (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}

private long nl() {
long num = 0;
int b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}

while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 + (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}

private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

private void tr(Object... o) {
if (!oj)
System.out.println(Arrays.deepToString(o));
}
}


[b]POJ 3368: Frequent values

思路:一开始采用分桶法,用Map记录每个元素的个数,最后再拼接答案,但这种做法超时了,后来想了下,因为统计个数这操作太慢,且并没有充分利用原数组的非递减性质。

超时代码如下:

static final int B = 1000;
static final int MAX_N = 100000 + 2000;
Map<Integer, Integer>[] bucket = new HashMap[MAX_N / B];

void solve() {
while (true){
int N = ni();
if (N == 0) break;

int Q = ni();
int[] A = new int
;

for (int i = 0; i <= N / B; ++i) bucket[i] = new HashMap<Integer, Integer>();
for (int i = 0; i < N; ++i){
A[i] = ni();
int b = i / B;
if (!bucket.containsKey(A[i])) bucket[b].put(A[i], 0);
bucket[b].put(A[i], bucket[b].get(A[i]) + 1);
}

for (int q = 0; q < Q; ++q){
int i = ni();
int j = ni();
i--;

int l = i, r = j;
Map<Integer, Integer> map = new HashMap<Integer, Integer>();
int max = 0;
while (l < r && l % B != 0){
int key = A[l++];
if (!map.containsKey(key)) map.put(key, 0);
map.put(key, map.get(key) + 1);
}

while (l < r && r % B != 0){
int key = A[--r];
if (!map.containsKey(key)) map.put(key, 0);
map.put(key, map.get(key) + 1);
}

while (l < r){
int b = l / B;
for (int key : bucket[b].keySet()){
if (!map.containsKey(key)) map.put(key, 0);
map.put(key, map.get(key) + bucket[b].get(key));
}
l += B;
}

for (int key : map.keySet()){
max = Math.max(map.get(key), max);
}

out.println(max);
}
}
}


此题采用了线段树,我们维护三元组分别表示为{当前区间的最大频次,左边界元素的频次,右边界出现的频次},这样我们就可以从下往上构造每个区间的三元组了,且能够由左孩子和右孩子不断向上合并,用分治的手段解决了统计频次问题。

参考至:http://www.hankcs.com/program/algorithm/poj-3368-frequent-values-am.html

代码如下:

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class Main{
InputStream is;
PrintWriter out;
String INPUT = "./data/judge/201708/P3368.txt";

static final int SIZE = (1 << 18) - 1;
class Pair{
int max;
int left;
int right;

public Pair(int max, int left, int right){
this.max = max;
this.left = left;
this.right = right;
}

@Override
public String toString() {
return max + " " + left + " " + right;
}
}

Pair[] dat = new Pair[SIZE];
int[] A;
void solve() {
while (true){
int N = ni();
if (N == 0) break;

int Q = ni();
A = new int
;

for (int i = 0; i < N; ++i){
A[i] = ni();
}

init(0, 0, N);
for (int q = 0; q < Q; ++q){
int i = ni();
int j = ni();
i--;
out.println(query(0, i, j, 0, N).max);
}

}
}

// 区间 [l, r)
public void init(int k, int l, int r){
if (r - l == 1){
dat[k] = new Pair(1, 1, 1);
}
else{
int lch = 2 * k + 1;
int rch = 2 * k + 2;
init(lch, l, (l + r) / 2);
init(rch, (l + r) / 2, r);

dat[k] = new Pair(0, 0, 0);
dat[k].max = Math.max(dat[lch].max, dat[rch].max);

int mid = (l + r) / 2;
if (A[mid - 1] == A[mid]){
dat[k].max = Math.max(dat[k].max, dat[lch].right + dat[rch].left);
}

if (A[l] == A[mid]){
dat[k].left = dat[lch].left + dat[rch].left;
}
else{
dat[k].left = dat[lch].left;
}

if (A[r - 1] == A[mid - 1]){
dat[k].right = dat[lch].right + dat[rch].right;
}
else{
dat[k].right = dat[rch].right;
}
}
}

// 查询
public Pair query(int k, int i, int j, int l, int r){
if (j <= l || i >= r) return new Pair(0, 0, 0);
else if (i <= l && j >= r){
return dat[k];
}
else{
int mid = (l + r) / 2;
Pair lch = query(2 * k + 1, i, j, l, mid);
Pair rch = query(2 * k + 2, i, j, mid, r);

Pair ans = new Pair(Math.max(lch.max, rch.max), lch.left, rch.right);

if (A[mid] == A[mid - 1]){
ans.max = Math.max(ans.max, lch.right + rch.left);
}

if (A[l] == A[mid]) ans.left += rch.left;
if (A[r - 1] == A[mid - 1]) ans.right += lch.right;

return ans;
}
}

void run() throws Exception {
is = oj ? System.in : new FileInputStream(new File(INPUT));
out = new PrintWriter(System.out);

long s = System.currentTimeMillis();
solve();
out.flush();
tr(System.currentTimeMillis() - s + "ms");
}

public static void main(String[] args) throws Exception {
new Main().run();
}

private byte[] inbuf = new byte[1024];
public int lenbuf = 0, ptrbuf = 0;

private int readByte() {
if (lenbuf == -1)
throw new InputMismatchException();
if (ptrbuf >= lenbuf) {
ptrbuf = 0;
try {
lenbuf = is.read(inbuf);
} catch (IOException e) {
throw new InputMismatchException();
}
if (lenbuf <= 0)
return -1;
}
return inbuf[ptrbuf++];
}

private boolean isSpaceChar(int c) {
return !(c >= 33 && c <= 126);
}

private int skip() {
int b;
while ((b = readByte()) != -1 && isSpaceChar(b))
;
return b;
}

private double nd() {
return Double.parseDouble(ns());
}

private char nc() {
return (char) skip();
}

private String ns() {
int b = skip();
StringBuilder sb = new StringBuilder();
while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
// ')
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}

private char[] ns(int n) {
char[] buf = new char
;
int b = skip(), p = 0;
while (p < n && !(isSpaceChar(b))) {
buf[p++] = (char) b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}

private char[][] nm(int n, int m) {
char[][] map = new char
[];
for (int i = 0; i < n; i++)
map[i] = ns(m);
return map;
}

private int[] na(int n) {
int[] a = new int
;
for (int i = 0; i < n; i++)
a[i] = ni();
return a;
}

private int ni() {
int num = 0, b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}

while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 + (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}

private long nl() {
long num = 0;
int b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}

while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 + (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}

private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

private void tr(Object... o) {
if (!oj)
System.out.println(Arrays.deepToString(o));
}
}


[b]POJ 3470: Walls

累觉不爱,待解决……

题解参考博文:http://www.hankcs.com/program/algorithm/poj-3470-walls.html

POJ 1201: Intervals

思路:排序+贪心+归简

首先按照右区间进行从小到达排序,这样开始选第一个区间时,选择最大的几个数,可以证明这种与后续出现区间存在交集的“可能性”最大,接着再考虑第二个区间时,把选择过的元素排除,继续取剩余最大的几个数,这样一来问题规模逐步缩小,完美解决。

如果快速求解区间内有多少个元素被选?BIT或线段树,BIT更简洁易懂。

代码如下:

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.InputMismatchException;

public class SolutionDay24_P1201 {
InputStream is;
PrintWriter out;
String INPUT = "./data/judge/201708/P1201.txt";

class Range implements Comparable<Range>{
int l;
int r;
int c;

public Range(int l, int r, int c){
this.l = l;
this.r = r;
this.c = c;
}

@Override
public int compareTo(Range that) {
return this.r - that.r;
}

@Override
public String toString() {
return l + " " + r + " " + c;
}
}

void solve() {
int n = ni();
init();
Range[] intervals = new Range
;
for (int i = 0; i < n; ++i){
intervals[i] = new Range(ni(), ni(), ni());
}

Arrays.sort(intervals);
boolean[] visited = new boolean[MAX_N];
int res = 0;
for (int i = 0; i < n; ++i){
Range now = intervals[i];
int picked = sum(now.l, now.r);
if (picked == 0){
res += now.c;
for (int j = 0; j < now.c; ++j){
add(now.r - j, 1);
visited[now.r - j] = true;
}
}
else{
int choose = now.c - picked;
if (choose <= 0) continue;
res += choose;
int pos = now.r;
while (choose > 0){
if (visited[pos]){
pos --;
}
else{
add(pos, 1);
visited[pos] = true;
pos --;
choose --;
}
}
}
}

out.println(res);

}

/*********************BIT************************/
int MAX_N = 2 * (50000 + 16);
int[] BIT;

public void init(){
BIT = new int[MAX_N];
}

public void add(int i, int val){
while (i <= MAX_N){
BIT[i] += val;
i += i & -i;
}
}

public int sum(int i){
int res = 0;
while (i > 0){
res += BIT[i];
i -= i & -i;
}
return res;
}

//区间 [l, r]
public int sum(int l, int r){
return sum(r) - sum(l - 1);
}

void run() throws Exception {
is = oj ? System.in : new FileInputStream(new File(INPUT));
out = new PrintWriter(System.out);

long s = System.currentTimeMillis();
solve();
out.flush();
tr(System.currentTimeMillis() - s + "ms");
}

public static void main(String[] args) throws Exception {
new SolutionDay24_P1201().run();
}

private byte[] inbuf = new byte[1024];
public int lenbuf = 0, ptrbuf = 0;

private int readByte() {
if (lenbuf == -1)
throw new InputMismatchException();
if (ptrbuf >= lenbuf) {
ptrbuf = 0;
try {
lenbuf = is.read(inbuf);
} catch (IOException e) {
throw new InputMismatchException();
}
if (lenbuf <= 0)
return -1;
}
return inbuf[ptrbuf++];
}

private boolean isSpaceChar(int c) {
return !(c >= 33 && c <= 126);
}

private int skip() {
int b;
while ((b = readByte()) != -1 && isSpaceChar(b))
;
return b;
}

private double nd() {
return Double.parseDouble(ns());
}

private char nc() {
return (char) skip();
}

private String ns() {
int b = skip();
StringBuilder sb = new StringBuilder();
while (!(isSpaceChar(b))) { // when nextLine, (isSpaceChar(b) && b != '
// ')
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}

private char[] ns(int n) {
char[] buf = new char
;
int b = skip(), p = 0;
while (p < n && !(isSpaceChar(b))) {
buf[p++] = (char) b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}

private char[][] nm(int n, int m) {
char[][] map = new char
[];
for (int i = 0; i < n; i++)
map[i] = ns(m);
return map;
}

private int[] na(int n) {
int[] a = new int
;
for (int i = 0; i < n; i++)
a[i] = ni();
return a;
}

private int ni() {
int num = 0, b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}

while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 + (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}

private long nl() {
long num = 0;
int b;
boolean minus = false;
while ((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'))
;
if (b == '-') {
minus = true;
b = readByte();
}

while (true) {
if (b >= '0' && b <= '9') {
num = num * 10 + (b - '0');
} else {
return minus ? -num : num;
}
b = readByte();
}
}

private boolean oj = System.getProperty("ONLINE_JUDGE") != null;

private void tr(Object... o) {
if (!oj)
System.out.println(Arrays.deepToString(o));
}
}


UVA 11990: Inversion

求逆序对个数的做法有很多种,可以采用分治合并,BST,线段树等等均可。

可以参考http://blog.csdn.net/u014688145/article/details/72864156

但此题除了求逆序对之外,还会动态删减,这样就需要使用一种好的数据结构去维护一些内部信息,并且在删除时,能够准确的表达出来。

此处我们采用分桶法,或者是一种分治的手段。我们把每个点映射到二维的平面上去(i, A[i]),这样一来,逆序对的个数为其左上点的个数和右下点的个数总和。如何表示动态删除?

因为在二维平面上,可以方便的表达某个点的左上和右下区域,所以删除也是很自然的事情,至于点没有了(即不参与计算),可以用(-1,-1)表示。

之所以可以这么做,因为题目给了一下额外性质:

permutation,范围是在1 ~ N,且小标在0 ~ N - 1,所以这些值可以方便的映射到二维坐标平面(都不需要离散化处理)

当然在做题时,很重要的一点在于下标自然的排序了,所以这给我们累加逆序对的个数也带来了极大好处,时间复杂度只需要O(n)

JAVA 代码如下:

import java.util.Arrays;
import java.util.Scanner;

public class Main{

static final int MAX_N = 200000 + 16;
static final int MAX_M = 200000 + 16;
static final int BUCKET_SIZE = 450;

static int[] A;
static int[] POS;
static int N, M;

static class Bucket{
int count;
int prefix_sum;
}
static Bucket[][] buckets;

static class Space{
int[] X;
int[] Y;

public Space(){
X = new int[MAX_N];
Y = new int[MAX_N];

Arrays.fill(X, -1);
Arrays.fill(Y, -1);
}

public void add(int x, int y){
X[y] = x;
Y[x] = y;
}

public void remove(int x, int y){
X[y] = -1;
Y[x] = -1;
}

public int getX(int y){
return X[y];
}

public int getY(int x){
return Y[x];
}
}
static Space space;

public static void update_prefix_sum(int bx, int by){
int len = buckets[0].length;
int sum = (bx > 0 ? buckets[bx - 1][by].prefix_sum : 0);
for (int i = bx; i < len; ++i){
sum += buckets[i][by].count;
buckets[i][by].prefix_sum = sum;
}
}

public static void add(int x, int y){
space.add(x, y);
int bx = x / BUCKET_SIZE;
int by = y / BUCKET_SIZE;
++buckets[bx][by].count;
update_prefix_sum(bx, by);
}

public static void remove(int x, int y){
space.remove(x, y);
int bx = x / BUCKET_SIZE;
int by = y / BUCKET_SIZE;
--buckets[bx][by].count;
update_prefix_sum(bx, by);
}

// 统计区间 [0,0] 到 [x, y] 的点的个数
public static int sum(int x, int y){
int bx = x / BUCKET_SIZE;
int by = y / BUCKET_SIZE;

int count = 0;
for (int i = 0; i < by; ++i){
if (bx > 0)
count += buckets[bx - 1][i].prefix_sum;
}

for (int py = by * BUCKET_SIZE; py < y; ++py){
if (space.getX(py) != -1 && space.getX(py) < x) count++;
}

for (int px = bx * BUCKET_SIZE; px < x; ++px){
if (space.getY(px) != -1 && space.getY(px) < by * BUCKET_SIZE) count++;
}
return count;
}

public static int sum_inversion(int x, int y){
int res = 0;
int intersection = sum(x, y);
res += sum(x, N) - intersection;
res += sum(N, y) - intersection;
return res;
}

public static void main(String[] args) {
Scanner in = new Scanner(System.in);
while (in.hasNext()){
N = in.nextInt();
M = in.nextInt();
A = new int
;
POS = new int
;
for (int i = 0; i < N; ++i){
A[i] = in.nextInt();
A[i]--;
POS[A[i]] = i;
}

long res = 0;
space = new Space();
buckets = new Bucket[MAX_N / BUCKET_SIZE + 1][MAX_N / BUCKET_SIZE + 1];
int len = buckets.length;
for (int i = 0; i < len; ++i){
for (int j = 0; j < len; ++j){
buckets[i][j] = new Bucket();
}
}

for (int i = 0; i < N; ++i){
add(i, A[i]);
res += sum_inversion(i, A[i]);
}

for (int i = 0; i < M; ++i){
int m = in.nextInt();
m--;
System.out.println(res);
res -= sum_inversion(POS[m], m);
remove(POS[m], m);
}
}
in.close();
}

}


中间利用了一些加速的手段,如前缀和,但总体就是分治的一种迭代版本。。。可惜还是TLE了,改成C++版本,能过,蛋疼。

代码如下:

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define MAX_N 200000 + 16
#define MAX_M 100000 + 16
#define BUCKET_SIZE 450 // sqrt(MAX_N) = 447

int A[MAX_N], N, M;
struct Bucket
{
int count;      // 内部数字的个数
int prefix_sum; // 前缀和
}bucket[BUCKET_SIZE][BUCKET_SIZE];

// 平面坐标快速查询
struct Space
{
int X[MAX_N], Y[MAX_N];

void insert(const int& x, const int& y)
{
X[y] = x;
Y[x] = y;
}

void remove(const int& x, const int& y)
{
X[y] = -1;
Y[x] = -1;
}

int getX(const int& y)
{
return X[y];
}

int getY(const int& x)
{
return Y[x];
}
void init()
{
memset(X, -1, sizeof(X)); memset(Y, -1, sizeof(Y));
}
} space;

void update_prefix_sum(int bx, int by)
{
int sum = (bx > 0 ? bucket[bx - 1][by].prefix_sum : 0);
for (int i = bx; i < BUCKET_SIZE; ++i)
{
sum += bucket[i][by].count;
bucket[i][by].prefix_sum = sum;
}
}

// 加入一个点
void add(int x, int y)
{
space.insert(x, y);
int bx = x / BUCKET_SIZE;
int by = y / BUCKET_SIZE;

++bucket[bx][by].count;
update_prefix_sum(bx, by);
}

// 删除一个点
void remove(int x, int y)
{
space.remove(x, y);
int bx = x / BUCKET_SIZE;
int by = y / BUCKET_SIZE;

--bucket[bx][by].count;
update_prefix_sum(bx, by);
}

// (0,0)与(x,y)围起来的矩形区域的点的个数
int count_sum(int x, int y)
{
int block_w = x / BUCKET_SIZE;
int block_h = y / BUCKET_SIZE;

int count = 0;
// 完全在内部的桶
for (int i = 0; i < block_h; ++i)
{
if (block_w > 0)
{
count += bucket[block_w - 1][i].prefix_sum;
}
}
// 其他
for (int i = block_w * BUCKET_SIZE; i < x; ++i)
{
if (space.getY(i) != -1 && space.getY(i) < block_h * BUCKET_SIZE) count++;
}
for (int i = block_h * BUCKET_SIZE; i < y; ++i)
{
if (space.getX(i) != -1 && space.getX(i) < x) count++;
}
return count;
}

// (x,y)的左上和右下方块内部点的个数就是逆序数对的个数
int count_inversion(int x, int y)
{
int count = 0;
int intersection = count_sum(x, y);
count += count_sum(x, N) - intersection;    // 左上
count += count_sum(N, y) - intersection;    // 右下
return count;
}

///////////////////////////SubMain//////////////////////////////////
int main(int argc, char *argv[])
{
#ifndef ONLINE_JUDGE
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#endif
while (scanf("%d %d", &N, &M) != EOF)
{
space.init();
memset(bucket, 0, sizeof(bucket));
for (int i = 0; i < N; ++i)
{
scanf("%d", &A[i]);
--A[i];
}
long long inversion = 0;
for (int i = 0; i < N; ++i)
{
add(i, A[i]);
inversion += count_inversion(i, A[i]);
}
for (int i = 0; i < M; ++i)
{
int q;
scanf("%d", &q);
--q;
printf("%lld\n", inversion);
inversion -= count_inversion(space.getX(q), q);
remove(space.getX(q), q);
}
}
#ifndef ONLINE_JUDGE
fclose(stdin);
fclose(stdout);
system("out.txt");
#endif
return 0;
}


参考了:http://www.hankcs.com/program/algorithm/uva-11990-inversion.html

但count_sum的函数做了一些改动,以自己的方式计算了(0,0)到(x,y)的个数,大同小异,注意一些边界即可。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: