您的位置:首页 > 编程语言 > Java开发

感知机学习算法实现

2015-05-20 18:38 274 查看
(1)感知机是二类分类的线性分类模型,其输入是实例的特征向量,输出为实例的类型。

          感知机学习的目的是将训练数据进行线性划分的分离超平面,属于是判别模型。感知机学习采用基于误分类的损失函数,利用梯度下降法对损失函数进行极小化,最后求得感知机模型。

          强调:利用感知机学习策略进行训练的数据集要求必须是线性可分的,否则感知机学习算法不会收敛,迭代结果会发生震荡。

         本文对感知机的学习和实现是参考:《统计学习方法》一书中第二章,旨在实现书中提到的算法,并且通过实际编码加深理解。

(2)书中例子的算法实现(针对二维空间)

//感知机学习算法
//感知机模型 f(x) = sign(w*x+b)
public class Perceptron
{
int m_learnRate;//学习率
int m_w0;
int m_w1;
int m_b;

public Perceptron(int w0,int w1, int b0, int learnRate)
{
this.m_b = b0;
this.m_learnRate = learnRate;
this.m_w0 = w0;
this.m_w1 = w1;
}

/*
* 判断针对训练数据x 估测的模型与实际数据是否有误差
*/
private boolean judgeHasError(int[] x)
{
//如果表达式小于0,说明没有被正确分类
if((x[2]*(this.m_w0*x[0]+this.m_w1*x[1]+this.m_b)) <= 0)
return false;
return true;
}

/*
* 有误差的话,需要调整模型参数
*/
private void adjustParam(int[] x)
{
//根据梯度下降法调整参数 w b
this.m_w0 = this.m_w0 + this.m_learnRate*x[2]*x[0];
this.m_w1 = this.m_w1 + this.m_learnRate*x[2]*x[1];
this.m_b = this.m_b + this.m_learnRate*x[2];
return ;
}

public void TrainData(int data[][], int num) throws InterruptedException
{
int count = 0;
boolean isOver = false;
while(!isOver)
{
System.out.println("w0 w1 b: "+this.m_w0+" "+this.m_w1+" "+this.m_b);

for(int i=0; i<num; ++i)
{
if(!judgeHasError(data[i]))
{
System.out.println(i+"调整次数:"+(++count));
adjustParam(data[i]);
isOver = false;
break;
}
else
isOver = true;
}
}
//
System.out.println("w0 w1 b: "+this.m_w0+" "+this.m_w1+" "+this.m_b);
}

public static void main(String args[])
{
//data数组中包括 正实例点和负实例点,其中数组中最后一位元素代表其为何种实例点(1代表正实例,-1代表负实例)
//训练数据一共包括三组,前两组是正实例
 int data[][] = {{3,3,1},{4,3,1},{1,1,-1}};<pre name="code" class="java" style="white-space: pre-wrap; word-wrap: break-word;">        Perceptron  p = new Perceptron(0,0,0,1);
p.TrainData(data, 3);
}
}



                                            
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  java机器学习