您的位置:首页 > 编程语言 > C语言/C++

KD-Tree 算法的 C++ 实现

2017-12-15 12:39 295 查看

KD-Tree 算法的 C++ 实现

阅读本文前,建议查阅相关资料,了解 KNN 算法与 KD 树。

基础知识

如图所示,假设一个点
a
目前的最近邻点为
b
,如果存在相对于
b
a
更近的点,那么这个点一定在以
a
为圆心,
ab
为半径的圆内。

现右侧的区域是未知的,如果
a
到分界线的距离
l
大于目前的最近距离
L
(圆半径),则没有必要在右侧的未知区域继续寻找最近邻点(如图一),反之,则要继续寻找(如图二)。

相应的,投射到多维空间,假如切分边界为第
i
维,切分点的值为
v
(标量),当前最近邻点为
y
(向量),如果目标点
x
(向量) 到切分边界的距离 |x[i] - v| 满足以下关系



时,需要在另一侧继续搜索。





通常地,一个机器学习算法分为
fit
predict
两个阶段,基于线性搜索的
KNN
是一种惰性算法,它将全部的计算任务放到了
predict
阶段,
predict
的时间复杂度为
O(n)
,KD 树之所以比线性搜索快,就是因为它将一部分任务放到了
fit
(建立 KD 树) 阶段,从而在搜索时可以略去大量不必搜索的结点(最优情况下时间复杂度为
O(1)
)。

上面说的比较简单,关于 KNN 算法和 KD 树的详细内容,请参考李航博士的《统计学习方法》。

代码

我们给出部分关键性的代码。

基本数据结构

训练集用一个一维数组
double *data
表示,它的长度为
n_samples * n_features
,标签集也用一个一维数组
double *labels
表示,它的长度为
n_samples


树的结点用以下数据结构表示

cpp

struct tree_node

{

size_t id;               // 表示训练集中的第 i 个数据

size_t split;            // 切分的维度

tree_node *left, *right; // 左、右子树

};


一个 KD 树的模型可用以下结构表示

cpp

struct tree_model

{

tree_node *root;        // 根结点

const double *datas;    // X

const double *labels;   // y

size_t n_samples;       // 样例数

size_t n_features;      // 每个样例的特征数

double p;               // 距离度量

};


求 K-近邻时需要用到大顶堆,我们直接用 C++ 的优先队列来表示,堆内现有的
n(n <= k)
个近邻点中,距离测试点最远的在堆顶

struct neighbor_heap_cmp {
bool operator()(const std::tuple<size_t, double> &i,
const std::tuple<size_t, double> &j) {
return std::get<1>(i) < std::get<1>(j);
}
};

typedef std::tuple<size_t, double> neighbor;
typedef std::priority_queue<neighbor,
std::vector<neighbor>, neighbor_heap_cmp> neighbor_heap_;

neighbor_heap k_neighbor_heap_;


KD-Tree 类

我们用类
KDTree
表示一个 KD 树类,它应该具有的功能有
建树
搜索


//(简化的代码,完整的代码详见最后)
class KDTree {
public:
// 建树
KDTree(const double *datas, const double *labels, size_t rows, size_t cols, double p)
// 返回树
tree_node *GetRoot() { return root; }
// 求一个测试点的 k 邻
std::vector<std::tuple<size_t, double>> FindKNearests(const double *coor, size_t k);
private:
tree_node *root_;
}


寻找切分维和切分点

在建树之前,我们还要考虑如何选择切分维度和切分点。切分维度的选择有许多,一般的,可以取
dim = floor % n_features
,即当前树的层数对特征数取余,我们在这里使用
dim = argmax(nmax - nmin)
,即选取当前结点集合中极差最大的维度。

(这里是不完整的代码,有些工具函数的定义请详见完整源代码)
size_t KDTree::FindSplitDim(const std::vector<size_t> &points) {
if (points.size() == 1)
return 0;
size_t cur_best_dim = 0;
double cur_largest_spread = -1;
double cur_min_val;
double cur_max_val;
for (size_t dim = 0; dim < n_features; ++dim) {
cur_min_val = GetDimVal(points[0], dim);
cur_max_val = GetDimVal(points[0], dim);
for (const auto &id : points) {
if (GetDimVal(id, dim) > cur_max_val)
cur_max_val = GetDimVal(id, dim);
else if (GetDimVal(id, dim) < cur_min_val)
cur_min_val = GetDimVal(id, dim);
}

if (cur_max_val - cur_min_val > cur_largest_spread) {
cur_largest_spread = cur_max_val - cur_min_val;
cur_best_dim = dim;
}
}
return cur_best_dim;
}


选择完切分维
k
之后,我们需选取当前结点集合中的结点在第
k
维的值的中位数
x
作为切分点的值,除去该点之外的点,第
k
维的值小于等于
x
的,放入左子树,反之放入右子树。

在求中位数时,不要全排序,然后取中间的点,可以采用类似快排的方法,找到中位数时就停止排序,这里我们就不写算法了,直接用 C++ 的函数。

std::tuple<size_t, double> KDTree::MidElement(const std::vector<size_t> &points, size_t dim) {
size_t len = points.size();
for (size_t i = 0; i < points.size(); ++i)
get_mid_buf_[i] = std::make_tuple(points[i], GetDimVal(points[i], dim));
std::nth_element(get_mid_buf_,
get_mid_buf_ + len / 2,
get_mid_buf_ + len,
[](const std::tuple<size_t, double> &i, const std::tuple<size_t, double> &j) {
return std::get<1>(i) < std::get<1>(j);
});
return get_mid_buf_[len / 2];
}


建树

建树直接按照建立二叉树的方法即可

tree_node *KDTree::BuildTree(const std::vector<size_t> &points) {
size_t dim = FindSplitDim(points);
std::tuple<size_t, double> t = MidElement(points, dim);
size_t arg_mid_val = std::get<0>(t);
double mid_val = std::get<1>(t);

tree_node *node = Malloc(tree_node, 1);
node->left = nullptr;
node->right = nullptr;
node->id = arg_mid_val;
node->split = dim;
std::vector<size_t> left, right;
for (auto &i : points) {
if (i == arg_mid_val)
continue;
if (GetDimVal(i, dim) <= mid_val)
left.emplace_back(i);
else
right.emplace_back(i);
}
if (!left.empty())
node->left = BuildTree(left);
if (!right.empty())
node->right = BuildTree(right);
return node;
}


搜索 K-近邻的规则

一般书上所讲的都是搜索最近邻,但是我们这里是搜索 K-近邻,需要对书上的算法做少许的扩充。

搜索最近邻时,我们一般设置两个变量
cur_min_id
cur_min_dist
,如果当前搜索到的点到测试点的距离
l < cur_min_dist
时,我们将上述两个变量更新为新点的
id
dist


相应的,在搜索 K-近邻时,我们可以设置一个最多有
k
个元素的大顶堆,这样,在搜索时,当堆满时,只需比较当前搜索点的
dist
是否小于堆顶点的
dist
,如果小于,堆顶出堆,并将当前搜索点压入,反之,则不变;当堆未满时,直接将该搜索点压入。

搜索 K-近邻的算法

我们直接使用二叉树深度优先遍历的非递归算法(具体的描述详见《统计学习方法》第 43 页算法 3.3)。

std::vector<std::tuple<size_t, double>> KDTree::FindKNearests(const double *coor, size_t k) {
std::memset(visited_buf_, 0, sizeof(bool) * n_samples);
std::stack<tree_node *> paths;
tree_node *p = root;

while (p) {
HeapStackPush(paths, p, coor, k);
p = coor[p->split] <= GetDimVal(p->id, p->split) ? p = p->left : p = p->right;
}
while (!paths.empty()) {
p = paths.top();
paths.pop();

if (!p->left && !p->right)
continue;

if (k_neighbor_heap_.size() < k) {
if (p->left)
HeapStackPush(paths, p->left, coor, k);
if (p->right)
HeapStackPush(paths, p->right, coor, k);
} else {
double node_split_val = GetDimVal(p->id, p->split);
double coor_split_val = coor[p->split];
double heap_top_val = std::get<1>(k_neighbor_heap_.top());
if (coor_split_val > node_split_val) {
if (p->right)
HeapStackPush(paths, p->right, coor, k);
if ((coor_split_val - node_split_val) < heap_top_val && p->left)
HeapStackPush(paths, p->left, coor, k);
} else {
if (p->left)
HeapStackPush(paths, p->left, coor, k);
if ((node_split_val - coor_split_val) < heap_top_val && p->right)
HeapStackPush(paths, p->right, coor, k);
}
}
}
std::vector<std::tuple<size_t, double>> res;

while (!k_neighbor_heap_.empty()) {
res.emplace_back(k_neighbor_heap_.top());
k_neighbor_heap_.pop();
}
return res;
}


完整代码

详见 https://github.com/WiseDoge/libkdtree

完整代码中除了 KD-Tree 的代码外,还给出了测试代码和 Python 接口代码,以及一些调用第三方库来加速的手段。

原文地址

http://www.jianshu.com/p/80e41da2a397
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息