您的位置:首页 > 编程语言 > Java开发

统计学习方法之感知机对偶形式Java实现代码

2016-12-19 22:08 681 查看
理论部分请参照李航博士的统计学习方法一书
Point类表示需要分类的样本点
package com.czb.ganzhiji;

public class Point {

double x[]=new double[2];
double y;

public Point(){

}

public Point(double x[],double y){
this.x=x;
this.y=y;
}

}

/**
* 感知机对偶形式的代码
*/
package com.czb.ganzhiji;

import java.util.ArrayList;
import java.util.Arrays;

public class Ganzhiji2 {
private double w[];
private double b=0;
private double a[];
private double eta;
ArrayList<Point> arrayList;

public Ganzhiji2(ArrayList<Point> arrayList,double eta){
this.arrayList=arrayList;
w=new double[arrayList.get(0).x.length];
a=new double[arrayList.size()];
this.eta=eta;
}

public Ganzhiji2(ArrayList<Point> arrayList){
this.arrayList=arrayList;
w=new double[arrayList.get(0).x.length];
a=new double[arrayList.size()];
this.eta=1;
}

private double f(double x1[],double x2[]){//进行两个向量的内积计算
double sum=0;
for(int i=0;i<x1.length;i++){
sum=sum+x1[i]*x2[i];
}
return sum;
}

private double g(ArrayList<Point> arrayList,int m){//用来判断模型
double sum=0;
for(int i=0;i<arrayList.size();i++){
sum=sum+a[i]*arrayList.get(i).y*f(arrayList.get(i).x, arrayList.get(m).x);
}
return arrayList.get(m).y*(sum+b);
}

private void h(ArrayList<Point> arrayList,int m){//用来更新a和b
a[m]=a[m]+eta;
b=b+arrayList.get(m).y;

System.out.print(a[0]+" "+a[1]+" "+a[2]+" "+b);
System.out.println();
}

private void classify(){
boolean flag=false;

while(!flag){
for(int i=0;i<arrayList.size();i++){
if(g(arrayList, i)<=0){
h(arrayList, i);
break;
}
if(i==arrayList.size()-1){
flag=true;
}
}
}
for(int i=0;i<arrayList.size();i++){
double temp1=a[i]*arrayList.get(i).y;

for(int j=0;j<arrayList.get(0).x.length;j++){
if(j==0)
w[j]+=arrayList.get(i).x[j]*temp1;
else
w[j]+=arrayList.get(i).x[j]*temp1;
}

}

System.out.println(Arrays.toString(w));
System.out.println(b);
}

public static void main(String[] args) {
Point point1=new Point(new double[]{3, 3},1);
Point point2=new Point(new double[]{4, 3},1);
Point point3=new Point(new double[]{1, 1},-1);

ArrayList<Point> arrayList=new ArrayList<>();
arrayList.add(point1);
arrayList.add(point2);
arrayList.add(point3);

Ganzhiji2 ganzhiji2=new Ganzhiji2(arrayList);
ganzhiji2.classify();
}

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息