您的位置:首页 > 编程语言

编程之美读书笔记_2.5 寻找最大的K个数 测试代码

2010-05-31 23:59 274 查看
//2.5_nth.cpp   by  flyinghearts#qq.com 
#include <iostream>
#include <vector>
#include <queue>
#include <set>
#include <utility>
#include <algorithm>
#include <ctime>
#include <cassert>
#include <cstdlib>
using namespace std;

const int check_result=0;  //是否检查结果的正确性

template<typename T> void nth_mset(const T*, size_t, T*, size_t);
template<typename T> void nth_queue(const T*, size_t, T*, size_t);
template<typename T> void nth_stl(const T*, size_t, T*, size_t);
template<typename T> void nth_partial(const T*, size_t, T*, size_t);
template<typename T> void nth_partial_copy(const T*, size_t, T*, size_t);
template<typename T> void nth_heap_1(const T*, size_t, T*, size_t);
template<typename T> void nth_heap_2(const T*, size_t, T*, size_t);
template<typename T> void nth_heap_3(const T*, size_t, T*, size_t);
template<typename T> void nth_heapsort(const T*, size_t, T*, size_t);

//使用桶排序法,仅适用于32位整数且 整数的最大重复次数小于2^32
void nth_count(const int *src, size_t src_size, int *dst, size_t dst_size);

//从N个数(取值范围(-2^BITS,2^BITS) )取M个最大的数
template<typename T, typename R, size_t size>
void test(R (&ff)[size], size_t count=1, size_t N=1e8,
          size_t M=1e4, size_t BITS=31);

class Rand_num {
  int mask;
 public:
  Rand_num(int i=31){ if(i>31 || i<1) i=31;  mask=((1<<i)-1)&(1<<31);}
  int operator()(){ return (((rand()&3)<<30)+(rand()<<15)+rand())&mask;}
};

struct Func {
  const char *name;
  void (*func)(const int *, size_t, int *, size_t);
};

int main()
{
  Func ff[]={
    {"nth_elmemnt",nth_stl},
    {"nth_count",nth_count},
    {"priority_queue",nth_queue},
    {"partial_sort",nth_partial},
    {"partial_sort_copy",nth_partial_copy},
    {"heap(pop/push) ",nth_heap_1},
    {"heap(pop/copy)",nth_heap_2},
    {"heap(custom_pop)",nth_heap_3},
    {"multiset",nth_mset},
    {"heap_sort",nth_heapsort},
  };
  test<int>(ff,1,1e6,8e5);
  test<int>(ff,1,1e8,1e4,10);
  test<int>(ff,1,1e8,1e4);
  test<int>(ff,1,1e8,1e5,10);
  test<int>(ff,1,1e8,1e5);
  test<int>(ff,1,1e8,1e6,10);
  test<int>(ff,1,1e8,1e6);
  test<int>(ff,1,1e8,5e6,10);
  test<int>(ff,1,1e8,5e6);
}

template<typename T>
void nth_heapsort(const T *src, size_t src_size, T *dst, size_t dst_size)
{
  assert( dst_size >0 && dst_size <= src_size);
  vector<T> cache(src,src+src_size);
  T *first=&cache[0];
  T *last=first+src_size;
  T *pp=last;
  make_heap(first, last);
  for (size_t i=0; i<dst_size; ++i)  pop_heap(first, pp--);
  copy(last-dst_size,last, dst);
}

template<typename T>
void nth_mset(const T *src, size_t src_size, T *dst, size_t dst_size)
{
  assert( dst_size >0 && dst_size <= src_size);
  assert(src);
  assert(dst);
  multiset<T> tmp(src, src+dst_size);
  const T *p=src + dst_size -1;
  const T *end=src + src_size;
  T min=*(tmp.begin());
  while ( ++p < end){
    if (*p > min){
      tmp.insert(*p);
      tmp.erase(tmp.begin());
      min=*(tmp.begin());
    }
  }
  copy(tmp.begin(),tmp.end(),dst);
}

template<typename T>
void nth_queue(const T *src, size_t src_size, T *dst, size_t dst_size)
{
  assert( dst_size >0 && dst_size <= src_size);
  assert(src);
  assert(dst);
  typedef greater<T> CMP;
  CMP cmp;
  priority_queue<T, vector<T>, CMP > tmp(src, src+dst_size);
  const T *p=src + dst_size -1;
  const T *end=src + src_size;
  T toppest=tmp.top();
  while ( ++p < end){
    if ( cmp(*p,toppest) ){
      tmp.pop();
      tmp.push(*p);
      toppest=tmp.top();
    }
  }
  copy( &(tmp.top()), &(tmp.top()) + dst_size,dst);
}

template<typename T>
void nth_stl(const T *src, size_t src_size, T *dst, size_t dst_size)
{
  assert( dst_size >0 && dst_size <= src_size);
  assert(src);
  assert(dst);
  vector<T> tmp(src, src+src_size);
  greater<T> cmp;
  nth_element(tmp.begin(),tmp.begin()+dst_size, tmp.end(), cmp);
  copy(tmp.begin(),tmp.begin()+dst_size,dst);
}

template<typename T>
void nth_partial(const T *src, size_t src_size, T *dst, size_t dst_size)
{
  assert( dst_size >0 && dst_size <= src_size);
  assert(src);
  assert(dst);
  vector<T> tmp(src, src+src_size);
  greater<T> cmp;
  partial_sort(tmp.begin(),tmp.begin()+dst_size, tmp.end(), cmp);
  copy(tmp.begin(),tmp.begin()+dst_size,dst);
}

template<typename T>
void nth_partial_copy(const T *src, size_t src_size, T *dst, size_t dst_size)
{
  assert( dst_size >0 && dst_size <= src_size);
  assert(src);
  assert(dst);
  greater<T> cmp;
  partial_sort_copy(src,src+src_size,dst,dst+dst_size,cmp);
}

template<typename T>
void nth_heap_1(const T *src, size_t src_size, T *dst, size_t dst_size)
{
  assert( dst_size >0 && dst_size <= src_size);
  greater<T> cmp;
  copy(src, src+dst_size, dst);
  const T *p=src + dst_size -1;
  const T *end=src + src_size;
  T * const first=dst;
  T * const last=dst+dst_size;
  make_heap( first, last, cmp);
  T toppest=*first;
  while ( ++p < end){
    if ( cmp(*p, toppest) ){
      pop_heap( first, last, cmp);
      *(last-1)=*p;
      push_heap( first,last, cmp);
      toppest=*first;
    }
  }
}

template<typename T>
void nth_heap_2(const T *src, size_t src_size, T *dst, size_t dst_size)
{
  greater<T> cmp;
  assert( dst_size >0 && dst_size <= src_size);
  vector<int> tmp(src, src+dst_size+1);
  const T *p=src + dst_size -1;
  const T *end=src + src_size;
  T * const first=&tmp[0];
  T * const last=first+dst_size;
  make_heap( first, last, cmp);
  T toppest=*first;
  while ( ++p < end){
    if (  cmp(*p, toppest) ){
      *last=*p;
      pop_heap( first, last+1, cmp);
      toppest=*first;
    }
  }
  copy(first, last,dst);
}

template<typename T>
void nth_heap_3(const T *src, size_t src_size, T *dst, size_t dst_size)
{
  assert( dst_size >0 && dst_size <= src_size);
  copy(src, src+dst_size, dst);
  const T *p=src + dst_size -1;
  const T *end=src + src_size;
  size_t i,left;
  make_heap( &dst[0],&dst[0]+dst_size, greater<T>() );
  T min=dst[0];
  while ( ++p < end){
    if (*p > min){
      dst[0]=*p;
      i=0;
      while( (left=(i<<1)+1)<dst_size-1){
        if (dst[left] > dst[left+1]) ++left;
        if (dst[left] >= *p) break;
        dst[i]=dst[left];
        i=left;
      }
      if (left==dst_size-1 && dst[left]<*p){
        dst[i]=dst[left];
        i=left;
      }
      dst[i]=*p;
      min=dst[0];
    }
  }
}

void nth_count(const int *src, size_t src_size, int *dst, size_t dst_size)
{
  assert( dst_size >0 && dst_size <= src_size);
  assert(sizeof(int)==4);
  const unsigned TOTAL=0x10000;       //桶总数
  const unsigned MIDDLE= TOTAL >> 1;  //0到TOTAL的中间位置
  unsigned count[TOTAL]={0};
  unsigned * const mid=count+MIDDLE;
  size_t sum=0;
  const int *p=src, *end=src+src_size;
  for (; p<end; ++p) ++mid[ *p>>16];
  unsigned pos=TOTAL;
  while (sum < dst_size) sum += count[--pos];
  size_t new_dst_size = count[pos]+ dst_size - sum;
  int *q=dst;
  int high16=pos-MIDDLE;

  if (new_dst_size == 0) {
    for (p=src,q=dst; p<end; ++p)
      if ( (*p>>16)>= high16) *q++=*p;
    return ;
  }

  fill(count, count+TOTAL, 0);

  for (p=src,q=dst; p<end; ++p){
    if ((*p>>16)>high16) *q++=*p;
    else if ((*p>>16)==high16) ++count[*p&0xFFFF];
  }

  int low16=TOTAL, high16_v=high16<<16, value=0, number=0;
  sum=0;
  while (1){
    number=count[--low16];
    if (number) {
      sum += number;
      value=high16_v+ low16;
      if (sum<new_dst_size){
        for (; number>0; --number) *q++=value;
      }else{
        for (number -= sum - new_dst_size; number>0; --number)  *q++=value;
        return;
      }
    }
  }
}

template<typename T, typename R, size_t size>
void test(R (&ff)[size], unsigned count, unsigned N, unsigned M, unsigned BITS)
{
  assert( N>=M && M>0);
  assert( size>0);
  assert(ff);
  size_t j=0, i=0;
  vector<T> arr(N),result(M),tmp(M);
  greater<T> cmp;
  cout<<"/nN,M: "<<N<<" "<<M<<"    "<< 1.0*M/N<< endl;
  clock_t tb;
  srand(time(0));
  Rand_num rand_num(BITS);
  for( j=0; j <count; ++j ){
    tb=clock();
    generate(arr.begin(),arr.end(), rand_num);
    tb=clock()-tb;
    cout<<"Randomizing: "<<tb<<" ms/n";
    if(check_result){
       tb=clock();
       partial_sort_copy(arr.begin(), arr.end(), tmp.begin(), tmp.end(), cmp);
       tb=clock()-tb;
       cout<<"partial_sort_copy: "<<tb<<" ms/n";
    }
    for(i=0; i<size; ++i){
      tb=clock();
      ff[i].func(&arr[0],N,&result[0],M);
      tb=clock()-tb;
      cout<< ff[i].name<<"  "<< tb<<"  ms/n";
      if(check_result){
        sort(result.begin(),result.end(), cmp);
        if( !equal(result.begin(), result.end(), tmp.begin())) {
           cout<< ff[i].name<< " failed to pass!/nresult:/n";
        }
      }
    }
    cout<<"/n";
  }
}








桶排序nth_count和标准库的nth_element比较:



#include<iostream>
#include<vector>
#include<algorithm>
#include<cstring>
#include<ctime>
#include<cassert>
using namespace std;

void nth_count(const int *src, size_t src_size, int *dst, size_t dst_size)
{
  assert( dst_size >0 && dst_size <= src_size);
  assert(sizeof(int)==4);
  const unsigned TOTAL=0x10000;       //桶总数
  const unsigned MIDDLE= TOTAL >> 1;  //0到TOTAL的中间位置
  unsigned count[TOTAL]={0};
  unsigned * const mid=count+MIDDLE;
  size_t sum=0;
  const int *p=src, *end=src+src_size;
  for (; p<end; ++p) ++mid[ *p>>16];
  unsigned pos=TOTAL;
  while (sum < dst_size) sum += count[--pos];
  size_t new_dst_size = count[pos]+ dst_size - sum;
  int *q=dst;
  int high16=pos-MIDDLE;

  if (new_dst_size == 0) {
    for (p=src,q=dst; p<end; ++p)
      if ( (*p>>16)>= high16) *q++=*p;
    return ;
  }

  fill(count, count+TOTAL, 0);

  for (p=src,q=dst; p<end; ++p){
    if ((*p>>16)>high16) *q++=*p;
    else if ((*p>>16)==high16) ++count[*p&0xFFFF];
  }

  int low16=TOTAL, high16_v=high16<<16, value=0, number=0;
  sum=0;
  while (1){
    number=count[--low16];
    if (number) {
      sum += number;
      value=high16_v+ low16;
      if (sum<new_dst_size){
        for (; number>0; --number) *q++=value;
      }else{
        for (number -= sum - new_dst_size; number>0; --number)  *q++=value;
        return;
      }
    }
  }
}

void test(size_t N=1e8, size_t M=1e6, size_t count=1)
{
  if (M>N || M==0 || count==0) return;
  cout<<"N,M: "<<N<<" "<<M<<"/n";
  vector<int> a(N);
  size_t src_size=N, dst_size=M;
  int *src=&a[0];
  clock_t tb;

  while (count--) {
    generate(a.begin(),a.end(), rand);
    tb=clock();
    nth_element(a.begin(),a.begin()+dst_size, a.end(), greater<int>());
    tb=clock()-tb;
    cout<< "nth_element: " << tb <<"/n";

    tb=clock();
    nth_count(src,src_size,src,dst_size);
    tb=clock()-tb;
    cout<< "nth_count(after nth_element): "<< tb <<"/n";

    random_shuffle(a.begin(),a.end());
    tb=clock();
    nth_count(src,src_size,src,dst_size);
    tb=clock()-tb;
    cout<< "nth_count(after random_shuffle): "<< tb <<"/n/n";
  }
}

int main()
{
  for (size_t i=19e7; i>0; i -= (size_t)1e7)
    test( (size_t)2e8,i,3);
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: