您的位置:首页 > 编程语言

CIFAR10 代码分析详解——cifar10_train.py

2017-03-24 08:30 771 查看
先在这里种个草,开篇后慢慢补完

引入各种库,并定义参数

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from datetime import datetime
import time

import tensorflow as tf

import cifar10

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
"""Directory where to write event logs """
"""and checkpoint.""")
tf.app.flags.DEFINE_integer('max_steps', 1000000,
"""Number of batches to run.""")
tf.app.flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""")
tf.app.flags.DEFINE_integer('log_frequency', 10,
"""How often to log results to the console.""")




下面是训练函数主体

def train():
"""Train CIFAR-10 for a number of steps."""

#定义一个图,关于Graph的用法查链接
with tf.Graph().as_default():

#获取global_step,至于为什么这么用有待考证。
tf.contrib.framework.get_or_create_global_step(Graph)
#若无输入图,则为默认图

global_step = tf.contrib.framework.get_or_create_global_step()

# Get images and labels for CIFAR-10.
images, labels = cifar10.distorted_inputs()

# Build a Graph that computes the logits predictions from the
# inference model.
logits = cifar10.inference(images)

# Calculate loss.
loss = cifar10.loss(logits, labels)

# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
train_op = cifar10.train(loss, global_step)

#log部分以后再补充????
class _LoggerHook(tf.train.SessionRunHook):
"""Logs loss and runtime."""

def begin(self):
self._step = -1
self._start_time = time.time()

def before_run(self, run_context):
self._step += 1
return tf.train.SessionRunArgs(loss)  # Asks for loss value.

def after_run(self, run_context, run_values):
if self._step % FLAGS.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time

loss_value = run_values.results
examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
sec_per_batch = float(duration / FLAGS.log_frequency)

format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
'sec/batch)')
print (format_str % (datetime.now(), self._step, loss_value,
examples_per_sec, sec_per_batch))

#这里要找到stop criterion????
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)

def main(argv=None):  # pylint: disable=unused-argument
cifar10.maybe_download_and_extract()
if tf.gfile.Exists(FLAGS.train_dir):
tf.gfile.DeleteRecursively(FLAGS.train_dir)
tf.gfile.MakeDirs(FLAGS.train_dir)
train()

if __name__ == '__main__':
tf.app.run()


该部分代码比较简单,在主体函数 train() 中先通过 cifar10.distorted_input() 读取图像和标签,然后通过cifar10.inference() 进行 logits 的估计,通过cifar10.loss() 来计算损失,再创建一个 train_op=cifar10.train()
来进行模型训练参数更新,直到满足 stop criterion。调用的函数参见相应的文章。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  机器学习