TensorFlow-4: tf.contrib.learn 快速入门
2017-04-26 10:39
501 查看
学习资料:
https://www.tensorflow.org/get_started/tflearn
相应的中文翻译:
http://studyai.site/2017/03/05/%E3%80%90Tensorflow%20r1.0%20%E6%96%87%E6%A1%A3%E7%BF%BB%E8%AF%91%E3%80%91%E3%80%90tf.contrib.learn%E5%BF%AB%E9%80%9F%E5%85%A5%E9%97%A8%E3%80%91/
今天学习用 tf.contrib.learn 来建立 DNN 对 Iris 数据集进行分类.
问题:
我们有 Iris 数据集,它包含150个样本数据,分别来自三个品种,每个品种有50个样本,每个样本具有四个特征,以及它属于哪一类,分别由 0,1,2 代表三个品种。
我们将这150个样本分为两份,一份是训练集具有120个样本,另一份是测试集具有30个样本。
我们要做的就是建立一个神经网络分类模型对每个样本进行分类,识别它是哪个品种。
一共有 5 步:
导入 CSV 格式的数据集
建立神经网络分类模型
用训练数据集训练模型
评价模型的准确率
对新样本数据进行分类
代码:
从代码可以看出很简短的几行就可以完成之前学过的很长的代码所做的事情,用起来和用 sklearn 相似。
关于
https://www.tensorflow.org/api_guides/python/contrib.learn
可以看到里面也有
在上面的代码中:
用
分类器模型只需要一行代码,就可以设置这个模型具有多少隐藏层,每个隐藏层有多少神经元,以及最后分为几类。
模型的训练也是只需要一行代码,输入指定的数据,包括特征和标签,再指定迭代的次数,就可以进行训练。
获得准确率也同样很简单,只需要输入测试集,调用 evaluate。
预测新的数据集,只需要把新的样本数据传递给 predict。
关于代码里几个新的方法:
1.
用于导入 CSV,需要三个必需的参数:
filename,CSV文件的路径
target_dtype,数据集的目标值的numpy数据类型。
features_dtype,数据集的特征值的numpy数据类型。
在这里,target 是花的品种,它是一个从 0-2 的整数,所以对应的numpy数据类型是np.int
2.
所有的特征数据都是连续的,因此用 tf.contrib.layers.real_valued_column,数据集中有四个特征(萼片宽度,萼片高度,花瓣宽度和花瓣高度),因此 dimension=4 。
3.
在后面会学到关于 TensorFlow 的 logging and monitoring 的章节,可以 track 一下训练中的模型: “Logging and Monitoring Basics with tf.contrib.learn”。
推荐阅读
历史技术博文链接汇总
也许可以找到你想要的
https://www.tensorflow.org/get_started/tflearn
相应的中文翻译:
http://studyai.site/2017/03/05/%E3%80%90Tensorflow%20r1.0%20%E6%96%87%E6%A1%A3%E7%BF%BB%E8%AF%91%E3%80%91%E3%80%90tf.contrib.learn%E5%BF%AB%E9%80%9F%E5%85%A5%E9%97%A8%E3%80%91/
今天学习用 tf.contrib.learn 来建立 DNN 对 Iris 数据集进行分类.
问题:
我们有 Iris 数据集,它包含150个样本数据,分别来自三个品种,每个品种有50个样本,每个样本具有四个特征,以及它属于哪一类,分别由 0,1,2 代表三个品种。
我们将这150个样本分为两份,一份是训练集具有120个样本,另一份是测试集具有30个样本。
我们要做的就是建立一个神经网络分类模型对每个样本进行分类,识别它是哪个品种。
一共有 5 步:
导入 CSV 格式的数据集
建立神经网络分类模型
用训练数据集训练模型
评价模型的准确率
对新样本数据进行分类
代码:
from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import urllib import numpy as np import tensorflow as tf # Data sets IRIS_TRAINING = "iris_training.csv" IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv" IRIS_TEST = "iris_test.csv" IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv" def main(): # If the training and test sets aren't stored locally, download them. if not os.path.exists(IRIS_TRAINING): raw = urllib.urlopen(IRIS_TRAINING_URL).read() with open(IRIS_TRAINING, "w") as f: f.write(raw) if not os.path.exists(IRIS_TEST): raw = urllib.urlopen(IRIS_TEST_URL).read() with open(IRIS_TEST, "w") as f: f.write(raw) # Load datasets. training_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32) test_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32) # Specify that all features have real-value data feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] # Build 3 layer DNN with 10, 20, 10 units respectively. classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3, model_dir="/tmp/iris_model") # Define the training inputs def get_train_inputs(): x = tf.constant(training_set.data) y = tf.constant(training_set.target) return x, y # Fit model. classifier.fit(input_fn=get_train_inputs, steps=2000) # Define the test inputs def get_test_inputs(): x = tf.constant(test_set.data) y = tf.constant(test_set.target) return x, y # Evaluate accuracy. accuracy_score = classifier.evaluate(input_fn=get_test_inputs, steps=1)["accuracy"] print("\nTest Accuracy: {0:f}\n".format(accuracy_score)) # Classify two new flower samples. def new_samples(): return np.array( [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=np.float32) predictions = list(classifier.predict(input_fn=new_samples)) print( "New Samples, Class Predictions: {}\n" .format(predictions)) if __name__ == "__main__": main()
从代码可以看出很简短的几行就可以完成之前学过的很长的代码所做的事情,用起来和用 sklearn 相似。
关于
tf.contrib.learn可以查看:
https://www.tensorflow.org/api_guides/python/contrib.learn
可以看到里面也有
kmeans,logistic,linear等模型:
在上面的代码中:
用
tf.contrib.learn.datasets.base.load_csv_with_header可以导入 CSV 数据集。
分类器模型只需要一行代码,就可以设置这个模型具有多少隐藏层,每个隐藏层有多少神经元,以及最后分为几类。
模型的训练也是只需要一行代码,输入指定的数据,包括特征和标签,再指定迭代的次数,就可以进行训练。
获得准确率也同样很简单,只需要输入测试集,调用 evaluate。
预测新的数据集,只需要把新的样本数据传递给 predict。
关于代码里几个新的方法:
1.
load_csv_with_header():
用于导入 CSV,需要三个必需的参数:
filename,CSV文件的路径
target_dtype,数据集的目标值的numpy数据类型。
features_dtype,数据集的特征值的numpy数据类型。
在这里,target 是花的品种,它是一个从 0-2 的整数,所以对应的numpy数据类型是np.int
2.
tf.contrib.layers.real_valued_column:
所有的特征数据都是连续的,因此用 tf.contrib.layers.real_valued_column,数据集中有四个特征(萼片宽度,萼片高度,花瓣宽度和花瓣高度),因此 dimension=4 。
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
3.
DNNClassifier:
feature_columns=feature_columns, 上面定义的一组特征
hidden_units=[10, 20, 10],三个隐藏层分别包含10,20,10个神经元。
n_classes=3,三个目标类,代表三个 Iris 品种。
model_dir=/tmp/iris_model,TensorFlow在模型训练期间将保存 checkpoint data。
在后面会学到关于 TensorFlow 的 logging and monitoring 的章节,可以 track 一下训练中的模型: “Logging and Monitoring Basics with tf.contrib.learn”。
推荐阅读
历史技术博文链接汇总
也许可以找到你想要的
相关文章推荐
- tf.contrib.learn快速入门
- 深度学习笔记——深度学习框架TensorFlow(十)[Creating Estimators in tf.contrib.learn]
- Tensorflow 利用tf.contrib.learn建立输入函数的方法
- TensorFlow学习笔记6----tf.contrib.learn Quickstart
- TensorFlow学习笔记12----Creating Estimators in tf.contrib.learn
- 05:Tensorflow高级API的进阶--利用tf.contrib.learn建立输入函数
- TensorFlow学习笔记10----Logging and Monitoring Basics with tf.contrib.learn
- TensorFlow高级API(tf.contrib.learn)及可视化工具TensorBoard的使用
- tensorflow学习笔记十五:tensorflow官方文档学习 Logging and Monitoring Basics with tf.contrib.learn
- 深度学习笔记——深度学习框架TensorFlow(八)[Logging and Monitoring Basics with tf.contrib.learn]
- Tensorflow高级API的进阶--利用tf.contrib.learn建立输入函数
- tensorflow学习笔记十四:TF官方教程学习 tf.contrib.learn Quickstart
- TensorFlow高级API(tf.contrib.learn)及可视化工具TensorBoard的使用
- 深度学习笔记——深度学习框架TensorFlow(四)[高级API tf.contrib.learn]
- TensorFlow入门深度学习–09.tf.contrib.slim用法详解
- [TensorFlow实战练习]3-高层API-tf.contrib.learn练习
- TensorFlow高层次机器学习API (tf.contrib.learn)
- 深度学习笔记——深度学习框架TensorFlow(九)[Building Input Functions with tf.contrib.learn]
- 学习笔记:Creating Estimators in tf.contrib.learn
- tensorflow入门4 windows环境下安装tf