您的位置:首页 > 运维架构

tf.name_scope and tf.variable_scope

2017-11-02 16:13 344 查看
作者:C Li

链接:https://www.zhihu.com/question/54513728/answer/181819324

来源:知乎

著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

在 tf.name_scope下时,tf.get_variable()创建的变量名不受 name_scope 的影响,而且在未指定共享变量时,如果重名会报错,tf.Variable()会自动检测有没有变量重名,如果有则会自行处理。

import tensorflow as tf

with tf.name_scope('name_scope_x'):
var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
var3 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)
var4 = tf.Variable(name='var2', initial_value=[2], dtype=tf.float32)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(var1.name, sess.run(var1))
print(var3.name, sess.run(var3))
print(var4.name, sess.run(var4))
# 输出结果:
# var1:0 [-0.30036557]   可以看到前面不含有指定的'name_scope_x'
# name_scope_x/var2:0 [ 2.]
# name_scope_x/var2_1:0 [ 2.]  可以看到变量名自行变成了'var2_1',避免了和'var2'冲突


如果使用tf.get_variable()创建变量,且没有设置共享变量,重名时会报错

import tensorflow as tf

with tf.name_scope('name_scope_1'):
var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
var2 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(var1.name, sess.run(var1))
print(var2.name, sess.run(var2))

# ValueError: Variable var1 already exists, disallowed. Did you mean
# to set reuse=True in VarScope? Originally defined at:
# var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)


所以要共享变量,需要使用tf.variable_scope()

import tensorflow as tf

with tf.variable_scope('variable_scope_y') as scope:
var1 = tf.get_variable(name='var1', shape=[1], dtype=tf.float32)
scope.reuse_variables()  # 设置共享变量
var1_reuse = tf.get_variable(name='var1')
var2 = tf.Variable(initial_value=[2.], name='var2', dtype=tf.float32)
var2_reuse = tf.Variable(initial_value=[2.], name='var2', dtype=tf.float32)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(var1.name, sess.run(var1))
print(var1_reuse.name, sess.run(var1_reuse))
print(var2.name, sess.run(var2))
print(var2_reuse.name, sess.run(var2_reuse))
# 输出结果:
# variable_scope_y/var1:0 [-1.59682846]
# variable_scope_y/var1:0 [-1.59682846]   可以看到变量var1_reuse重复使用了var1
# variable_scope_y/var2:0 [ 2.]
# variable_scope_y/var2_1:0 [ 2.]


也可以这样

with tf.variable_scope('foo') as foo_scope:
v = tf.get_variable('v', [1])
with tf.variable_scope('foo', reuse=True):
v1 = tf.get_variable('v')
assert v1 == v


或者这样:

with tf.variable_scope('foo') as foo_scope:
v = tf.get_variable('v', [1])
with tf.variable_scope(foo_scope, reuse=True):
v1 = tf.get_variable('v')
assert v1 == v
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: