用DNN对Iris数据分类的代码--tensorflow--logging/monitoring/earlystopping/visualizing
2017-05-09 16:23
330 查看
本博客是对 用深度神经网络对Iris数据集进行分类的程序–tensorflow
里面的代码进行修改,使其可以记录训练日志,监控训练指标,设置early stopping, 并在TensorBoard中进行可视化.
注意和原程序进行对比,看看增加了哪些code
在命令行输入
可以看到loss/accuracy/recall/precision/dnn/global_step等指标的可视化结果
里面的代码进行修改,使其可以记录训练日志,监控训练指标,设置early stopping, 并在TensorBoard中进行可视化.
注意和原程序进行对比,看看增加了哪些code
from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import numpy as np import tensorflow as tf tf.logging.set_verbosity(tf.logging.INFO) # Data sets IRIS_TRAINING = os.path.join(os.path.dirname(__file__), "iris_training.csv") IRIS_TEST = os.path.join(os.path.dirname(__file__), "iris_test.csv") def main(unused_argv): # Load datasets. training_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32) test_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32) validation_metrics = { "accuracy": tf.contrib.learn.MetricSpec( metric_fn=tf.contrib.metrics.streaming_accuracy, prediction_key=tf.contrib.learn.PredictionKey. CLASSES), "precision": tf.contrib.learn.MetricSpec( metric_fn=tf.contrib.metrics.streaming_precision, prediction_key=tf.contrib.learn.PredictionKey. CLASSES), "recall": tf.contrib.learn.MetricSpec( metric_fn=tf.contrib.metrics.streaming_recall, prediction_key=tf.contrib.learn.PredictionKey. CLASSES) } validation_monitor = tf.contrib.learn.monitors.ValidationMonitor( test_set.data, test_set.target, every_n_steps=50, metrics=validation_metrics, early_stopping_metric="loss", early_stopping_metric_minimize=True, early_stopping_rounds=200) # Specify that all features have real-value data feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] # Build 3 layer DNN with 10, 20, 10 units respectively. classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3, model_dir="/tmp/iris_model", config=tf.contrib.learn.RunConfig(save_checkpoints_secs=1)) # Fit model. classifier.fit(x=training_set.data, y=training_set.target, steps=2000, monitors=[validation_monitor]) # Evaluate accuracy. accuracy_score = classifier.evaluate(x=test_set.data, y=test_set.target)["accuracy"] print('Accuracy: {0:f}'.format(accuracy_score)) # Classify two new flower samples. new_samples = np.array( [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float) y = list(classifier.predict(new_samples, as_iterable=True)) print('Predictions: {}'.format(str(y))) if __name__ == "__main__": tf.app.run()
在命令行输入
$ tensorboard --logdir=/tmp/iris_model/
可以看到loss/accuracy/recall/precision/dnn/global_step等指标的可视化结果
相关文章推荐
- 9行代码,不用递归实现无限分类数据的树形格式化
- 5行代码足矣,不用递归实现无限分类数据的树形格式化
- 数据挖掘-oneR算法-Iris数据集分析-使用oneR算法进行分类预测(五)
- 10 张图详解 TensorFlow 数据读取机制(附代码)
- TensorFlow学习笔记10----Logging and Monitoring Basics with tf.contrib.learn
- 无限级分类数据结构和读取分类的php代码
- python机器学习——数据的分类(knn,决策树,贝叶斯)代码笔记
- 用深度神经网络对Iris数据集进行分类的程序--tensorflow
- 数据挖掘-K-近邻分类器-Iris数据集分析-使用K-近邻分类器进行分类预测(四)
- TensorFlow里面mnist导入手写数据代码分析
- PHP不用递归实现无限分类数据的树形格式化 5行9行代码修改
- php 无限分类 树形数据格式化代码
- Tensorflow二分类处理dense或者sparse(文本分类)的输入数据
- tensorflow学习笔记十五:tensorflow官方文档学习 Logging and Monitoring Basics with tf.contrib.learn
- 数据挖掘-K-近邻分类器-Iris数据集分析-根据花萼长宽分类-以散点图显示(一)
- 5行代码 不用递归实现无限分类数据的树形格式化
- Tensorflow 实现稠密输入数据的逻辑回归二分类
- 用Hive+Hadoop集群实现《飞机票购买人群分类案例》思路+代码 (实验数据待整理)
- [TensorFlow实战练习]2-对推特数据的情绪分析分类
- 十图详解TensorFlow数据读取机制(附代码)