您的位置:首页 > 其它

spoj 3266 Cow School (splay 斜率优化)

2016-03-31 16:52 387 查看
#include<iostream>
#include<cstring>
#include<string>
#include<cstdio>
#include<stdio.h>
#include<algorithm>
#include<cmath>
#include<map>
#include<queue>
#include<bitset>
#include<stack>
#include<set>
using namespace std;

#pragma comment(linker, "/STACK:1024000000,1024000000")
#define inf 1e11
#define eps 1e-9
#define pii pair<int,int>
#define MP make_pair
#define LL  long long
#define ULL unsigned long long
#define N ( 300000 + 10 )
#define M ( 200000 + 10)
#define mod  1000000007

double mi
, ma
, g
;
struct node {
LL t, p;
bool operator < ( const node &ot) const {
return (LL)t * ot.p < (LL)p * ot.t;
}
}a
;

int x
, y
;
int ch
[2];
double lk
, rk
;
int par
;
int root, sz;
int n;

bool e;
int newnode() {
if(e) return ++sz;
else return sz--;
}
int dcmp(double x) {
if(fabs(x) < eps) return 0;
return x < 0 ? -1 : 1;
}

void rot(int x) {
int y = par[x], d = ch[y][1] == x;
ch[y][d] = ch[x][!d];
if(ch[x][!d]) par[ch[x][!d]] = y;
ch[x][!d] = y;
par[x] = par[y];
par[y] = x;
if(par[x]) ch[par[x]][ch[par[x]][1] == y] = x;
}

void splay(int x, int goal) {
while(par[x] != goal) {
int f = par[x], ff = par[f];
if(ff == goal)
rot(x);
else if((ch[ff][1] == f) == (ch[f][1] == x))
rot(f), rot(x);
else rot(x), rot(x);
}
if(!goal) root = x;
}

int find(double x) {
if(!root) {
if(e) return 0;
else return n + 1;
}
int u = root;
while(u) {
if(dcmp(x-lk[u]) > 0) u = ch[u][0];
else if(dcmp(rk[u] - x) > 0) u = ch[u][1];
else if(dcmp(lk[u]-x)>=0 && dcmp(x-rk[u])>=0) return u;
}
return u;
}

void insert(int i) {
if(root == 0) {
root = newnode();
lk[root] = inf;
rk[root] = -inf;
return ;
}
int u = root;

while(u) {
if(x[i] <= x[u]) {
if(!ch[u][0]) {
ch[u][0] = newnode(), par[ch[u][0]] = u, splay(ch[u][0], 0);
return ;
}
else u = ch[u][0];
}
else {
if(!ch[u][1]) {
ch[u][1] = newnode(), par[ch[u][1]] = u, splay(ch[u][1], 0);
return ;
}
else u = ch[u][1];
}
}
}

double get(int i, int j) {
if(x[i] == x[j]) {
if(y[j] > y[i]) return inf;
else return -inf;
}
return (1.0*y[i] - y[j]) / (1.0*x[i] - x[j]);
}
int getl(int x) {
int t = ch[x][0];
int ans = t;
while(t) {
if(dcmp(lk[t] - get(t,x)) >= 0) ans = t, t = ch[t][1];
else t = ch[t][0];
}
return ans;
}
int getr(int x) {
int t = ch[x][1], ans = t;
while(t) {
if(dcmp(rk[t] - get(x,t))<= 0) ans = t, t = ch[t][0];
else t = ch[t][1];
}
return ans;

}

void maintain(int x) {
splay(x, 0);
if(ch[x][0]) {
int ls = getl(x);
splay(ls, root);
par[ch[ls][1]] = 0, ch[ls][1] = 0;
rk[ls] = lk[x] = get(ls, x);
}
else lk[x] = inf;
if(ch[x][1]) {
int rs = getr(x);
splay(rs, root);
par[ch[rs][0]] = 0, ch[rs][0] = 0;
lk[rs] = rk[x] = get(x, rs);
}
else rk[x] = -inf;

if(lk[x] <= rk[x]) {
splay(x, 0);
int ls = ch[x][0], rs = ch[x][1];
if(!ls && !rs) {
root = 0;
return;
}
if(!ls) {
root = rs;
par[rs] = 0;
lk[rs] = inf;
return ;
}
if(!rs) {
root = ls;
par[ls] = 0;
rk[ls] = -inf;
return ;
}
par[ch[x][0]] = 0, ch[x][0] = 0;
splay(ls, 0);
ch[ls][1] = ch[x][1], par[ch[x][1]] = ls;
rk[ls] = lk[rs] = get(ls, rs);
}
}

void init(){
root = 0;
for(int i = 0; i <= n; ++i) ch[i][0] = ch[i][1] = 0, par[i] = 0;
}

void debug() {
for(int i = 1; i <= n; ++i) {
double mma = -inf, mmi = inf;
for(int j = 1; j < i; ++j)
mma = max(a[j].t - g[i] * a[j].p, mma);
for(int j = i; j <= n; ++j)
mmi = min(a[j].t - g[i] * a[j].p, mmi);
printf("i %d mi %.5lf %.5lf ma %.5lf %.5lf\n", i, mi[i], mmi, ma[i], mma);
}
}

int main() {
//freopen("in.in", "r", stdin);
//freopen("out.out", "w", stdout);
while(~scanf("%d", &n)) {
LL sumt = 0, sump = 0;
for(int i = 1; i <= n; ++i)
scanf("%lld%lld", &a[i].t, &a[i].p);
sort(a+1, a+n+1);
int cnt = 1;
for(int i = n; i >= 1; --i) {
sumt += a[i].t;
sump += a[i].p;
g[i] = (double)sumt / sump;
}
e = 1;
init();
sz = 0;
for(int i = 1; i <= n; ++i) {
int j = find(g[i]);
if(j == 0) ma[i] = -inf;
else
ma[i] = a[j].t - g[i] * a[j].p;
x[i] = a[i].p;
y[i] = a[i].t;
insert(i);
maintain(i);
}
e = 0;
init();
sz = n;
for(int i = n; i >= 1; --i) {
x[i] = a[i].p;
y[i] = -a[i].t;
insert(i);
maintain(i);
int j = find(-g[i]);
if(j == n+1) mi[i] = inf;
else mi[i] = a[j].t - g[i] * a[j].p;
}
int ans = 0;
for(int i = 1; i <= n; ++i) {
if(dcmp(ma[i] - mi[i]) > 0) ++ans;
}
printf("%d\n", ans);
for(int i = 1; i <= n; ++i) {
if(dcmp(ma[i]- mi[i]) > 0) printf("%d\n", i-1);
}
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: