您的位置:首页 > 其它

Tensorflow-变量保存与导入

2017-09-20 11:39 453 查看

1、基本用法

#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
# import glob
import numpy as np

logdir='./output/'

with tf.variable_scope('conv'):
w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
sess=tf.InteractiveSession()
saver=tf.train.Saver([w]) # 参数为空,默认保存所有变量,此处只保存变量w
tf.global_variables_initializer().run() # 初始化所有变量

# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
tf.variables_initializer([b]) # 初始化变量b

saver.save(sess,logdir+'model.ckpt')
print('w',w.eval())
print('-----------')
print('b',b.eval())
sess.close()


2、变量从ckpt中提取,没有的需初始化

第一运行

#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
# import glob
import numpy as np
from tensorflow.contrib.layers.python.layers import batch_norm
import argparse

logdir='./output/'

with tf.variable_scope('conv'):
w1=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
b1=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
# with tf.variable_scope('conv2'):
#     w2=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
#     b2=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)

sess=tf.InteractiveSession()
tf.global_variables_initializer().run() # 初始化所有变量

# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
try:
saver = tf.train.Saver()  # 参数为空,默认保存所有变量,这里只有变量w1、b1
saver.restore(sess, ckpt.model_checkpoint_path)
saver=None
except:
saver = tf.train.Saver([w1,b1])  # 参数为空,默认保存所有变量,这里只有变量w1、b1
saver.restore(sess, ckpt.model_checkpoint_path)
saver = None
# tf.variables_initializer([b1]) # 初始化变量b

saver=tf.train.Saver() # 参数为空,默认保存所有变量,这里只有变量w1、b1
saver.save(sess,logdir+'model.ckpt')
print('w',w1.eval())
print('-----------')
print('b',b1.eval())
print('-----------')
# print('w',w2.eval())
# print('-----------')
# print('b',b2.eval())

sess.close()




第二次运行

#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
# import glob
import numpy as np
from tensorflow.contrib.layers.python.layers import batch_norm
import argparse

logdir='./output/'

with tf.variable_scope('conv'):
w1=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
b1=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
with tf.variable_scope('conv2'):
w2=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
b2=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)

sess=tf.InteractiveSession()

tf.global_variables_initializer().run() # 初始化所有变量

# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
try:
saver = tf.train.Saver()  # 参数为空,默认提取所有变量,
saver.restore(sess, ckpt.model_checkpoint_path)
saver = None
except:
saver = tf.train.Saver([w1, b1])  # 参数为空,默认提取所有变量,
# 此处提取变量w1、b1(因为上步保存的变量没有w2,b2,如果使用saver = tf.train.Saver()会报错)
saver.restore(sess, ckpt.model_checkpoint_path)
saver=None
# tf.variables_initializer([b1]) # 初始化变量b

saver=tf.train.Saver() # 参数为空,默认保存所有变量,此处只保存所有变量,包括w2,b2
saver.save(sess,logdir+'model.ckpt')
print('w',w1.eval())
print('-----------')
print('b',b1.eval())
print('-----------')
print('w',w2.eval())
print('-----------')
print('b',b2.eval())

sess.close()




第3次运行

#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
# import glob
import numpy as np
from tensorflow.contrib.layers.python.layers import batch_norm
import argparse

logdir='./output/'

with tf.variable_scope('conv'):
w1=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
b1=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
with tf.variable_scope('conv2'):
w2=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
b2=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)

sess=tf.InteractiveSession()

tf.global_variables_initializer().run() # 初始化所有变量

# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
# try:
#     saver = tf.train.Saver()  # 参数为空,默认提取所有变量,
#     saver.restore(sess, ckpt.model_checkpoint_path)
#     saver = None
# except:
saver = tf.train.Saver([w1, b1])  # 上一步保存的变量有w1,b1,w2,b2,这里只提取w1,b1
saver.restore(sess, ckpt.model_checkpoint_path)
saver=None
# tf.variables_initializer([w2,b2]) # 初始化变量w2,b2

saver=tf.train.Saver() # 参数为空,默认保存所有变量,此处只保存所有变量,包括w2,b2
saver.save(sess,logdir+'model.ckpt')
print('w',w1.eval())
print('-----------')
print('b',b1.eval())
print('-----------')
print('w',w2.eval())
print('-----------')
print('b',b2.eval())

sess.close()




3、总结

保存变量

tf.global_variables_initializer().run() # 初始化所有变量
saver=tf.train.Saver() # 参数为空,默认保存所有变量
saver=tf.train.Saver([w,b]) # 保存部分变量
saver.save(sess,logdir+'model.ckpt')


导入变量

tf.global_variables_initializer().run() # 初始化所有变量

# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
try:
saver = tf.train.Saver()  # 参数为空,默认导入所有变量,
saver.restore(sess, ckpt.model_checkpoint_path)
saver = None
except:
saver = tf.train.Saver([w1, b1])  # 导入部分变量
saver.restore(sess, ckpt.model_checkpoint_path)
saver=None


如果保存的变量有w1,b1,w2,b2,但只导入w1,b1,对w2,b2重新初始化,训练等 使用

saver = tf.train.Saver([w1, b1])


如果保存的变量中只有w1,b1,现在新增变量w2,b2 则只能导入w1,b1

saver = tf.train.Saver([w1, b1])


通过这种方法可以实现,只导入模型的前n-1层参数,而对第n层参数重新初始化训练,这样就能很好的实现迁移学习

4、补充 结合variable_scope

#!/usr/bin/python3
# -*- coding:utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
# import glob
import numpy as np
from tensorflow.contrib.layers.python.layers import batch_norm
import argparse

logdir='./output/'

with tf.variable_scope('conv'):
w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
# with tf.variable_scope('conv2'):
#     w=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
#     b=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)

sess=tf.InteractiveSession()

tf.global_variables_initializer().run() # 初始化所有变量

# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
# try:
#     saver = tf.train.Saver()  # 参数为空,默认提取所有变量,
#     saver.restore(sess, ckpt.model_checkpoint_path)
#     saver = None
# except:
saver = tf.train.Saver([tf.variable_op_scope(w,name_or_scope='conv/w:0').args[0],
tf.variable_op_scope(b, name_or_scope='conv/b:0').args[0]])  # 上一步保存的变量有w1,b1,w2,b2,这里只提取w1,b1
saver.restore(sess, ckpt.model_checkpoint_path)
saver=None
# tf.variables_initializer([w2,b2]) # 初始化变量w2,b2

saver=tf.train.Saver() # 参数为空,默认保存所有变量,此处只保存所有变量,包括w2,b2
saver.save(sess,logdir+'model.ckpt')
print('w',w.eval())
print('-----------')
print('b',b.eval())
print('-----------')
# print('w',w2.eval())
# print('-----------')
# print('b',b2.eval())

sess.close()


说明:

with tf.variable_scope('conv'):
w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
with tf.variable_scope('conv2'):
w=tf.get_variable('w',[1,2],tf.float32,initializer=tf.random_normal_initializer)
b=tf.get_variable('b',[1],tf.float32,initializer=tf.random_normal_initializer)


print('w',w.eval()) # 打印的是'conv2' 中的w
print('-----------')
print('b',b.eval())# 打印的是'conv2' 中的b


如果要打印的是’conv1’ 中的w,b

print('w',tf.variable_op_scope(w,name_or_scope='conv/w:0').args[0].eval())
print('-----------')
print('b',tf.variable_op_scope(b, name_or_scope='conv/b:0').args[0].eval())


打印’conv2’ 中的w,b也可以使用

print('w',tf.variable_op_scope(w,name_or_scope='conv2/w:0').args[0].eval())
print('-----------')
print('b',tf.variable_op_scope(b, name_or_scope='conv2/b:0').args[0].eval())
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: