tensorflow tf.train.Supervisor作用
2017-06-28 08:24
543 查看
tf.train.Supervisor可以简化编程,避免显示地实现restore操作.通过一个例子看.
这段代码是对tensorflow官网上的demo做一个微小的改动.如果模型已经存在,就先读取模型接着训练.tf.train.Supervisor可以简化这个步骤.看下面的代码.
sv = tf.train.Supervisor(logdir=log_path, init_op=init)会判断模型是否存在.如果存在,会自动读取模型.不用显式地调用restore.
tensorflow学习笔记(二十二):Supervisor
import tensorflow as tf import numpy as np import os log_path = r"D:\Source\model\linear" log_name = "linear.ckpt" # Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3 x_data = np.random.rand(100).astype(np.float32) y_data = x_data * 0.1 + 0.3 # Try to find values for W and b that compute y_data = W * x_data + b # (We know that W should be 0.1 and b 0.3, but TensorFlow will # figure that out for us.) W = tf.Variable(tf.random_uniform([1], -1.0, 1.0)) b = tf.Variable(tf.zeros([1])) y = W * x_data + b # Minimize the mean squared errors. loss = tf.reduce_mean(tf.square(y - y_data)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) # Before starting, initialize the variables. We will 'run' this first. saver = tf.train.Saver() init = tf.global_variables_initializer() # Launch the graph. sess = tf.Session() sess.run(init) if len(os.listdir(log_path)) != 0: # 已经有模型直接读取 saver.restore(sess, os.path.join(log_path, log_name)) for step in range(201): sess.run(train) if step % 20 == 0: print(step, sess.run(W), sess.run(b)) saver.save(sess, os.path.join(log_path, log_name))
这段代码是对tensorflow官网上的demo做一个微小的改动.如果模型已经存在,就先读取模型接着训练.tf.train.Supervisor可以简化这个步骤.看下面的代码.
import tensorflow as tf import numpy as np import os log_path = r"D:\Source\model\supervisor" log_name = "linear.ckpt" x_data = np.random.rand(100).astype(np.float32) y_data = x_data * 0.1 + 0.3 W = tf.Variable(tf.random_uniform([1], -1.0, 1.0)) b = tf.Variable(tf.zeros([1])) y = W * x_data + b loss = tf.reduce_mean(tf.square(y - y_data)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) saver = tf.train.Saver() init = tf.global_variables_initializer() sv = tf.train.Supervisor(logdir=log_path, init_op=init) # logdir用来保存checkpoint和summary saver = sv.saver # 创建saver with sv.managed_session() as sess: # 会自动去logdir中去找checkpoint,如果没有的话,自动执行初始化 for i in range(201): sess.run(train) if i % 20 == 0: print(i, sess.run(W), sess.run(b)) saver.save(sess, os.path.join(log_path, log_name))
sv = tf.train.Supervisor(logdir=log_path, init_op=init)会判断模型是否存在.如果存在,会自动读取模型.不用显式地调用restore.
参考资料
tensorflow官方文档tensorflow学习笔记(二十二):Supervisor
相关文章推荐
- tensorflow tf.nn.embedding_lookup(embeddings, train_inputs)解释
- tensorflow tf.train.batch之数据批量读取
- 【tensorflow】打印Tensorflow graph中的所有变量--tf.trainable_variables()
- 解决PyCharm [import tensorflow as tf]报错
- tensorflow API: tf.clip_by_value
- 学习 train.py ( TensorFlow Object Detection API)
- tensorflow tf.tile 实例
- tensorflow.nn.relu的作用
- 学习笔记TF067:TensorFlow Serving、Flod、计算加速,机器学习评测体系,公开数据集
- python3下使用TensorFlow Object Detection打包TFRecord
- tensorflow tf.group
- tensorflow slim【TF-Slim】
- 【Tensorflow】辅助工具篇——tensorflow slim(TF-Slim)介绍
- WARNING:tensorflow:From tf_should_use.py:107 initialize_all_variables(from tensorflow.python.ops.var
- tensorflow API:tf.set_random_seed
- 解决Jupyter notebook[import tensorflow as tf]报错
- 第五课 Tensorflow TFRecord读取数据
- TensorFlow resize_images函数导致TFRecord产生形状不匹配
- tensorflow tf.layers.dense 实例