梯度下降法的一个简单实验
2015-03-20 16:00
204 查看
import java.util.ArrayList;
import java.util.List;
/**使用梯度下降算法计算特征向量的权值**/
public class SolveEquations {
public double K[];//特征向量的权值
public SolveEquations(double[] K){
this.K=K;
}
public static double multiplication(double[] a,double[] b){
double rtn=0.0;
if(a.length!=b.length){
System.err.println("输入的两个数组维度不一致!");
throw new RuntimeException();
}
for(int i=0;i<a.length;i++){
rtn+=a[i]*b[i];
}
return rtn;
}
public double computePartialDerivative(List<INPUT> inputList,int j){//计算误差函数的偏导数
double rtn=0.0;
for (INPUT input : inputList) {
double temp=(multiplication(input.X, K)-input.Y)*input.X[j];
rtn+=temp;
}
return rtn;
}
public double J(List<INPUT> inputList){//误差估计函数
double rtn=0.0;
for (INPUT input : inputList) {
double temp=Math.pow(DimensionTool.multiplication(input.X, K),2);
rtn+=temp;
}
return rtn/2;
}
public void nextK(List<INPUT> inputList,double a){//输入参数为训练集合、迭代时的步长
for (int j=0;j<K.length;j++) {
double temp=a*computePartialDerivative(inputList, j);
K[j]=K[j]-temp;
}
}
public static void main(String[] args) {
//假想一个函数y=k0x0+k1*x1;其中k0=3,k1=4;
List<INPUT> inputList=new ArrayList<INPUT>();
double[] i0={0,0};
double[] i1={1,1};
double[] i2={1,2};
double[] i3={2,3};
double[] i4={3,4};
double[] i5={5,5};
//设定6个点,作为数据加入训练集合
inputList.add(new INPUT(i0, 0.0));
inputList.add(new INPUT(i1, 7.0));
inputList.add(new INPUT(i2, 11));
inputList.add(new INPUT(i3, 18.0));
inputList.add(new INPUT(i4, 25));
inputList.add(new INPUT(i5, 35));
//初始化特征权值(假想函数只有两个维度)为1.0,1.0
double K[]={1.0,1.0};
SolveEquations solveEquations=new SolveEquations(K);
//迭代一千次
for(int i=0;i<1000;i++){
solveEquations.nextK(inputList, 0.01);
//System.out.println(K[0]+" "+K[1]);
}
//输出通过梯度下降算法迭代出的k0和k1的值
for(int i=0;i<solveEquations.K.length;i++){
System.out.println(solveEquations.K[i]);
}
}
}
class INPUT{
double[] X;//一组X
double Y;//h(X)的值y
public INPUT(double X[],double Y){
this.X=X;
this.Y=Y;
}
}
import java.util.List;
/**使用梯度下降算法计算特征向量的权值**/
public class SolveEquations {
public double K[];//特征向量的权值
public SolveEquations(double[] K){
this.K=K;
}
public static double multiplication(double[] a,double[] b){
double rtn=0.0;
if(a.length!=b.length){
System.err.println("输入的两个数组维度不一致!");
throw new RuntimeException();
}
for(int i=0;i<a.length;i++){
rtn+=a[i]*b[i];
}
return rtn;
}
public double computePartialDerivative(List<INPUT> inputList,int j){//计算误差函数的偏导数
double rtn=0.0;
for (INPUT input : inputList) {
double temp=(multiplication(input.X, K)-input.Y)*input.X[j];
rtn+=temp;
}
return rtn;
}
public double J(List<INPUT> inputList){//误差估计函数
double rtn=0.0;
for (INPUT input : inputList) {
double temp=Math.pow(DimensionTool.multiplication(input.X, K),2);
rtn+=temp;
}
return rtn/2;
}
public void nextK(List<INPUT> inputList,double a){//输入参数为训练集合、迭代时的步长
for (int j=0;j<K.length;j++) {
double temp=a*computePartialDerivative(inputList, j);
K[j]=K[j]-temp;
}
}
public static void main(String[] args) {
//假想一个函数y=k0x0+k1*x1;其中k0=3,k1=4;
List<INPUT> inputList=new ArrayList<INPUT>();
double[] i0={0,0};
double[] i1={1,1};
double[] i2={1,2};
double[] i3={2,3};
double[] i4={3,4};
double[] i5={5,5};
//设定6个点,作为数据加入训练集合
inputList.add(new INPUT(i0, 0.0));
inputList.add(new INPUT(i1, 7.0));
inputList.add(new INPUT(i2, 11));
inputList.add(new INPUT(i3, 18.0));
inputList.add(new INPUT(i4, 25));
inputList.add(new INPUT(i5, 35));
//初始化特征权值(假想函数只有两个维度)为1.0,1.0
double K[]={1.0,1.0};
SolveEquations solveEquations=new SolveEquations(K);
//迭代一千次
for(int i=0;i<1000;i++){
solveEquations.nextK(inputList, 0.01);
//System.out.println(K[0]+" "+K[1]);
}
//输出通过梯度下降算法迭代出的k0和k1的值
for(int i=0;i<solveEquations.K.length;i++){
System.out.println(solveEquations.K[i]);
}
}
}
class INPUT{
double[] X;//一组X
double Y;//h(X)的值y
public INPUT(double X[],double Y){
this.X=X;
this.Y=Y;
}
}
相关文章推荐
- 【实验】【PROCEDURE】一个最简单的oracle存储过程"proc_helloworld"【转】
- FLEX小实验 一个简单的时钟
- 演示:一个最简单的IPv6实验
- 做ssh key分发实验中的一个简单问题
- 一个简单的ns2实验全过程
- 在Eclipse下,采用mulan多标签分类软件进行一个简单的测试实验
- 关于COMMIT与ROLLBACK的一个简单实验
- 【实验 1-1】编写一个简单的 TCP 服务器和 TCP 客户端程序。程序均为控制台程序窗口。
- oracle实验31:使用PL/SQL,书写一个最简单的块
- 【iOS开发-50】利用创建新的类实现代码封装,从而不知不觉实践一个简单的MVC实验,附带个动画
- 一个简单地次优路径小实验
- 深入浅出FPGA-15-xilinx_zynq7000_EPP上一个简单实验(PL)
- Java--第十三周实验--任务0--编写一个简单的Java应用程序
- 【实验 1-2】编写一个简单的 UDP 服务器和 UDPP 客户端程序。程序均为控制台程序窗口。
- 关于Cisco一个简单实验拓扑配置搭建与配置
- 《第五周实验报告2-1》---设计一个简单的分数类,完成对分数的几个运算
- 一个简单的实验,Java数组遍历
- 【实验】【PROCEDURE】一个最简单的oracle存储过程"proc_helloworld"