您的位置:首页 > 编程语言 > Java开发

LDA模型学习之(三)走过的弯路

2012-03-15 21:47 357 查看
为了把LDA算法用于文本聚类,我真的是绞尽脑汁。除了去看让我头大的概率论、随机过程、高数这些基础的数学知识,还到网上找已经实现的源代码。

最先让我看到署光的是Mallet,我研究了大概一个星期,最后决定放弃了。因为Mallet作者提供的例子实在太少了。

回到了网上找到的这样一段源代码:

/*
* (C) Copyright 2005, Gregor Heinrich (gregor :: arbylon : net) (This file is
* part of the org.knowceans experimental software packages.)
*/
/*
* LdaGibbsSampler is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License as published by the Free
* Software Foundation; either version 2 of the License, or (at your option) any
* later version.
*/
/*
* LdaGibbsSampler is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*/
/*
* You should have received a copy of the GNU General Public License along with
* this program; if not, write to the Free Software Foundation, Inc., 59 Temple
* Place, Suite 330, Boston, MA 02111-1307 USA
*/

/*
* Created on Mar 6, 2005
*/
package com.xh.lda;

import java.text.DecimalFormat;
import java.text.NumberFormat;

/**
* Gibbs sampler for estimating the best assignments of topics for words and
* documents in a corpus. The algorithm is introduced in Tom Griffiths' paper
* "Gibbs sampling in the generative model of Latent Dirichlet Allocation"
* (2002).
*
* @author heinrich
*/
public class LdaGibbsSampler {

/**
* document data (term lists)
*/
int[][] documents;

/**
* vocabulary size
*/
int V;

/**
* number of topics
*/
int K;

/**
* Dirichlet parameter (document--topic associations)
*/
double alpha;

/**
* Dirichlet parameter (topic--term associations)
*/
double beta;

/**
* topic assignments for each word.
*/
int z[][];

/**
* cwt[i][j] number of instances of word i (term?) assigned to topic j.
*/
int[][] nw;

/**
* na[i][j] number of words in document i assigned to topic j.
*/
int[][] nd;

/**
* nwsum[j] total number of words assigned to topic j.
*/
int[] nwsum;

/**
* nasum[i] total number of words in document i.
*/
int[] ndsum;

/**
* cumulative statistics of theta
*/
double[][] thetasum;

/**
* cumulative statistics of phi
*/
double[][] phisum;

/**
* size of statistics
*/
int numstats;

/**
* sampling lag (?)
*/
private static int THIN_INTERVAL = 20;

/**
* burn-in period
*/
private static int BURN_IN = 100;

/**
* max iterations
*/
private static int ITERATIONS = 1000;

/**
* sample lag (if -1 only one sample taken)
*/
private static int SAMPLE_LAG;

private static int dispcol = 0;

/**
* Initialise the Gibbs sampler with data.
*
* @param V
*            vocabulary size
* @param data
*/
public LdaGibbsSampler(int[][] documents, int V) {

this.documents = documents;
this.V = V;
}

/**
* Initialisation: Must start with an assignment of observations to topics ?
* Many alternatives are possible, I chose to perform random assignments
* with equal probabilities
*
* @param K
*            number of topics
* @return z assignment of topics to words
*/
public void initialState(int K) {
int i;

int M = documents.length;

// initialise count variables.
nw = new int[V][K];
nd = new int[M][K];
nwsum = new int[K];
ndsum = new int[M];

// The z_i are are initialised to values in [1,K] to determine the
// initial state of the Markov chain.

z = new int[M][];
for (int m = 0; m < M; m++) {
int N = documents[m].length;
z[m] = new int
;
for (int n = 0; n < N; n++) {
int topic = (int) (Math.random() * K);
z[m]
= topic;
// number of instances of word i assigned to topic j
nw[documents[m]
][topic]++;
// number of words in document i assigned to topic j.
nd[m][topic]++;
// total number of words assigned to topic j.
nwsum[topic]++;
}
// total number of words in document i
ndsum[m] = N;
}
}

/**
* Main method: Select initial state ? Repeat a large number of times: 1.
* Select an element 2. Update conditional on other elements. If
* appropriate, output summary for each run.
*
* @param K
*            number of topics
* @param alpha
*            symmetric prior parameter on document--topic associations
* @param beta
*            symmetric prior parameter on topic--term associations
*/
public void gibbs(int K, double alpha, double beta) {
this.K = K;
this.alpha = alpha;
this.beta = beta;

// init sampler statistics
if (SAMPLE_LAG > 0) {
thetasum = new double[documents.length][K];
phisum = new double[K][V];
numstats = 0;
}

// initial state of the Markov chain:
initialState(K);

System.out.println("Sampling " + ITERATIONS
+ " iterations with burn-in of " + BURN_IN + " (B/S="
+ THIN_INTERVAL + ").");

for (int i = 0; i < ITERATIONS; i++) {

// for all z_i
for (int m = 0; m < z.length; m++) {
for (int n = 0; n < z[m].length; n++) {

// (z_i = z[m]
)
// sample from p(z_i|z_-i, w)
int topic = sampleFullConditional(m, n);
z[m]
= topic;
}
}

if ((i < BURN_IN) && (i % THIN_INTERVAL == 0)) {
//                System.out.print("B");
dispcol++;
}
// display progress
if ((i > BURN_IN) && (i % THIN_INTERVAL == 0)) {
//                System.out.print("S");
dispcol++;
}
// get statistics after burn-in
if ((i > BURN_IN) && (SAMPLE_LAG > 0) && (i % SAMPLE_LAG == 0)) {
updateParams();
//                System.out.print("|");
if (i % THIN_INTERVAL != 0)
dispcol++;
}
if (dispcol >= 100) {
//                System.out.println();
dispcol = 0;
}
}
}

/**
* Sample a topic z_i from the full conditional distribution: p(z_i = j |
* z_-i, w) = (n_-i,j(w_i) + beta)/(n_-i,j(.) + W * beta) * (n_-i,j(d_i) +
* alpha)/(n_-i,.(d_i) + K * alpha)
*
* @param m
*            document
* @param n
*            word
*/
private int sampleFullConditional(int m, int n) {

// remove z_i from the count variables
int topic = z[m]
;
nw[documents[m]
][topic]--;
nd[m][topic]--;
nwsum[topic]--;
ndsum[m]--;

// do multinomial sampling via cumulative method:
double[] p = new double[K];
for (int k = 0; k < K; k++) {
p[k] = (nw[documents[m]
][k] + beta) / (nwsum[k] + V * beta)
* (nd[m][k] + alpha) / (ndsum[m] + K * alpha);
}
// cumulate multinomial parameters
for (int k = 1; k < p.length; k++) {
p[k] += p[k - 1];
}
// scaled sample because of unnormalised p[]
double u = Math.random() * p[K - 1];
for (topic = 0; topic < p.length; topic++) {
if (u < p[topic])
break;
}

// add newly estimated z_i to count variables
nw[documents[m]
][topic]++;
nd[m][topic]++;
nwsum[topic]++;
ndsum[m]++;

return topic;
}

/**
* Add to the statistics the values of theta and phi for the current state.
*/
private void updateParams() {
for (int m = 0; m < documents.length; m++) {
for (int k = 0; k < K; k++) {
thetasum[m][k] += (nd[m][k] + alpha) / (ndsum[m] + K * alpha);
}
}
for (int k = 0; k < K; k++) {
for (int w = 0; w < V; w++) {
phisum[k][w] += (nw[w][k] + beta) / (nwsum[k] + V * beta);
}
}
numstats++;
}

/**
* Retrieve estimated document--topic associations. If sample lag > 0 then
* the mean value of all sampled statistics for theta[][] is taken.
*
* @return theta multinomial mixture of document topics (M x K)
*/
public double[][] getTheta() {
double[][] theta = new double[documents.length][K];

if (SAMPLE_LAG > 0) {
for (int m = 0; m < documents.length; m++) {
for (int k = 0; k < K; k++) {
theta[m][k] = thetasum[m][k] / numstats;
}
}

} else {
for (int m = 0; m < documents.length; m++) {
for (int k = 0; k < K; k++) {
theta[m][k] = (nd[m][k] + alpha) / (ndsum[m] + K * alpha);
}
}
}

return theta;
}

/**
* Retrieve estimated topic--word associations. If sample lag > 0 then the
* mean value of all sampled statistics for phi[][] is taken.
*
* @return phi multinomial mixture of topic words (K x V)
*/
public double[][] getPhi() {
System.out.println("K is:"+K+",V is:"+V);
double[][] phi = new double[K][V];
if (SAMPLE_LAG > 0) {
for (int k = 0; k < K; k++) {
for (int w = 0; w < V; w++) {
phi[k][w] = phisum[k][w] / numstats;
}
}
} else {
for (int k = 0; k < K; k++) {
for (int w = 0; w < V; w++) {
phi[k][w] = (nw[w][k] + beta) / (nwsum[k] + V * beta);
}
}
}
return phi;
}
/**
* Configure the gibbs sampler
*
* @param iterations
*            number of total iterations
* @param burnIn
*            number of burn-in iterations
* @param thinInterval
*            update statistics interval
* @param sampleLag
*            sample interval (-1 for just one sample at the end)
*/
public void configure(int iterations, int burnIn, int thinInterval,
int sampleLag) {
ITERATIONS = iterations;
BURN_IN = burnIn;
THIN_INTERVAL = thinInterval;
SAMPLE_LAG = sampleLag;
}

/**
* Driver with example data.
*
* @param args
*/
public static void main(String[] args) {

// words in documents
int[][] documents = {
{1, 4, 3, 2, 3, 1, 4, 3, 2, 3, 1, 4, 3, 2, 3, 6},
{2, 2, 4, 2, 4, 2, 2, 2, 2, 4, 2, 2},
{1, 6, 5, 6, 0, 1, 6, 5, 6, 0, 1, 6, 5, 6, 0, 0},
{5, 6, 6, 2, 3, 3, 6, 5, 6, 2, 2, 6, 5, 6, 6, 6, 0},
{2, 2, 4, 4, 4, 4, 1, 5, 5, 5, 5, 5, 5, 1, 1, 1, 1, 0},
{5, 4, 2, 3, 4, 5, 6, 6, 5, 4, 3, 2},

};

// vocabulary
int V = 7;
int M = documents.length;
// # topics
int K = 2;
// good values alpha = 2, beta = .5
double alpha = 2;
double beta = .5;

System.out.println("Latent Dirichlet Allocation using Gibbs Sampling.");

LdaGibbsSampler lda = new LdaGibbsSampler(documents, V);
lda.configure(10000, 2000, 100, 10);
lda.gibbs(K, alpha, beta);//用gibbs抽样

double[][] theta = lda.getTheta();//Theta是我们所希望的一种分布可能
double[][] phi = lda.getPhi();

System.out.println();
System.out.println();
System.out.println("Document--Topic Associations, Theta[d][k] (alpha="
+ alpha + ")");
System.out.print("d\\k\t");
for (int m = 0; m < theta[0].length; m++) {
System.out.print("   " + m % 10 + "    ");
}
System.out.println();
for (int m = 0; m < theta.length; m++) {
System.out.print(m + "\t");
for (int k = 0; k < theta[m].length; k++) {
System.out.print(theta[m][k] + " ");
//                System.out.print(shadeDouble(theta[m][k], 1) + " ");
}
System.out.println();
}
System.out.println();
System.out.println("Topic--Term Associations, Phi[k][w] (beta=" + beta
+ ")");

System.out.print("k\\w\t");
for (int w = 0; w < phi[0].length; w++) {
System.out.print("   " + w % 10 + "    ");
}
System.out.println();
for (int k = 0; k < phi.length; k++) {
System.out.print(k + "\t");
for (int w = 0; w < phi[k].length; w++) {
System.out.print(phi[k][w] + " ");
//                System.out.print(shadeDouble(phi[k][w], 1) + " ");
}
System.out.println();
}
}

}

代码中关于数学部分我现在依然没有弄懂,但是先能用着再说吧。

// vocabulary

int V = 7;// 表示所有的文档中词汇的总数为7

int M = documents.length;//表示文档的总个数

// # topics

int K = 2;//如果用于聚类,表示类簇的个数:主题的个数

// good values alpha = 2, beta = .5

下面两个是LDA模型的参数,可以先不用管。

double alpha = 2;

double beta = .5;

我用的做法是:文本分词后对词进行统计,然后给词编号。这样就可以把文档

转化成了document矩阵了!
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息