LDA理解以及源码分析(二)
2015-12-09 17:31
387 查看
LDA系列的讲解分多个博文给出,主要大纲如下:
LDA相关的基础知识
什么是共轭
multinomial分布
Dirichlet分布
LDA in text
LAD的概率图模型
LDA的参数推导
伪代码
GibbsLDA++-0.2源码分析
Python实现GibbsLDA
参考资料
工具包里docs文件夹里有说明文件GibbsLDA++Manual.pdf,按照要求编译就可以使用,很方便。(具体使用方法后面给出)
代码在文件夹src中,主要有这么几个类:dataset,model, strtokenizer,utils以及lda.cpp, constants.h文件。
dataset
utils
strtokenizer
model类是最重要的类,实现核心代码
下面详细分析estimate()函数和sampling()函数
LDA相关的基础知识
什么是共轭
multinomial分布
Dirichlet分布
LDA in text
LAD的概率图模型
LDA的参数推导
伪代码
GibbsLDA++-0.2源码分析
Python实现GibbsLDA
参考资料
GibbsLDA++-0.2源码分析
GibbsLDA++-0.2工具包下载地址为:下载工具包里docs文件夹里有说明文件GibbsLDA++Manual.pdf,按照要求编译就可以使用,很方便。(具体使用方法后面给出)
代码在文件夹src中,主要有这么几个类:dataset,model, strtokenizer,utils以及lda.cpp, constants.h文件。
dataset
//两个全局变量,分别存储word和id的对应。 typedef map<string, int> mapword2id; typedef map<int, string> mapid2word; //类document class document { public: //保存每个word对应的id int * words; string rawstr; //文章的words总数 int length; document() {} document(int length) {} document(int length, int * words) {} document(int length, int * words, string rawstr) {} document(vector<int> & doc) {} document(vector<int> & doc, string rawstr) {} ~document() {} }; class dataset { public: document ** docs; document ** _docs; // used only for inference map<int, int> _id2id; // also used only for inference int M; // documents总数 int V; // words总数 dataset() {} dataset(int M) {} ~dataset() {} void deallocate() {} void add_doc(document * doc, int idx) {} void _add_doc(document * doc, int idx) {} //根据pword2id写wordmap,文件每行都是“word id”的格式 static int write_wordmap(string wordmapfile, mapword2id * pword2id); //读wordmap中的内容,存储到pword2id中 static int read_wordmap(string wordmapfile, mapword2id * pword2id); static int read_wordmap(string wordmapfile, mapid2word * pid2word); int read_trndata(string dfile, string wordmapfile);//读训练文件 int read_newdata(string dfile, string wordmapfile);//读推断的新文件 int read_newdata_withrawstrs(string dfile, string wordmapfile); };
utils
class utils { public: // 解析命令行参数 static int parse_args(int argc, char ** argv, model * pmodel); // 读<model_name>.others文件并解析模型参数 static int read_and_parse(string filename, model * model); // 为当前的迭代生成模型名字,命令行有个参数会指定什么时候需要保存模型,iter=-1时最后才保存 static string generate_model_name(int iter); // 排序 static void sort(vector<double> & probs, vector<int> & words); static void quicksort(vector<pair<int, double> > & vect, int left, int right); };
strtokenizer
class strtokenizer { protected: vector<string> tokens; //存储分后的词 int idx;//tokens的索引 public: strtokenizer(string str, string seperators = " ");//对str按照seperators分词,存在tokens中 void parse(string str, string seperators);//同上 int count_tokens(); //返回tokens的大小 string next_token(); //返回idx当前索引的token值 void start_scan(); //idx = 0 string token(int i); // 返回tokens[i]的值 };
model类是最重要的类,实现核心代码
class model { public: // fixed options string wordmapfile; // file that contains word map [string -> integer id] string trainlogfile; // training log file string tassign_suffix; // suffix for topic assignment file string theta_suffix; // suffix for theta file string phi_suffix; // suffix for phi file string others_suffix; // suffix for file containing other parameters string twords_suffix; // suffix for file containing words-per-topics string dir; // model directory string dfile; // data file string model_name; // model name int model_status; // model status:本代码提供est,estc和inf三种模式 dataset * ptrndata; // 指针,指向训练文档集 dataset * pnewdata; // 指针,指向推断的新文档集 mapid2word id2word; // word map [int => string] /*下面是模型参数和变量*/ int M; // documents数量 int V; // words数量(不重复) int K; // topics数量 double alpha, beta; // LDA超参数 int niters; // Gibbs采样迭代的次数 int liter; // 需要将模型保存为文件的迭代次数 int savestep; // saving period int twords; // 输出每个topic的前twords个词 int withrawstrs; /*下面是LDA核心算法中的变量*/ double * p; // 采样的临时变量 int ** z; // M x doc.size(), 文档中words的topic分布 int ** nw; // V x K,nw[i][j]词i在主题j上出现的次数 int ** nd; // M x K,nd[i][j]: 文章i中属于主题j的词的个数 int * nwsum; // K, nwsum[j]: 属于主题j的词的数量 int * ndsum; // M, ndsum[i]: 文章i中词的数量 double ** theta; // M x K, document-topic分布 double ** phi; // K x V, topic-word分布 /*下面的变量只有在推断时用到*/ int inf_liter; int newM; int newV; int ** newz; int ** newnw; int ** newnd; int * newnwsum; int * newndsum; double ** newtheta; double ** newphi; // -------------------------------------- model() { ~model(); // 对变量设置初值,构造函数中调用 void set_default_values(); // 解析命令行得到LDA参数 int parse_args(int argc, char ** argv); // 初始化模型,调用后面的init()设置初值 int init(int argc, char ** argv); // 加载已有的LDA模型继续训练或者推断 int load_model(string model_name); // 保存LDA模型文件 int save_model(string model_name); int save_model_tassign(string filename); int save_model_theta(string filename); int save_model_phi(string filename); int save_model_others(string filename); int save_model_twords(string filename); // 保存inf的结果 int save_inf_model(string model_name); int save_inf_model_tassign(string filename); int save_inf_model_newtheta(string filename); int save_inf_model_newphi(string filename); int save_inf_model_others(string filename); int save_inf_model_twords(string filename); // 训练模型前的初始化,init()函数中会调用 int init_est(); int init_estc(); // LDA核心算法!!!下面详细介绍 void estimate(); int sampling(int m, int n); //计算迭代之后的theta和phi void compute_theta(); void compute_phi(); // 推断初始化,在init()中调用 int init_inf(); // 在已知LDA模型中,对未知的新文档进行推断 void inference(); int inf_sampling(int m, int n); void compute_newtheta(); void compute_newphi(); };
下面详细分析estimate()函数和sampling()函数
void model::estimate() { if (twords > 0) {//如果要求输出主题下的词,则读wordmap文件排序后输出 dataset::read_wordmap(dir + wordmapfile, &id2word); } printf("Sampling %d iterations!\n", niters); int last_iter = liter;//记录迭代次数 for (liter = last_iter + 1; liter <= niters + last_iter; liter++) { printf("Iteration %d ...\n", liter); // 对所有的z[m] 采样一个主题 for (int m = 0; m < M; m++) { for (int n = 0; n < ptrndata->docs[m]->length; n++) { // sampling根据Gibbs采样公式p(z_i|z_-i, w)采样 int topic = sampling(m, n); z[m] = topic; } } //判断是否需要保存这次迭代的模型 if (savestep > 0) { if (liter % savestep == 0) { printf("Saving the model at iteration %d ...\n", liter); compute_theta(); compute_phi(); save_model(utils::generate_model_name(liter)); } } } //迭代结束,根据theta和phi的公式计算值,并且保存模型文件 printf("Gibbs sampling completed!\n"); printf("Saving the final model!\n"); compute_theta(); compute_phi(); liter--; save_model(utils::generate_model_name(-1)); }
int model::sampling(int m, int n) { //采样算法,排除当前词,根据其他词的主题分布计算当前词的分布 int topic = z[m] ;//获得当前主题 int w = ptrndata->docs[m]->words ; //获得当前词的id //首先,排除当前词,即涉及的统计变量个数减一 nw[w][topic] -= 1; nd[m][topic] -= 1; nwsum[topic] -= 1; ndsum[m] -= 1; //临时计算变量 double Vbeta = V * beta; double Kalpha = K * alpha; /*通过累加来采样*/ //计算当前词在每个主题下的概率 for (int k = 0; k < K; k++) { p[k] = (nw[w][k] + beta) / (nwsum[k] + Vbeta) * (nd[m][k] + alpha) / (ndsum[m] + Kalpha); } // 概率和累加 for (int k = 1; k < K; k++) { p[k] += p[k - 1]; } //随机生成小数 double u = ((double)random() / RAND_MAX) * p[K - 1]; //根据随机生成的值,采样主题编号 for (topic = 0; topic < K; topic++) { if (p[topic] > u) { break; } } //新的topic对应的参数增加1 nw[w][topic] += 1; nd[m][topic] += 1; nwsum[topic] += 1; ndsum[m] += 1; return topic; }
相关文章推荐
- 解决Eclipse官网下的自带Eclipse编辑器不能自动代码提示的问题。
- android中树形json解析为对象,并通过dialog显示,多级列表
- 带有控制信息的单链表
- 【个人重构】数据库设计(1)
- java获取当前ip
- pip装了一个包,但是python里Import的时候找不到怎么办?
- Zookeeper基本原理
- Spark Streaming和Kafka整合开发指南(二)
- oracle更改字符集AL32UTF8为ZHS16GBK
- Android NDK: 怎么减少APK大小
- new/delete 和malloc/free 的区别
- linux 和 window 的EOF
- 在没有导航控制器的情况下,如何实现页面的跳转
- Java设置session超时(失效)的三种方式
- 网络编程之HTTP
- UICollectionView可移动item
- mybitis学习笔记
- 深入理解C#(一)
- Kylin 大数据时代的OLAP利器
- OpenCV获取与设置像素点的值的几个方法