您的位置:首页 > 其它

对正向传播、反向传播和精度检查的理解

2017-01-18 11:51 405 查看
1. 首先是概念,

// 格式化节点,每个节点中包含的计算值和梯度值
function Unit(value, grad) {
// 这个值是正向传播的值
this.value = value;
// 这个值是反向传播的值,
this.grad = grad;
}

function multiplyGate(){ };//乘法门
multiplyGate.prototype = {
forward: function(u0, u1) {
// 正向传播的输入值
this.u0 = u0;
this.u1 = u1;
this.utop = new Unit(u0.value * u1.value, 0.0);
return this.utop;//正向传播的输出值
},
backward: function() {
// u0 节点的梯度(导数)是输出节点的梯度(该梯度一个是反向传播计算来的,原理是链式法则), 乘以该乘法函数对u0的导数 , u0节点的梯度(导数)定义为 ax+bx+c 对 a 的求导
this.u0.grad += this.u1.value * this.utop.grad;
// u1 节点的梯度(导数)是输出节点的梯度(该梯度一个是反向传播计算来的,原理是链式法则), 乘以该乘法函数对u1的导数 , u1节点的梯度(导数)定义为 ax+bx+c 对 x 的求导
this.u1.grad += this.u0.value * this.utop.grad;
}
}

function addGate(){ };
addGate.prototype = {
forward: function(u0, u1) {
this.u0 = u0;
this.u1 = u1; //
this.utop = new Unit(u0.value + u1.value, 0.0);
return this.utop;
},
backward: function() {
// 加法函数对任意元素的导数都为1 ,那么在整个反向传播中, u0 的导数等于对整个函数的求导即 1 乘以 输出节点的导数
this.u0.grad += 1 * this.utop.grad;
this.u1.grad += 1 * this.utop.grad;
}
}

function Circuit() {
// 创建各类门(或者叫公式)
this.mulg0 = new multiplyGate();
this.mulg1 = new multiplyGate();
this.addg0 = new addGate();
this.addg1 = new addGate();
};
Circuit.prototype = {
forward: function(x,y,a,b,c) {
this.ax = this.mulg0.forward(a, x); // a*x
this.by = this.mulg1.forward(b, y); // b*y
this.axpby = this.addg0.forward(this.ax, this.by); // a*x + b*y
this.axpbypc = this.addg1.forward(this.axpby, c); // a*x + b*y + c
return this.axpbypc;
},
backward: function(gradient_top) { // 最后输出的梯度
this.axpbypc.grad = gradient_top;
this.addg1.backward(); // sets gradient in axpby and c
this.addg0.backward(); // sets gradient in ax and by
this.mulg1.backward(); // sets gradient in b and y
this.mulg0.backward(); // sets gradient in a and x
}

}

// SVM class
function SVM() {

// random initial parameter values
this.a = new Unit(1.0, 0.0);
this.b = new Unit(-2.0, 0.0);
this.c = new Unit(-1.0, 0.0);

this.circuit = new Circuit();
};
SVM.prototype = {
forward: function(x, y) { // assume x and y are Units
this.unit_out = this.circuit.forward(x, y, this.a, this.b, this.c);
return this.unit_out;
},
backward: function(label) { // label is +1 or -1

// reset pulls on a,b,c
this.a.grad = 0.0;
this.b.grad = 0.0;
this.c.grad = 0.0;

// compute the pull based on what the circuit output was
var pull = 0.0;
if(label === 1 && this.unit_out.value < 1) {
pull = 1; // the score was too low: pull up
}
if(label === -1 && this.unit_out.value > -1) {
pull = -1; // the score was too high for a positive example, pull down
}
this.circuit.backward(pull); // writes gradient into x,y,a,b,c

// add regularization pull for parameters: towards zero and proportional to value
this.a.grad += -this.a.value;
this.b.grad += -this.b.value;
},
learnFrom: function(x, y, label) {
this.forward(x, y); // forward pass (set .value in all Units)
this.backward(label); // backward pass (set .grad in all Units)
this.parameterUpdate(); // parameters respond to tug
},
parameterUpdate: function() {
var step_size = 0.01;
this.a.value += step_size * this.a.grad;
this.b.value += step_size * this.b.grad;
this.c.value += step_size * this.c.grad;
},
max:function(x){
var unit_out=Math.max(0, x);

return unit_out;
}
};

var data = []; var labels = [];
data.push([1.2, 0.7]); labels.push(1);
data.push([-0.3, -0.5]); labels.push(-1);
data.push([3.0, 0.1]); labels.push(1);
data.push([-0.1, -1.0]); labels.push(-1);
data.push([-1.0, 1.1]); labels.push(-1);
data.push([2.1, -3]); labels.push(1);
data.push([4.1, -0.1]); labels.push(1);
var svm = new SVM();

// the learning loop
for(var iter = 0; iter < 410; iter++) {//
// 随机取出数据点

var i = Math.floor(Math.random() * data.length);
var x = new Unit(data[i][0], 0.0);
var y = new Unit(data[i][1], 0.0);
var label = labels[i];
svm.learnFrom(x, y, label);

if(iter % 25 == 0) { // every 10 iterations...
console.log('training accuracy at iter ' + iter + ': ' + evalTrainingAccuracy(svm,data));
}
}

// 计算调整 ax+by+c 中的a,b,c 的参数使得 每个数据集预测的标签和标记的标签一致,如果所有数据集标签和标记的标签是一样的那么精度为1
function evalTrainingAccuracy (svm,data){//
var num_correct = 0;
for(var i = 0; i < data.length; i++) {
var x = new Unit(data[i][0], 0.0);
var y = new Unit(data[i][1], 0.0);
var true_label = labels[i];

// see if the prediction matches the provided label
var predicted_label = svm.forward(x, y).value > 0 ? 1 : -1;
if(predicted_label === true_label) {
num_correct++;
}
}

return num_correct / data.length;
};
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: