Tensorflow-变量保存与导入
2017-09-20 11:39
453 查看
1、基本用法
#!/usr/bin/python3 # -*- coding:utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf # import glob import numpy as np logdir='./output/' with tf.variable_scope('conv'): w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer) b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer) sess=tf.InteractiveSession() saver=tf.train.Saver([w]) # 参数为空,默认保存所有变量,此处只保存变量w tf.global_variables_initializer().run() # 初始化所有变量 # 验证之前是否已经保存了检查点文件 ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) tf.variables_initializer([b]) # 初始化变量b saver.save(sess,logdir+'model.ckpt') print('w',w.eval()) print('-----------') print('b',b.eval()) sess.close()
2、变量从ckpt中提取,没有的需初始化
第一运行#!/usr/bin/python3 # -*- coding:utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf # import glob import numpy as np from tensorflow.contrib.layers.python.layers import batch_norm import argparse logdir='./output/' with tf.variable_scope('conv'): w1=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer) b1=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer) # with tf.variable_scope('conv2'): # w2=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer) # b2=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer) sess=tf.InteractiveSession() tf.global_variables_initializer().run() # 初始化所有变量 # 验证之前是否已经保存了检查点文件 ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: try: saver = tf.train.Saver() # 参数为空,默认保存所有变量,这里只有变量w1、b1 saver.restore(sess, ckpt.model_checkpoint_path) saver=None except: saver = tf.train.Saver([w1,b1]) # 参数为空,默认保存所有变量,这里只有变量w1、b1 saver.restore(sess, ckpt.model_checkpoint_path) saver = None # tf.variables_initializer([b1]) # 初始化变量b saver=tf.train.Saver() # 参数为空,默认保存所有变量,这里只有变量w1、b1 saver.save(sess,logdir+'model.ckpt') print('w',w1.eval()) print('-----------') print('b',b1.eval()) print('-----------') # print('w',w2.eval()) # print('-----------') # print('b',b2.eval()) sess.close()
第二次运行
#!/usr/bin/python3 # -*- coding:utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf # import glob import numpy as np from tensorflow.contrib.layers.python.layers import batch_norm import argparse logdir='./output/' with tf.variable_scope('conv'): w1=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer) b1=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer) with tf.variable_scope('conv2'): w2=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer) b2=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer) sess=tf.InteractiveSession() tf.global_variables_initializer().run() # 初始化所有变量 # 验证之前是否已经保存了检查点文件 ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: try: saver = tf.train.Saver() # 参数为空,默认提取所有变量, saver.restore(sess, ckpt.model_checkpoint_path) saver = None except: saver = tf.train.Saver([w1, b1]) # 参数为空,默认提取所有变量, # 此处提取变量w1、b1(因为上步保存的变量没有w2,b2,如果使用saver = tf.train.Saver()会报错) saver.restore(sess, ckpt.model_checkpoint_path) saver=None # tf.variables_initializer([b1]) # 初始化变量b saver=tf.train.Saver() # 参数为空,默认保存所有变量,此处只保存所有变量,包括w2,b2 saver.save(sess,logdir+'model.ckpt') print('w',w1.eval()) print('-----------') print('b',b1.eval()) print('-----------') print('w',w2.eval()) print('-----------') print('b',b2.eval()) sess.close()
第3次运行
#!/usr/bin/python3 # -*- coding:utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf # import glob import numpy as np from tensorflow.contrib.layers.python.layers import batch_norm import argparse logdir='./output/' with tf.variable_scope('conv'): w1=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer) b1=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer) with tf.variable_scope('conv2'): w2=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer) b2=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer) sess=tf.InteractiveSession() tf.global_variables_initializer().run() # 初始化所有变量 # 验证之前是否已经保存了检查点文件 ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: # try: # saver = tf.train.Saver() # 参数为空,默认提取所有变量, # saver.restore(sess, ckpt.model_checkpoint_path) # saver = None # except: saver = tf.train.Saver([w1, b1]) # 上一步保存的变量有w1,b1,w2,b2,这里只提取w1,b1 saver.restore(sess, ckpt.model_checkpoint_path) saver=None # tf.variables_initializer([w2,b2]) # 初始化变量w2,b2 saver=tf.train.Saver() # 参数为空,默认保存所有变量,此处只保存所有变量,包括w2,b2 saver.save(sess,logdir+'model.ckpt') print('w',w1.eval()) print('-----------') print('b',b1.eval()) print('-----------') print('w',w2.eval()) print('-----------') print('b',b2.eval()) sess.close()
3、总结
保存变量tf.global_variables_initializer().run() # 初始化所有变量 saver=tf.train.Saver() # 参数为空,默认保存所有变量 saver=tf.train.Saver([w,b]) # 保存部分变量 saver.save(sess,logdir+'model.ckpt')
导入变量
tf.global_variables_initializer().run() # 初始化所有变量 # 验证之前是否已经保存了检查点文件 ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: try: saver = tf.train.Saver() # 参数为空,默认导入所有变量, saver.restore(sess, ckpt.model_checkpoint_path) saver = None except: saver = tf.train.Saver([w1, b1]) # 导入部分变量 saver.restore(sess, ckpt.model_checkpoint_path) saver=None
如果保存的变量有w1,b1,w2,b2,但只导入w1,b1,对w2,b2重新初始化,训练等 使用
saver = tf.train.Saver([w1, b1])
如果保存的变量中只有w1,b1,现在新增变量w2,b2 则只能导入w1,b1
saver = tf.train.Saver([w1, b1])
通过这种方法可以实现,只导入模型的前n-1层参数,而对第n层参数重新初始化训练,这样就能很好的实现迁移学习
4、补充 结合variable_scope
#!/usr/bin/python3 # -*- coding:utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf # import glob import numpy as np from tensorflow.contrib.layers.python.layers import batch_norm import argparse logdir='./output/' with tf.variable_scope('conv'): w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer) b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer) # with tf.variable_scope('conv2'): # w=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer) # b=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer) sess=tf.InteractiveSession() tf.global_variables_initializer().run() # 初始化所有变量 # 验证之前是否已经保存了检查点文件 ckpt = tf.train.get_checkpoint_state(logdir) if ckpt and ckpt.model_checkpoint_path: # try: # saver = tf.train.Saver() # 参数为空,默认提取所有变量, # saver.restore(sess, ckpt.model_checkpoint_path) # saver = None # except: saver = tf.train.Saver([tf.variable_op_scope(w,name_or_scope='conv/w:0').args[0], tf.variable_op_scope(b, name_or_scope='conv/b:0').args[0]]) # 上一步保存的变量有w1,b1,w2,b2,这里只提取w1,b1 saver.restore(sess, ckpt.model_checkpoint_path) saver=None # tf.variables_initializer([w2,b2]) # 初始化变量w2,b2 saver=tf.train.Saver() # 参数为空,默认保存所有变量,此处只保存所有变量,包括w2,b2 saver.save(sess,logdir+'model.ckpt') print('w',w.eval()) print('-----------') print('b',b.eval()) print('-----------') # print('w',w2.eval()) # print('-----------') # print('b',b2.eval()) sess.close()
说明:
with tf.variable_scope('conv'): w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer) b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer) with tf.variable_scope('conv2'): w=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer) b=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)
print('w',w.eval()) # 打印的是'conv2' 中的w print('-----------') print('b',b.eval())# 打印的是'conv2' 中的b
如果要打印的是’conv1’ 中的w,b
print('w',tf.variable_op_scope(w,name_or_scope='conv/w:0').args[0].eval()) print('-----------') print('b',tf.variable_op_scope(b, name_or_scope='conv/b:0').args[0].eval())
打印’conv2’ 中的w,b也可以使用
print('w',tf.variable_op_scope(w,name_or_scope='conv2/w:0').args[0].eval()) print('-----------') print('b',tf.variable_op_scope(b, name_or_scope='conv2/b:0').args[0].eval())
相关文章推荐
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- Tensorflow学习: 保存变量和网络
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow入门(三)--变量:创建、初始化、保存和加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- tensorflow保存部分变量
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- Tensorflow-pb保存与导入
- TensorFlow变量管理、保存和读取(持久化)
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- Tensorflow语法学习笔记(一):变量:创建、初始化、保存和加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- tensorflow对训练变量checkpoint的保存与读取
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- TensorFLow 入门 - 用Saver保存和恢复变量
- TensorFlow 教程 --进阶指南--3.2变量:创建、初始化、保存和加载
- TensorFlow 深度学习框架(7)-- 变量管理及训练模型的保存与加载
- tensorflow学习笔记(五):变量保存与导入