KD-Tree 算法的 C++ 实现
2017-12-15 12:39
295 查看
KD-Tree 算法的 C++ 实现
阅读本文前,建议查阅相关资料,了解 KNN 算法与 KD 树。基础知识
如图所示,假设一个点a目前的最近邻点为
b,如果存在相对于
b离
a更近的点,那么这个点一定在以
a为圆心,
ab为半径的圆内。
现右侧的区域是未知的,如果
a到分界线的距离
l大于目前的最近距离
L(圆半径),则没有必要在右侧的未知区域继续寻找最近邻点(如图一),反之,则要继续寻找(如图二)。
相应的,投射到多维空间,假如切分边界为第
i维,切分点的值为
v(标量),当前最近邻点为
y(向量),如果目标点
x(向量) 到切分边界的距离 |x[i] - v| 满足以下关系
时,需要在另一侧继续搜索。
通常地,一个机器学习算法分为
fit和
predict两个阶段,基于线性搜索的
KNN是一种惰性算法,它将全部的计算任务放到了
predict阶段,
predict的时间复杂度为
O(n),KD 树之所以比线性搜索快,就是因为它将一部分任务放到了
fit(建立 KD 树) 阶段,从而在搜索时可以略去大量不必搜索的结点(最优情况下时间复杂度为
O(1))。
上面说的比较简单,关于 KNN 算法和 KD 树的详细内容,请参考李航博士的《统计学习方法》。
代码
我们给出部分关键性的代码。基本数据结构
训练集用一个一维数组double *data表示,它的长度为
n_samples * n_features,标签集也用一个一维数组
double *labels表示,它的长度为
n_samples。
树的结点用以下数据结构表示
cpp struct tree_node { size_t id; // 表示训练集中的第 i 个数据 size_t split; // 切分的维度 tree_node *left, *right; // 左、右子树 };
一个 KD 树的模型可用以下结构表示
cpp struct tree_model { tree_node *root; // 根结点 const double *datas; // X const double *labels; // y size_t n_samples; // 样例数 size_t n_features; // 每个样例的特征数 double p; // 距离度量 };
求 K-近邻时需要用到大顶堆,我们直接用 C++ 的优先队列来表示,堆内现有的
n(n <= k)个近邻点中,距离测试点最远的在堆顶
struct neighbor_heap_cmp { bool operator()(const std::tuple<size_t, double> &i, const std::tuple<size_t, double> &j) { return std::get<1>(i) < std::get<1>(j); } }; typedef std::tuple<size_t, double> neighbor; typedef std::priority_queue<neighbor, std::vector<neighbor>, neighbor_heap_cmp> neighbor_heap_; neighbor_heap k_neighbor_heap_;
KD-Tree 类
我们用类KDTree表示一个 KD 树类,它应该具有的功能有
建树和
搜索。
//(简化的代码,完整的代码详见最后) class KDTree { public: // 建树 KDTree(const double *datas, const double *labels, size_t rows, size_t cols, double p) // 返回树 tree_node *GetRoot() { return root; } // 求一个测试点的 k 邻 std::vector<std::tuple<size_t, double>> FindKNearests(const double *coor, size_t k); private: tree_node *root_; }
寻找切分维和切分点
在建树之前,我们还要考虑如何选择切分维度和切分点。切分维度的选择有许多,一般的,可以取dim = floor % n_features,即当前树的层数对特征数取余,我们在这里使用
dim = argmax(nmax - nmin),即选取当前结点集合中极差最大的维度。
(这里是不完整的代码,有些工具函数的定义请详见完整源代码) size_t KDTree::FindSplitDim(const std::vector<size_t> &points) { if (points.size() == 1) return 0; size_t cur_best_dim = 0; double cur_largest_spread = -1; double cur_min_val; double cur_max_val; for (size_t dim = 0; dim < n_features; ++dim) { cur_min_val = GetDimVal(points[0], dim); cur_max_val = GetDimVal(points[0], dim); for (const auto &id : points) { if (GetDimVal(id, dim) > cur_max_val) cur_max_val = GetDimVal(id, dim); else if (GetDimVal(id, dim) < cur_min_val) cur_min_val = GetDimVal(id, dim); } if (cur_max_val - cur_min_val > cur_largest_spread) { cur_largest_spread = cur_max_val - cur_min_val; cur_best_dim = dim; } } return cur_best_dim; }
选择完切分维
k之后,我们需选取当前结点集合中的结点在第
k维的值的中位数
x作为切分点的值,除去该点之外的点,第
k维的值小于等于
x的,放入左子树,反之放入右子树。
在求中位数时,不要全排序,然后取中间的点,可以采用类似快排的方法,找到中位数时就停止排序,这里我们就不写算法了,直接用 C++ 的函数。
std::tuple<size_t, double> KDTree::MidElement(const std::vector<size_t> &points, size_t dim) { size_t len = points.size(); for (size_t i = 0; i < points.size(); ++i) get_mid_buf_[i] = std::make_tuple(points[i], GetDimVal(points[i], dim)); std::nth_element(get_mid_buf_, get_mid_buf_ + len / 2, get_mid_buf_ + len, [](const std::tuple<size_t, double> &i, const std::tuple<size_t, double> &j) { return std::get<1>(i) < std::get<1>(j); }); return get_mid_buf_[len / 2]; }
建树
建树直接按照建立二叉树的方法即可tree_node *KDTree::BuildTree(const std::vector<size_t> &points) { size_t dim = FindSplitDim(points); std::tuple<size_t, double> t = MidElement(points, dim); size_t arg_mid_val = std::get<0>(t); double mid_val = std::get<1>(t); tree_node *node = Malloc(tree_node, 1); node->left = nullptr; node->right = nullptr; node->id = arg_mid_val; node->split = dim; std::vector<size_t> left, right; for (auto &i : points) { if (i == arg_mid_val) continue; if (GetDimVal(i, dim) <= mid_val) left.emplace_back(i); else right.emplace_back(i); } if (!left.empty()) node->left = BuildTree(left); if (!right.empty()) node->right = BuildTree(right); return node; }
搜索 K-近邻的规则
一般书上所讲的都是搜索最近邻,但是我们这里是搜索 K-近邻,需要对书上的算法做少许的扩充。搜索最近邻时,我们一般设置两个变量
cur_min_id和
cur_min_dist,如果当前搜索到的点到测试点的距离
l < cur_min_dist时,我们将上述两个变量更新为新点的
id和
dist。
相应的,在搜索 K-近邻时,我们可以设置一个最多有
k个元素的大顶堆,这样,在搜索时,当堆满时,只需比较当前搜索点的
dist是否小于堆顶点的
dist,如果小于,堆顶出堆,并将当前搜索点压入,反之,则不变;当堆未满时,直接将该搜索点压入。
搜索 K-近邻的算法
我们直接使用二叉树深度优先遍历的非递归算法(具体的描述详见《统计学习方法》第 43 页算法 3.3)。std::vector<std::tuple<size_t, double>> KDTree::FindKNearests(const double *coor, size_t k) { std::memset(visited_buf_, 0, sizeof(bool) * n_samples); std::stack<tree_node *> paths; tree_node *p = root; while (p) { HeapStackPush(paths, p, coor, k); p = coor[p->split] <= GetDimVal(p->id, p->split) ? p = p->left : p = p->right; } while (!paths.empty()) { p = paths.top(); paths.pop(); if (!p->left && !p->right) continue; if (k_neighbor_heap_.size() < k) { if (p->left) HeapStackPush(paths, p->left, coor, k); if (p->right) HeapStackPush(paths, p->right, coor, k); } else { double node_split_val = GetDimVal(p->id, p->split); double coor_split_val = coor[p->split]; double heap_top_val = std::get<1>(k_neighbor_heap_.top()); if (coor_split_val > node_split_val) { if (p->right) HeapStackPush(paths, p->right, coor, k); if ((coor_split_val - node_split_val) < heap_top_val && p->left) HeapStackPush(paths, p->left, coor, k); } else { if (p->left) HeapStackPush(paths, p->left, coor, k); if ((node_split_val - coor_split_val) < heap_top_val && p->right) HeapStackPush(paths, p->right, coor, k); } } } std::vector<std::tuple<size_t, double>> res; while (!k_neighbor_heap_.empty()) { res.emplace_back(k_neighbor_heap_.top()); k_neighbor_heap_.pop(); } return res; }
完整代码
详见 https://github.com/WiseDoge/libkdtree完整代码中除了 KD-Tree 的代码外,还给出了测试代码和 Python 接口代码,以及一些调用第三方库来加速的手段。
原文地址
http://www.jianshu.com/p/80e41da2a397相关文章推荐
- KD-Tree 算法的 C++ 实现
- 反距离权重法生成DEM(利用KD-tree实现KNN算法)
- 20170219C++项目班02_02递归下降算法/解析器/Scanner实现
- 【Coursera】Algorithms, Part I 算法C++实现: Quick Union
- 【算法和数据结构】分治思想之二分查找(C++实现)
- 组合算法 C++高效实现 (二进制辅助法)
- 人狼羊白菜过河问题算法,C++代码实现
- SIFT特征点匹配中KD-tree与Ransac算法的使用
- C++ 实现k-means machine learning 算法 Computer Vision
- 简单插入算法的C++实现
- 二叉树(Binary Tree)相关算法的实现
- [教程] 卡尔曼滤波简介及其算法实现代码(C++/C/MATLAB)
- 数据结构与算法——不相交集类的C++实现
- 算法代码实现之堆排序,C/C++实现
- 古典密码算法的设计与实现(C++实现)
- C++实现&nbsp;贪心算法-区间覆盖问题
- 【LeetCode-面试算法经典-Java实现】【114-Flatten Binary Tree to Linked List(二叉树转单链表)】
- 图像锐化算法 C++ 实现
- KMP模式匹配算法 C++实现
- 算法-最长子序列和C/C++实现(三个复杂度)