您的位置:首页 > 其它

torch 的 forward 和 backward

2018-01-01 11:31 253 查看
Criterions有其forward和backward函数

https://github.com/torch/nn/blob/master/doc/criterion.md

Module也有其forward和backward函数

https://github.com/torch/nn/blob/master/doc/module.md

Module的forward函数最简单,就是输入input得到output

Module的backward看下这个线性回归的例子

require 'torch'
require 'nn'
require 'gnuplot'

month = torch.range(1,10)
price = torch.Tensor{28993,29110,29436,30791,33384,36762,39900,39972,40230,40146}

model = nn.Linear(1, 1)
criterion = nn.MSECriterion()

month_train = month:reshape(10,1)
price_train = price:reshape(10,1)

for i=1,1000 do
price_predict = model:forward(month_train) -- 输入 -> 输出
err = criterion:forward(price_predict, price_train) -- 输出,正确 -> loss值
print(i, err)
model:zeroGradParameters()
gradient = criterion:backward(price_predict, price_train) -- 输出,正确 -> 梯度
model:backward(month_train, gradient) -- 输入,梯度
model:updateParameters(0.01)
end

month_predict = torch.range(1,12)
local price_predict = model:forward(month_predict:reshape(12,1))
print(price_predict)

gnuplot.pngfigure('plot.png')
gnuplot.plot({month, price}, {month_predict, price_predict})
gnuplot.plotflush()
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: