您的位置:首页 > 其它

libsvm使用简介

2013-09-06 16:32 225 查看
libsvm是support vector machine的一种开源实现,采用了smo算法。源代码编写有独到之处,值得一睹。

常用结构

svm_node结构

定义了构成输入特征向量的元素,index为索引(= -1为最后一个元素),value为值,

public class svm_node implements java.io.Serializable
{
public int index;
public double value;
}


借鉴了稀疏矩阵的表示方法。对于一个输入向量,定义为svm_node构成的一维数组

svm_node[] pa = {pa0, pa1};


所有输入序列有一个二维数组表示

svm_node[][] datas = {pa, pb};


标记序列

就是一个double数组,对应于输入序列datas的每一维。

double[] labels = {1.0, -1.0};


svm_problem结构

定义了(X, Y)的训练样本结构

public class svm_problem implements java.io.Serializable
{
public int l;
public double[] y;
public svm_node[][] x;
}


其中l是样本数量。

svm_parameter结构

定义了训练时的重要参数

public class svm_parameter implements Cloneable,java.io.Serializable
{
/* svm_type */
public static final int C_SVC = 0;
public static final int NU_SVC = 1;
public static final int ONE_CLASS = 2;
public static final int EPSILON_SVR = 3;
public static final int NU_SVR = 4;

/* kernel_type */
public static final int LINEAR = 0;
public static final int POLY = 1;
public static final int RBF = 2;
public static final int SIGMOID = 3;
public static final int PRECOMPUTED = 4;

public int svm_type;
public int kernel_type;
public int degree;    // for poly
public double gamma;    // for poly/rbf/sigmoid
public double coef0;    // for poly/sigmoid

// these are for training only
public double cache_size; // in MB
public double eps;    // stopping criteria
public double C;    // for C_SVC, EPSILON_SVR and NU_SVR
public int nr_weight;        // for C_SVC
public int[] weight_label;    // for C_SVC
public double[] weight;        // for C_SVC
public double nu;    // for NU_SVC, ONE_CLASS, and NU_SVR
public double p;    // for EPSILON_SVR
public int shrinking;    // use the shrinking heuristics
public int probability; // do probability estimates

public Object clone()
{
try
{
return super.clone();
} catch (CloneNotSupportedException e)
{
return null;
}
}

}


主要分为两大类参数:分类器的核函数性质和训练算法SMO的一些参数,包括精度啊等等

训练

通过调用svm.svm_train()训练模型

public static svm_model svm_train(svm_problem prob, svm_parameter param)


返回svm_model类对象表示训练得到的分类器

预测

通过svm.svm_predict()利用分类器进行预测

public static double svm_predict(svm_model model, svm_node[] x)


返回类别标记

实例代码如下,输入点pa = (10.0 10.0) ya = 1.0 pb = (-10.0, -10.0) yb = -1.0

测试点 (-0.1, 0)

import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;

public class SvmTest {
public static void main(String[] args) {

svm_node pa0 = new svm_node();
pa0.index = 0;
pa0.value = 10.0;

svm_node pa1 = new svm_node();
pa1.index = -1;
pa1.value = 10.0;

svm_node pb0 = new svm_node();
pb0.index = 0;
pb0.value = -10.0;

svm_node pb1 = new svm_node();
pb1.index = -1;
pb1.value = -10.0;

svm_node[] pa = {pa0, pa1};
svm_node[] pb = {pb0, pb1};

svm_node[][] datas = {pa, pb};

double[] labels = {1.0, -1.0};

svm_problem problem = new svm_problem();
problem.l = 2;
problem.x = datas;
problem.y = labels;

svm_parameter param = new svm_parameter();
param.svm_type = svm_parameter.C_SVC;
param.kernel_type = svm_parameter.LINEAR;
param.cache_size = 100;
param.eps = 0.00001;
param.C = 1;

System.out.println(svm.svm_check_parameter(problem, param));
svm_model model = svm.svm_train(problem, param);

svm_node pc0 = new svm_node();
pc0.index = 0;
pc0.value = -0.1;
svm_node pc1 = new svm_node();
pc1.index = -1;
pc1.value = 0;

svm_node[] pc = {pc0, pc1};

System.out.println(svm.svm_predict(model, pc));
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: