torch入门笔记18: Torch 中的引用、深拷贝 以及 getParameters 获取参数的探讨
2016-11-02 18:28
471 查看
reference: http://blog.csdn.net/u010167269/article/details/52073136 一.关于 Torch 中赋值引用、深拷贝的问题二.关于
getParameters()获取参数引发的问题。Torch 中的输入 Tensor,经过一个网络层之后,输出的结果,其 Size 会随着之后输入数据 Size 的变化而变化。这个问题,我在知乎上提问了:https://www.zhihu.com/question/48986099。中科院自动化所的博士@feanfrog 给我做了解答。在这里,我再总结一下。输入一个数据,比如说随机生成:
x1 = torch.randn(3, 128, 128))经过一网络
convNet,如卷积层:<span style="font-family:microsoft yahei;">convNet = nn.SpatialConvolution(3,64, 3,3, 1,1, 1,1))</span>
<span style="font-family: "microsoft yahei"; background-color: rgb(255, 255, 255);"></span><span style="font-family: "microsoft yahei"; background-color: rgb(255, 255, 255);">进行 </span><code style="font-family: "Source Code Pro", monospace; padding: 2px 4px; font-size: 12.6px; color: rgb(63, 63, 63); white-space: nowrap; background-color: rgb(255, 255, 255);">forward</code><span style="font-family: "microsoft yahei"; background-color: rgb(255, 255, 255);"> 之后,其输出结果为 </span><code style="font-family: "Source Code Pro", monospace; padding: 2px 4px; font-size: 12.6px; color: rgb(63, 63, 63); white-space: nowrap; background-color: rgb(255, 255, 255);">y1</code><span style="font-family: "microsoft yahei"; background-color: rgb(255, 255, 255);">,其 </span><code style="font-family: "Source Code Pro", monospace; padding: 2px 4px; font-size: 12.6px; color: rgb(63, 63, 63); white-space: nowrap; background-color: rgb(255, 255, 255);">Size</code><span style="font-family: "microsoft yahei"; background-color: rgb(255, 255, 255);">: </span>64×128×128 . 但输出的这个 Size 大小,会随着之后的输入数据的 Size 的变化而变化!这是诡异的地方……如又输入:
x2 = torch.randn(5, 3, 128, 128)这个
x2经过
convNet:forward:
>y2 = convNet:forward(x2)>y2:size()564128128[torch.LongStorage of size 4]这个
y2的
size为 5×64×128×128 ,这是应该的。 但是请看
y1:size():
y1的
size也变成了 5×64×128×128 !原来 Torch 中为了提高速度,
model:forward()操作之后赋予的变量是不给这个变量开盘新的存储空间的,而是 引用。就相当于 起了个别名。不光这里,torch里面向量或是矩阵的赋值是指向同一内存的,这种策略不同于 Matlab。如果想不想引用,可以用
clone()进行 深拷贝,如下的例子: 当改变变量
v的第一个元素的值时,变量
t也随之变化。
第二个坑: Torch 中 getParameters 获取参数引起的疑惑
当有如下的代码:require 'nn'local convNet = nn.Sequential()convNet:add(nn.Linear(2, 3))convNet:add(nn.Tanh())local convNet2 = convNet:clone('weight', 'bias', 'gradWeight', 'gradBias')local params, gradParams = convNet:getParameters()params:fill(0)print(convNet2:get(1).weight)此时的输出为: 感觉输出结果很显然的样子。但当将上述的代码做一下微调:
require 'nn'local convNet = nn.Sequential()convNet:add(nn.Linear(2, 3))convNet:add(nn.Tanh())local params, gradParams = convNet:getParameters()local convNet2 = convNet:clone('weight', 'bias', 'gradWeight', 'gradBias')params:fill(0)print(convNet2:get(1).weight)输出的结果为: 仅仅将代码中下面的两行做了对调:
local convNet2 = convNet:clone('weight', 'bias', 'gradWeight', 'gradBias')local params, gradParams = convNet:getParameters()结果就是两种不同的结果。 看一下别人的解释:This is not really a bug though.
getParameters()isabit subtle, and should be documented properly.It gets and flattens all the parameters of any given module, and insures that the set of parameters, as well as all the sharing in place within that module, remains consistent.In the example you show, you’re grabbing the parameters of ‘convNet’, but
getParameters()doesn’tknow about the external convNet2. So sharing will be lost.我自己的理解是: 在第一段代码中,顺序是:
local convNet2 = convNet:clone('weight', 'bias', 'gradWeight', 'gradBias')local params, gradParams = convNet:getParameters()是先进行的拷贝
clone(),是 深拷贝。
convNet2与
convNet并不是同一个存储。之后再
getParameters(要所有的参数 拉平)的时候,已经不关
convNet2的事了。 这时候再通过
params:fill(0)赋值的时候(因为
getParameters()得到的只是参数的引用,与原先参数指向的同一块内存,所以可以通过
params:fill(0)这种方式给
convNet网络赋值),对
convNet2已经没有影响了。所以
convNet2保持原先的值。而第二段代码,顺序是:
local params, gradParams = convNet:getParameters()local convNet2 = convNet:clone('weight', 'bias', 'gradWeight', 'gradBias')注意这时候,
convNet中的存储结构,已经被
getParameters()函数给 拉平 了,相当于是
convNet的结构已经被破坏了。 因为在官方的
Module文档中的
getParameters()函数这块,有这么一句话:This function will go over all the weights and gradWeights and make them view into a single tensor (one for weights and one for gradWeights). Since the storage of every weightand gradWeight is changed, this function should be called only once on a given network.下面
convNet2再
clone('weight',...)这样拷贝,已经失效了。所以,实际上,这时候
convNet2进行的所谓的 深拷贝,并不是真正的 深拷贝,而是 失效的深拷贝 。下面我们可以通过加一句话验证一下,上面的深拷贝是失效的:
require 'nn'local convNet = nn.Sequential()convNet:add(nn.Linear(2, 3))convNet:add(nn.Tanh())local params, gradParams = convNet:getParameters()local convNet2 = convNet:clone('weight', 'bias', 'gradWeight', 'gradBias')-- 加上下面这一句, 这一句的拷贝不指定拷贝的参数, 如 weight, bias 这些-- 而是默认的进行深拷贝local convNet3 = convNet:clone()params:fill(0)print(convNet2:get(1).weight)print('------------------------')print(convNet3:get(1).weight)我们在中间加了一句:
local convNet3 = convNet:clone(),自动的进行深拷贝,而不是指定参数。看输出结果: 看到了吗?!
convNet3的参数与
convNet不是同一块存储地址,深拷贝成功。而
convNet2的深拷贝失效,所以当
params:fill(0)的时候,
convNet2的参数也变了。但
convNet3的深拷贝成功!总结一下: 实验证明,我的猜想是成功的。由于
getParameters()获取参数使得
convNet的网络参数被 拉平 了,所以
convNet2的深拷贝方式就已经失效了,
convNet2本质上跟
convNet还是共用的一块内存地址:
local convNet2 = convNet:clone('weight', 'bias', 'gradWeight', 'gradBias')反而不指定参数的
convNet3的深拷贝方式,反而保持有效:
local convNet3 = convNet:clone()
Reference
推荐一个博客,我在这个博客中也找到了 Torch 里这个坑的叙述,写的也不错:http://blog.csdn.net/hungryof?viewmode=contents,以及 学习Torch框架的该看的资料汇总(不断更新)Torch 的 Github Issues 版块上的讨论:Problem with getParameters?Google groups 中 Torch 板块的一篇讨论贴: Triplet Net, Parallel table and weights sharing最后,nn Modules的官方文档还是得多读读:https://github.com/torch/nn/blob/master/doc/module.md
相关文章推荐
- Torch 中的引用、深拷贝 以及 getParameters 获取参数的探讨
- SQL获取所有数据库名、表名、储存过程以及参数列表
- 深度学习第二课 改善深层神经网络:超参数调试、正则化以及优化 第二周Mini_batch+优化算法 笔记和作业
- (18)servletContext应用:获取web应用的初始化参数、实现servlet转发、利用servletContext对象读取资源文件
- OC_语法入门_day5_内存管理_计数器/set方法/property的参数/循环引用/自动释放池
- restlet2.1 学习笔记(四) 获取、返回XML类型参数
- SPSS(|PASW)18 学习笔记(1):入门示例-克山病例
- unity3D-游戏/AR/VR在线就业班 C#入门值类型和引用类型学习笔记
- oc开发笔记3 录音时频率获取 以及声像显示
- 【Halcon笔记1】基于Halcon软件的【摄像机标定】以及【内部参数】和【外部参数】的求解过程【原理细节详解】
- javascript入门系列演示·函数的定义以及简单参数使用,调用函数
- 【学习笔记】day1_快速入门 14_电话拨号器定义布局&获取组件对象
- [原创]java WEB学习笔记98:Spring学习---Spring Bean配置及相关细节:如何在配置bean,Spring容器(BeanFactory,ApplicationContext),如何获取bean,属性赋值(属性注入,构造器注入),配置bean细节(字面值,包含特殊字符,引用bean,null值,集合属性list map propert),util 和p 命名空间
- TestNG入门笔记[4]: testng.xml 执行case —— 参数的传递
- 指针作函数参数,引用作函数参数以及内存释放
- PhalApi框架脱坑笔记(二:get请求的参数获取)
- sublime text3入门笔记以及屏蔽sublime自动升级检测更新
- SQL获取所有数据库名、表名、储存过程以及参数列表
- c++中参数传递的三种方式,以及用法。传值,传址,传引用
- 使用jquery获取url以及jquery获取url参数的方法