您的位置:首页 > 大数据 > 人工智能

LibSVM 3.12的源码分析Svm-train.c

2014-07-16 14:59 369 查看


注:本文非笔者原创,原文转载自:

http://jacoxu.com/?p=133

共涉及3个文件: Svm-train.c, Svm.cpp, Svm.h. 建议使用Source Insight软件对这3个文件建立工程. 方便代码阅读. 下面从Svm-train.c文件中的main()函数切入.

int main(int argc, <span

class="keyword">char **argv)

{

char input_file_name[1024]; //训练样本文件名

char model_file_name[1024]; //输出模型的文件名

const char *error_msg;

parse_command_line(argc, argv, input_file_name, model_file_name); //解析运行程序时,命令行输入的参数

read_problem(input_file_name); //读入训练样本,存入到struct
svm_problem prob结构体中

error_msg = svm_check_parameter(&prob,¶m); //检查训练样本数据格式是否正确

if(error_msg)

{

fprintf(stderr,“ERROR:
%s\n”,error_msg);

exit(1);

}

if(cross_validation)

{

do_cross_validation(); //根据设置进行交叉验证训练

}

else

{

model = svm_train(&prob,¶m); //根据问题数据(&prob)和参数(¶m)训练模型

if(svm_save_model(model_file_name,model))//保存模型到输出

文件中

{

fprintf(stderr, “can’t
save model to file %s\n”, model_file_name);

exit(1);

}

svm_free_and_destroy_model(&model); //释放模型结构空间

}

svm_destroy_param(¶m); //释放使用的其他结构空间

free(prob.y);

free(prob.x);

free(x_space);

free(line);

return 0;

}

下面分析一下main()函数中调用的主要函数程序, 命令行参数解析函数parse_command_line()代码及其注释如下:

void parse_command_line(int argc, <span

class="keyword">char **argv, char *input_file_name,char

*model_file_name)

{

int i;

void (*print_func)(const <span

class="keyword">char*) = NULL; //
default printing to stdout

//
default values

param.svm_type = C_SVC;

param.kernel_type = RBF;

param.degree = 3;

param.gamma = 0; //
1/num_features

param.coef0 = 0;

param.nu = 0.5;

param.cache_size = 100;

param.C = 1;

param.eps = 1e-3;

param.p = 0.1;

param.shrinking = 1;

param.probability = 0;

param.nr_weight = 0;

param.weight_label = NULL;

param.weight = NULL;

cross_validation = 0;

//
parse options

for(i=1;i<argc;i++) //argc中存放的是命令行程序运行时的参数

个数

{

if(argv[i][0]
!= ‘-’) break; <span

class="comment">//开头处是否为参数类型标识,若不是跳出循环

if(++i>=argc) //判断参数类型后是否有其他参数,如样本文件名

exit_with_help(); //如果没有则退出并打印帮助提示

switch(argv[i-1][1]) //根据参数标识,转换参数值为正确类型或相应设置

{

case ‘s’:

param.svm_type = atoi(argv[i]);

break;

case ‘t’:

param.kernel_type = atoi(argv[i]);

break;

case ‘d’:

param.degree = atoi(argv[i]);

break;

case ‘g’:

param.gamma = atof(argv[i]);

break;

case ‘r’:

param.coef0 = atof(argv[i]);

break;

case ‘n’:

param.nu = atof(argv[i]);

break;

case ‘m’:

param.cache_size = atof(argv[i]);

break;

case ‘c’:

param.C = atof(argv[i]);

break;

case ‘e’:

param.eps = atof(argv[i]);

break;

case ‘p’:

param.p = atof(argv[i]);

break;

case ‘h’:

param.shrinking = atoi(argv[i]);

break;

case ‘b’:

param.probability = atoi(argv[i]);

break;

case ‘q’:

print_func = &print_null;

i–;

break;

case ‘v’: //设置交叉验证的参数标识

cross_validation = 1;

nr_fold = atoi(argv[i]);

if(nr_fold
< 2)

{

fprintf(stderr,“n-fold
cross validation: n must >= 2\n”);

exit_with_help();

}

break;

case ‘w’:

++param.nr_weight;

param.weight_label = (int*)realloc(param.weight_label,<span

class="keyword">sizeof(int)*param.nr_weight);

param.weight = (double*)realloc(param.weight,<span

class="keyword">sizeof(double)*param.nr_weight);

param.weight_label[param.nr_weight-1] = atoi(&argv[i-1][2]);

param.weight[param.nr_weight-1] = atof(argv[i]);

break;

default:

fprintf(stderr,“Unknown
option: -%c\n”, argv[i-1][1]);

exit_with_help();

}

}

svm_set_print_string_function(print_func);

//
determine filenames

if(i>=argc)

exit_with_help();

strcpy(input_file_name, argv[i]); //将命令行中的训练文件名,赋值给main中的字符数组.

if(i<argc-1) //如果自定义了输出模型名,则赋值给变量,否则使用默认命名方式

构造文件名

strcpy(model_file_name,argv[i+1]);

else

{

char *p
= strrchr(argv[i],’/');

if(p==NULL)

p = argv[i];

else

++p;

sprintf(model_file_name,“%s.model”,p);

}

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息