您的位置:首页 > 其它

梯度下降法的一个简单实验

2015-03-20 16:00 204 查看
import java.util.ArrayList;

import java.util.List;

/**使用梯度下降算法计算特征向量的权值**/

public class SolveEquations {

public double K[];//特征向量的权值

public SolveEquations(double[] K){

this.K=K;

}

public static double multiplication(double[] a,double[] b){

double rtn=0.0;

if(a.length!=b.length){

System.err.println("输入的两个数组维度不一致!");

throw new RuntimeException();

}

for(int i=0;i<a.length;i++){

rtn+=a[i]*b[i];

}

return rtn;

}

public double computePartialDerivative(List<INPUT> inputList,int j){//计算误差函数的偏导数

double rtn=0.0;

for (INPUT input : inputList) {

double temp=(multiplication(input.X, K)-input.Y)*input.X[j];

rtn+=temp;

}

return rtn;

}

public double J(List<INPUT> inputList){//误差估计函数

double rtn=0.0;

for (INPUT input : inputList) {

double temp=Math.pow(DimensionTool.multiplication(input.X, K),2);

rtn+=temp;

}

return rtn/2;

}

public void nextK(List<INPUT> inputList,double a){//输入参数为训练集合、迭代时的步长

for (int j=0;j<K.length;j++) {

double temp=a*computePartialDerivative(inputList, j);

K[j]=K[j]-temp;

}

}

public static void main(String[] args) {

//假想一个函数y=k0x0+k1*x1;其中k0=3,k1=4;

List<INPUT> inputList=new ArrayList<INPUT>();

double[] i0={0,0};

double[] i1={1,1};

double[] i2={1,2};

double[] i3={2,3};

double[] i4={3,4};

double[] i5={5,5};

//设定6个点,作为数据加入训练集合

inputList.add(new INPUT(i0, 0.0));

inputList.add(new INPUT(i1, 7.0));

inputList.add(new INPUT(i2, 11));

inputList.add(new INPUT(i3, 18.0));

inputList.add(new INPUT(i4, 25));

inputList.add(new INPUT(i5, 35));

//初始化特征权值(假想函数只有两个维度)为1.0,1.0

double K[]={1.0,1.0};

SolveEquations solveEquations=new SolveEquations(K);

//迭代一千次

for(int i=0;i<1000;i++){

solveEquations.nextK(inputList, 0.01);

//System.out.println(K[0]+" "+K[1]);

}

//输出通过梯度下降算法迭代出的k0和k1的值

for(int i=0;i<solveEquations.K.length;i++){

System.out.println(solveEquations.K[i]);

}

}

}

class INPUT{

double[] X;//一组X

double Y;//h(X)的值y

public INPUT(double X[],double Y){

this.X=X;

this.Y=Y;

}

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: