您的位置:首页 > 其它

tensorflow 变量创建,初始化,共享

2017-10-20 15:18 591 查看

创建变量

创建变量最好使用**tf.get_variable(“name”,shape, dtype=tf.int32,

initializer=tf.zeros_initializer)** ,特点有2:

1. 必须提供变量名,使得graph的定义更加规范

2. 可以复用变量

tf.varaible()和tf.get_variable()区别:

tf.variable()会自动解决命名重复的问题,tf.get_variable()如果在不指定reuse的情况下名字冲突会报错。

import tensorflow as tf
w_1 = tf.Variable(3,name="w_1")
w_2 = tf.Variable(1,name="w_1")
print w_1.name
print w_2.name
#输出
#w_1:0
#w_1_1:0

w_1 = tf.get_variable(name="w_1",initializer=1)
w_2 = tf.get_variable(name="w_1",initializer=2)
#错误:ValueError: Variable w_1 already exists, disallowed. Did
you mean to set reuse=True in VarScope?


tf中提供了collecion机制以保证可以全局的存取保存的变量,默认变量会被存储在 tf.GraphKeys.GLOBAL_VARIABLEStf.GraphKeys.TRAINABLE_VARIABLES 两个系统预定义collection中。

变量初始化

tf中变量在使用之前必须初始化,可以通过如下方式进行初始化:

session.run(tf.global_variables_initializer())


该方法可以一次性初始化tf.GraphKeys.GLOBAL_VARIABLES中的所有变量,但是该方法初始化变量的顺序不可控!对于有依赖关系的变量初始化要慎重。

也可自己控制初始化:

session.run(my_variable.initializer)


tf提供了print(session.run(tf.report_uninitialized_variables()))来检查为初始化的变量。

变量共享

exp1:

import tensorflow as tf

with tf.name_scope('test_name_scope'):
initializer=tf.constant_initializer(value=1)
var1=tf.get_variable(name='var',shape=[1],dtype=tf.float32,initializer=initializer)
var21=tf.Variable(name='var',dtype=tf.float32,initial_value=[1.])
var22=tf.Variable(name='var1',dtype=tf.float32,initial_value=[2.])

with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print(var1.name)
print(sess.run(var1))
print(var21.name)
print(sess.run(var21))
print(var22.name)
print(sess.run(var22))


输出:

var:0
[ 1.]
test_name_scope/var:0
[ 1.]
test_name_scope/var1:0
[ 2.]


tf.get_variable()创建的变量名字不受name_scope()的影响,tf.Variable()创建的变量会带上name_scope()的前缀,并且Variable()会自动处理重名问题。

exp2:

with tf.variable_scope('test_name_scope'):
initializer=tf.constant_initializer(value=1)
var11=tf.get_variable(name='var',shape=[1],dtype=tf.float32,initializer=initializer)
var12=tf.get_variable(name='var',shape=[1],dtype=tf.float32,initializer=initializer)
var21=tf.Variable(name='var',dtype=tf.float32,initial_value=[1.])
var22=tf.Variable(name='var1',dtype=tf.float32,initial_value=[2.])

with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print(var11.name)
print(sess.run(var11))
print(var12.name)
print(sess.run(var12))
print(var21.name)
print(sess.run(var21))
print(var22.name)
print(sess.run(var22))


结果:

错误: Variable test_name_scope/var already exists


修改后:

with tf.variable_scope('test_variable_scope') as scope:
initializer=tf.constant_initializer(value=1)
var11=tf.get_variable(name='var',shape=[1],dtype=tf.float32,initializer=initializer)
# Variable sharing method #1:call for scope.resuse_variables() directly.
scope.reuse_variables()
var12=tf.get_variable(name='var')
var21=tf.Variable(name='var',dtype=tf.float32,initial_value=[1.])
var22=tf.Variable(name='var1',dtype=tf.float32,initial_value=[2.])

with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print(var11.name)
print(sess.run(var11))
print(var12.name)
print(sess.run(var12))
print(var21.name)
print(sess.run(var21))
print(var22.name)
print(sess.run(var22))


也可改成如下形式

with tf.variable_scope('test_variable_scope') as scope:
initializer=tf.constant_initializer(value=1)
var11=tf.get_variable(name='var',shape=[1],dtype=tf.float32,initializer=initializer)
var21=tf.Variable(name='var',dtype=tf.float32,initial_value=[1.])
var22=tf.Variable(name='var1',dtype=tf.float32,initial_value=[2.])

# Variable sharing method #2:create a scope with same name and have resuse set to True
with tf.variable_scope('test_variable_scope',reuse=True)as scope:
initializer=tf.constant_initializer(value=1)
var12=tf.get_variable(name='var',shape=[1],dtype=tf.float32,initializer=initializer)
var21=tf.Variable(name='var',dtype=tf.float32,initial_value=[1.])
var22=tf.Variable(name='var1',dtype=tf.float32,initial_value=[2.])

with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
print(var11.name)
print(sess.run(var11))
print(var12.name)
print(sess.run(var12))
print(var21.name)
print(sess.run(var21))
print(var22.name)
print(sess.run(var22))


输出

test_variable_scope/var:0
[ 1.]
test_variable_scope/var:0
[ 1.]
test_variable_scope/var_1:0
[ 1.]
test_variable_scope/var1:0
[ 2.]


varibale_scope()会使得变量的名字都带上scop前缀,get_variable()在不指定复用的情况下,遇到重名变量会报错。with variable_scope(name)可以在不同的地方用来在同一个name scope中创建变量。

exp3

import tensorflow as tf

def conv_relu(input, kernel_shape, bias_shape):
# Create variable named "weights".
weights = tf.get_variable("weights", kernel_shape,
initializer=tf.random_normal_initializer())
# Create variable named "biases".
biases = tf.get_variable("biases", bias_shape,
initializer=tf.constant_initializer(0.0))
conv = tf.nn.conv2d(input, weights,
strides=[1, 1, 1, 1], padding='SAME')
return tf.nn.relu(conv + biases)

input1 = tf.random_normal([1,10,10,32])
input2 = tf.random_normal([1,20,20,32])
x = conv_relu(input1, kernel_shape=[5, 5, 32, 32], bias_shape=[32])
x = conv_relu(x, kernel_shape=[5, 5, 32, 32], bias_shape = [32])  # This fails.

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())


输出

错误:Variable weights already exists


第二次调用conv_relu会是计算图中出现重名变量”weights”和”biases”,因此报错。

参考资料

https://www.tensorflow.org/programmers_guide/variables
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息