CIFAR10 代码分析详解——cifar10_train.py
2017-03-24 08:30
771 查看
先在这里种个草,开篇后慢慢补完
引入各种库,并定义参数
下面是训练函数主体
该部分代码比较简单,在主体函数 train() 中先通过 cifar10.distorted_input() 读取图像和标签,然后通过cifar10.inference() 进行 logits 的估计,通过cifar10.loss() 来计算损失,再创建一个 train_op=cifar10.train()
来进行模型训练参数更新,直到满足 stop criterion。调用的函数参见相应的文章。
引入各种库,并定义参数
from __future__ import absolute_import from __future__ import division from __future__ import print_function from datetime import datetime import time import tensorflow as tf import cifar10 FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train', """Directory where to write event logs """ """and checkpoint.""") tf.app.flags.DEFINE_integer('max_steps', 1000000, """Number of batches to run.""") tf.app.flags.DEFINE_boolean('log_device_placement', False, """Whether to log device placement.""") tf.app.flags.DEFINE_integer('log_frequency', 10, """How often to log results to the console.""")
下面是训练函数主体
def train(): """Train CIFAR-10 for a number of steps.""" #定义一个图,关于Graph的用法查链接 with tf.Graph().as_default(): #获取global_step,至于为什么这么用有待考证。 tf.contrib.framework.get_or_create_global_step(Graph) #若无输入图,则为默认图 global_step = tf.contrib.framework.get_or_create_global_step() # Get images and labels for CIFAR-10. images, labels = cifar10.distorted_inputs() # Build a Graph that computes the logits predictions from the # inference model. logits = cifar10.inference(images) # Calculate loss. loss = cifar10.loss(logits, labels) # Build a Graph that trains the model with one batch of examples and # updates the model parameters. train_op = cifar10.train(loss, global_step) #log部分以后再补充???? class _LoggerHook(tf.train.SessionRunHook): """Logs loss and runtime.""" def begin(self): self._step = -1 self._start_time = time.time() def before_run(self, run_context): self._step += 1 return tf.train.SessionRunArgs(loss) # Asks for loss value. def after_run(self, run_context, run_values): if self._step % FLAGS.log_frequency == 0: current_time = time.time() duration = current_time - self._start_time self._start_time = current_time loss_value = run_values.results examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration sec_per_batch = float(duration / FLAGS.log_frequency) format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 'sec/batch)') print (format_str % (datetime.now(), self._step, loss_value, examples_per_sec, sec_per_batch)) #这里要找到stop criterion???? with tf.train.MonitoredTrainingSession( checkpoint_dir=FLAGS.train_dir, hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), tf.train.NanTensorHook(loss), _LoggerHook()], config=tf.ConfigProto( log_device_placement=FLAGS.log_device_placement)) as mon_sess: while not mon_sess.should_stop(): mon_sess.run(train_op) def main(argv=None): # pylint: disable=unused-argument cifar10.maybe_download_and_extract() if tf.gfile.Exists(FLAGS.train_dir): tf.gfile.DeleteRecursively(FLAGS.train_dir) tf.gfile.MakeDirs(FLAGS.train_dir) train() if __name__ == '__main__': tf.app.run()
该部分代码比较简单,在主体函数 train() 中先通过 cifar10.distorted_input() 读取图像和标签,然后通过cifar10.inference() 进行 logits 的估计,通过cifar10.loss() 来计算损失,再创建一个 train_op=cifar10.train()
来进行模型训练参数更新,直到满足 stop criterion。调用的函数参见相应的文章。
相关文章推荐
- CIFAR10 代码分析详解——cifar10_input.py
- CIFAR10 代码分析详解——cifar10.py
- theano-xnor-net代码注释 cifar10_train.py
- 关于DS18B20温度传感器的时序详解及代码分析
- libsvm 学习笔记(四)--- grid.py 关键代码详解
- 深入XPath的详解以及Java示例代码分析
- x264 代码重点详解 详细分析
- 四极管:2410启动代码分析之 vector.s详解一
- NSQ系列之nsqlookupd代码分析四(详解nsqlookupd中的RegitrationDB)
- NSQ系列之nsqlookupd代码分析三(详解tcpServer 中的IOLoop方法)
- web.py 直接使用示例代码,web.application报错, 'module' object has no attribute 'application',问题原因分析
- Windows C++代码heap分析详解
- php代码出现错误分析详解第1/2页
- tableView中cell的删除、插入、移动、复制粘贴问题详解代码分析
- Hadoop RCFile存储格式详解(源码分析、代码示例)
- jQuery选择器代码详解(六)——Sizzle选择器匹配逻辑分析
- wav文件格式分析详解和解析代码
- wav文件格式分析详解和解析代码
- libsvm代码阅读:关于svm_train函数分析
- SLIC超像素分割详解(二):关键代码分析