tensorflow graph 中获取variable operation
2017-10-19 13:19
288 查看
tensorflow提供了一些列的方法获取和变量计算图中的variable和operation。
1. tf.Graph.get_tensor_by_name(tensor_name)
2. tf.Graph.get_operation_by_name(op_name)
1. graph.node
输出:
该方法列出了每个graph中每个node的详细信息。
2. graph.get_operations()
输出:
op.valuses()将返回该op对应的tensor对象,可以进一步获取tensor的name,shape等信息。
3. tf.all_variables()
输出:
该方法返回默认计算图中所有的variable()对象
4. tf.get_collection(collection_key)
输出:
该方法根据key返回相应collection中的对象。
tensorflow中预定义了一些grapykClass GraphKeys
除了预定义的collecion,tensorflow还支持自定义collection方法–tf.add_collection(key,value),tf.get_collection(key)。tf的collecion提供了一种全局的存储机制,不收命名空间影响。代码如下:
输出:
tf还提供了获取graph中所有collection的方法:
输出
获取单个operation/variable
可以通过如下两个方法获取图中的相关variable和operation:1. tf.Graph.get_tensor_by_name(tensor_name)
2. tf.Graph.get_operation_by_name(op_name)
批量获取
批量获取的方式主要有如下几种:1. graph.node
import tensorflow as tf # Create some variables. v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) v3 = tf.get_variable("v3", shape=[4], initializer = tf.zeros_initializer) inc_v1 = tf.assign(v1,v1+1,name='inc_v1') dec_v2 = tf.assign(v2,v2-1,name='dec_v2') dec_v3 = tf.assign(v3,v3-2,name='dec_v3') # Add an op to initialize the variables. init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, and save the # variables to disk. with tf.Session() as sess: sess.run(init_op) # Do some work with the model. inc_v1.op.run() dec_v2.op.run() dec_v3.op.run() for n in tf.get_default_graph().as_graph_def().node: print n
输出:
name: "v1/Initializer/zeros" op: "Const" attr { key: "_class" value { list { s: "loc:@v1" } } } attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "value" value { tensor { dtype: DT_FLOAT tensor_shape { dim { size: 3 } } float_val: 0.0 } } }
该方法列出了每个graph中每个node的详细信息。
2. graph.get_operations()
for op in tf.get_default_graph().get_operations(): print op.name print op.values()
输出:
name:v1/Initializer/zeros value:(<tf.Tensor 'v1/Initializer/zeros:0' shape=(3,) dtype=float32>,) name:v1 value:(<tf.Tensor 'v1:0' shape=(3,) dtype=float32_ref>,)
op.valuses()将返回该op对应的tensor对象,可以进一步获取tensor的name,shape等信息。
3. tf.all_variables()
import tensorflow as tf # Create some variables. v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) v3 = tf.get_variable("v3", shape=[4], initializer = tf.zeros_initializer) inc_v1 = tf.assign(v1,v1+1,name='inc_v1') dec_v2 = tf.assign(v2,v2-1,name='dec_v2') dec_v3 = tf.assign(v3,v3-2,name='dec_v3') # Add an op to initialize the variables. init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, and save the # variables to disk. with tf.Session() as sess: sess.run(init_op) # Do some work with the model. inc_v1.op.run() dec_v2.op.run() dec_v3.op.run() for variable in tf.all_variables(): print variable print variable.name
输出:
<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref> v1:0 <tf.Variable 'v2:0' shape=(5,) dtype=float32_ref> v2:0 <tf.Variable 'v3:0' shape=(4,) dtype=float32_ref> v3:0
该方法返回默认计算图中所有的variable()对象
4. tf.get_collection(collection_key)
for variable in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES): print variable
输出:
<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref> <tf.Variable 'v2:0' shape=(5,) dtype=float32_ref> <tf.Variable 'v3:0' shape=(4,) dtype=float32_ref>
该方法根据key返回相应collection中的对象。
tensorflow中预定义了一些grapykClass GraphKeys
Standard names to use for graph collections. The standard library uses various well-known names to collect and retrieve values associated with a graph. For example, the tf.Optimizer subclasses default to optimizing the variables collected under tf.GraphKeys.TRAINABLE_VARIABLES if none is specified, but it is also possible to pass an explicit list of variables. The following standard keys are defined: ● GLOBAL_VARIABLES: the default collection of Variable objects, shared across distributed environment (model variables are subset of these). See tf.global_variables for more details. Commonly, all TRAINABLE_VARIABLES variables will be in MODEL_VARIABLES, and all MODEL_VARIABLES variables will be in GLOBAL_VARIABLES. ● LOCAL_VARIABLES: the subset of Variable objects that are local to each machine. Usually used for temporarily variables, like counters. Note: use tf.contrib.framework.local_variable to add to this collection. ● MODEL_VARIABLES: the subset of Variable objects that are used in the model for inference (feed forward). Note: use tf.contrib.framework.model_variable to add to this collection. ● TRAINABLE_VARIABLES: the subset of Variable objects that will be trained by an optimizer. Seetf.trainable_variables for more details. ● SUMMARIES: the summary Tensor objects that have been created in the graph. See tf.summary.merge_all for more details. ● QUEUE_RUNNERS: the QueueRunner objects that are used to produce input for a computation. Seetf.train.start_queue_runners for more details. ● MOVING_AVERAGE_VARIABLES: the subset of Variable objects that will also keep moving averages. Seetf.moving_average_variables for more details. ● REGULARIZATION_LOSSES: regularization losses collected during graph construction. The following standard keys are defined, but their collections are not automatically populated as many of the others are: ● WEIGHTS ● BIASES ● ACTIVATIONS
除了预定义的collecion,tensorflow还支持自定义collection方法–tf.add_collection(key,value),tf.get_collection(key)。tf的collecion提供了一种全局的存储机制,不收命名空间影响。代码如下:
import tensorflow as tf # Create some variables. v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) v3 = tf.get_variable("v3", shape=[4], initializer = tf.zeros_initializer) inc_v1 = tf.assign(v1,v1+1,name='inc_v1') dec_v2 = tf.assign(v2,v2-1,name='dec_v2') dec_v3 = tf.assign(v3,v3-2,name='dec_v3') # Add an op to initialize the variables. init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, and save the # variables to disk. with tf.Session() as sess: sess.run(init_op) # Do some work with the model. inc_v1.op.run() dec_v2.op.run() dec_v3.op.run() # Add variable into tf.add_to_collection('test',v1) tf.add_to_collection('test',v2) tf.add_to_collection('test',inc_v1) for element in tf.get_collection('test'): print element
输出:
<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref> <tf.Variable 'v2:0' shape=(5,) dtype=float32_ref> Tensor("inc_v1:0", shape=(3,), dtype=float32_ref)
tf还提供了获取graph中所有collection的方法:
for key in tf.get_default_graph().get_all_collection_keys(): print 'key:'+key for element in tf.get_collection(key): print element
输出
key:variables
<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref> <tf.Variable 'v2:0' shape=(5,) dtype=float32_ref> <tf.Variable 'v3:0' shape=(4,) dtype=float32_ref>
key:trainable_variables
<tf.Variable 'v1:0' shape=(3,) dtype=float32_ref> <tf.Variable 'v2:0' shape=(5,) dtype=float32_ref> <tf.Variable 'v3:0' shape=(4,) dtype=float32_ref>
相关文章推荐
- tensorflow API简单整理(四、Graph,Operation&Tensor)
- hands on machine learning with sklearn and tensorflow 附录B-扩充整理 2关于获取数据
- How to optimize for inference a simple, saved TensorFlow 1.0.1 graph?
- Tensorflow GraphDef pb 文件读和写 (binary format text format, )
- Tensorflow——Graph Tensor Session 等基本概念汇总
- 【tensorflow】打印Tensorflow graph中的所有变量--tf.trainable_variables()
- tensorflow default graph分析与说明
- Purpose Of Multiple Graphs In Tensorflow
- 理解 tensorflow graph
- Tensorflow Session graph Op 的理解
- How to write mutiple graphs in a run in tensorflow
- Loading a TensorFlow graph with the C++ API
- C++ API载入tensorflow graph
- ValueError: Tensorflow error: “Tensor must be from the same graph as Tensor…”
- native: tensorflow_jni.cc:153 Could not create TensorFlow Graph: Not found: Op type not registered '
- TensorFlow Object Detection API 跑代码过程
- tensorflow DCGAN 源码中 conv_cond_concat函数
- Ubuntu install Tensorflow1.4 Anaconda GPU Guide
- Win10 64位下GPU版本MXNet+Tensorflow 1.3.0的安装
- The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available...