您的位置:首页 > 其它

【NLP】TensorFlow实现CNN用于中文文本分类

2018-02-04 17:00 776 查看
代码基于 dennybritz/cnn-text-classification-tfclayandgithub/zh_cnn_text_classify

参考文章 了解用于NLP的卷积神经网络(译)TensorFlow实现CNN用于文本分类(译)

本文完整代码 - Widiot/cnn-zh-text-classification

项目结构

以下是完整的目录结构示例,包括运行之后形成的目录和文件

cnn-zh-text-classification/
data/
maildata/
cleaned_ham_5000.utf8
cleaned_spam_5000.utf8
ham_5000.utf8
spam_5000.utf8
runs/
1517572900/
checkpoints/
...
summaries/
...
prediction.csv
vocab
.gitignore
README.md
data_helpers.py
eval.py
text_cnn.py
train.py


各个目录及文件的作用如下

data 目录用于存放数据

maildata 目录用于存放邮件文件,目前有四个文件,ham_5000.utf8 及 spam_5000.utf8 分别为正常邮件和垃圾邮件,带 cleaned 前缀的文件为清洗后的数据

runs 目录用于存放每次运行产生的数据,以时间戳为目录名

1517572900 目录用于存放每次运行产生的检查点、日志摘要、词汇文件及评估产生的结果

data_helpers.py 用于处理数据

eval.py 用于评估模型

text_cnn.py 是 CNN 模型类

train.py 用于训练模型

数据

数据格式

以分类正常邮件和垃圾邮件为例,如下是邮件数据的例子

# 正常邮件
他们自己也是刚到北京不久 跟在北京读书然后留在这里工作的还不一样 难免会觉得还有好多东西没有安顿下来 然后来了之后还要带着四处旅游甚么什么的 却是花费很大 你要不带着出去玩,还真不行 这次我小表弟来北京玩,花了好多钱 就因为本来预定的几个地方因为某种原因没去 舅妈似乎就很不开心 结果就是钱全白花了 人家也是牢骚一肚子 所以是自己找出来的困难 退一万步说 婆婆来几个月
发文时难免欠点理智 我不怎么灌水,没想到上了十大了,拍的还挺欢,呵呵 写这个贴子,是由于自己太郁闷了,其时,我最主要的目的,是觉得,水木上肯定有一些嫁农村GG但现在很幸福的JJMM.我目前遇到的问题,我的确不知道怎么解决,所以发上来,问一下成功解决这类问题的建议.因为没有相同的经历和体会,是不会理解的,我在我身边就找不到可行的建议. 结果,无心得罪了不少人.呵呵,可能我想了太多关于城乡差别的问题,意识的比较深刻,所以不经意写了出来.
所以那些贵族1就要找一些特定的东西来章显自己的与众不同 这个东西一定是穷人买不起的,所以好多奢侈品也就营运诞生了 想想也是,他们要表也没有啊, 我要是香paris hilton那么有钱,就每天一个牌子的表,一个牌子的时装,一个牌子的汽车,哈哈,。。。要得就是这个派 俺连表都不用, 带手上都累赘, 上课又不能开手机, 所以俺只好经常退一下ppt去看右下脚的时间. 其实 贵族又不用赶时间, 要知道精确时间做啥? 表走的

# 垃圾邮件
中信(国际)电子科技有限公司推出新产品: 升职步步高、做生意发大财、连找情人都用的上,详情进入 网  址:  http://www.usa5588.com/ccc 电话:020-33770208   服务热线:013650852999
以下不能正确显示请点此 IFRAME: http://www.ewzw.com/bbs/viewthread.php?tid=3809&fpage=1 尊敬的公司您好!打扰之处请见谅! 我深圳公司愿在互惠互利、诚信为本代开3厘---2点国税、地税等发票。增值税和海关缴款书就以2点---7点来代开。手机:13510631209       联系人:邝先生  邮箱:ao998@163.com     祥细资料合作告知,希望合作。谢谢!!


每个句子单独一行,正常邮件和垃圾邮件的数据分别存放在两个文件中

数据处理

数据处理 data_helpers.py 的代码如下,与所参考的代码不同的是

load_data_and_labels():将函数的参数修改为以逗号分隔的数据文件的路径字符串,比如
'./data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8'
,这样可以读取多个类别的数据文件以实现多分类问题

read_and_clean_zh_file():将函数的 output_cleaned_file 修改为 boolean 类型,控制是否保存清洗后的数据,并在函数中判断,如果已经存在清洗后的数据文件则直接加载,否则进行清洗并选择保存

其他函数与所参考的代码相比变动不大

import numpy as np
import re
import os

def load_data_and_labels(data_files):
"""
1. 加载所有数据和标签
2. 可以进行多分类,每个类别的数据单独放在一个文件中
2. 保存处理后的数据
"""
data_files = data_files.split(',')
num_data_file = len(data_files)
assert num_data_file > 1
x_text = []
y = []
for i, data_file in enumerate(data_files):
# 将数据放在一起
data = read_and_clean_zh_file(data_file, True)
x_text += data
# 形成数据对应的标签
label = [0] * num_data_file
label[i] = 1
labels = [label for _ in data]
y += labels
return [x_text, np.array(y)]

def read_and_clean_zh_file(input_file, output_cleaned_file=False):
"""
1. 读取中文文件并清洗句子
2. 可以将清洗后的结果保存到文件
3. 如果已经存在经过清洗的数据文件则直接加载
"""
data_file_path, file_name = os.path.split(input_file)
output_file = os.path.join(data_file_path, 'cleaned_' + file_name)
if os.path.exists(output_file):
lines = list(open(output_file, 'r').readlines())
lines = [line.strip() for line in lines]
else:
lines = list(open(input_file, 'r').readlines())
lines = [clean_str(seperate_line(line)) for line in lines]
if output_cleaned_file:
with open(output_file, 'w') as f:
for line in lines:
f.write(line + '\n')
return lines

def clean_str(string):
"""
1. 将除汉字外的字符转为一个空格
2. 将连续的多个空格转为一个空格
3. 除去句子前后的空格字符
"""
string = re.sub(r'[^\u4e00-\u9fff]', ' ', string)
string = re.sub(r'\s{2,}', ' ', string)
return string.strip()

def seperate_line(line):
"""
将句子中的每个字用空格分隔开
"""
return ''.join([word + ' ' for word in line])

def batch_iter(data, batch_size, num_epochs, shuffle=True):
'''
生成一个batch迭代器
'''
data = np.array(data)
data_size = len(data)
num_batches_per_epoch = int((data_size - 1) / batch_size) + 1
for epoch in range(num_epochs):
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size))
shuffled_data = data[shuffle_indices]
else:
shuffled_data = data
for batch_num in range(num_batches_per_epoch):
start_idx = batch_num * batch_size
end_idx = min((batch_num + 1) * batch_size, data_size)
yield shuffled_data[start_idx:end_idx]

if __name__ == '__main__':
data_files = './data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8'
x_text, y = load_data_and_labels(data_files)
print(x_text)


清洗标准

将原始数据进行清洗,仅保留汉字,并把每个汉字用一个空格分隔开,各个类别清洗后的数据分别存放在 cleaned 前缀的文件中,清洗后的数据格式如下

本 公 司 有 部 分 普 通 发 票 商 品 销 售 发 票 增 值 税 发 票 及 海 关 代 征 增 值 税 专 用 缴 款 书 及 其 它 服 务 行 业 发 票 公 路 内 河 运 输 发 票 可 以 以 低 税 率 为 贵 公 司 代 开 本 公 司 具 有 内 外 贸 生 意 实 力 保 证 我 司 开 具 的 票 据 的 真 实 性 希 望 可 以 合 作 共 同 发 展 敬 侯 您 的 来 电 洽 谈 咨 询 联 系 人 李 先 生 联 系 电 话 如 有 打 扰 望 谅 解 祝 商 琪


模型

CNN 模型类 text_cnn.py 的代码如下,修改的地方如下

将 concat 和 reshape 的操作结点放在 concat 命名空间下,这样在 TensorBoard 中的节点图更加清晰合理

将计算损失值的操作修改为通过 collection 进行,并只计算 W 的 L2 损失值,删去了计算 b 的 L2 损失值的代码

import tensorflow as tf
import numpy as np

class TextCNN(object):
"""
字符级CNN文本分类
词嵌入层->卷积层->池化层->softmax层
"""

def __init__(self,
sequence_length,
num_classes,
vocab_size,
embedding_size,
filter_sizes,
num_filters,
l2_reg_lambda=0.0):

# 输入,输出,dropout的占位符
self.input_x = tf.placeholder(
tf.int32, [None, sequence_length], name='input_x')
self.input_y = tf.placeholder(
tf.float32, [None, num_classes], name='input_y')
self.dropout_keep_prob = tf.placeholder(
tf.float32, name='dropout_keep_prob')

# l2正则化损失值(可选)
#l2_loss = tf.constant(0.0)

# 词嵌入层
# W为词汇表,大小为0~词汇总数,索引对应不同的字,每个字映射为128维的数组,比如[3800,128]
with tf.name_scope('embedding'):
self.W = tf.Variable(
tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
name='W')
self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
self.embedded_chars_expanded = tf.expand_dims(
self.embedded_chars, -1)

# 卷积层和池化层
# 为3,4,5分别创建128个过滤器,总共3×128个过滤器
# 过滤器形状为[3,128,1,128],表示一次能过滤三个字,最后形成188×128的特征向量
# 池化核形状为[1,188,1,1],128维中的每一维表示该句子的不同向量表示,池化即从每一维中提取最大值表示该维的特征
# 池化得到的特征向量为128维
pooled_outputs = []
for i, filter_size in enumerate(filter_sizes):
with tf.name_scope('conv-maxpool-%s' % filter_size):
# 卷积层
filter_shape = [filter_size, embedding_size, 1, num_filters]
W = tf.Variable(
tf.truncated_normal(filter_shape, stddev=0.1), name='W')
b = tf.Variable(
tf.constant(0.1, shape=[num_filters]), name='b')
conv = tf.nn.conv2d(
self.embedded_chars_expanded,
W,
strides=[1, 1, 1, 1],
padding='VALID',
name='conv')
# ReLU激活
h = tf.nn.relu(tf.nn.bias_add(conv, b), name='relu')
# 池化层
pooled = tf.nn.max_pool(
h,
ksize=[1, sequence_length - filter_size + 1, 1, 1],
strides=[1, 1, 1, 1],
padding='VALID',
name='pool')
pooled_outputs.append(pooled)

# 组合所有池化后的特征
# 将三个过滤器得到的特征向量组合成一个384维的特征向量
num_filters_total = num_filters * len(filter_sizes)
with tf.name_scope('concat'):
self.h_pool = tf.concat(pooled_outputs, 3)
self.h_pool_flat = tf.reshape(self.h_pool, [-1, num_filters_total])

# dropout
with tf.name_scope('dropout'):
self.h_drop = tf.nn.dropout(self.h_pool_flat,
self.dropout_keep_prob)

# 全连接层
# 分数和预测结果
with tf.name_scope('output'):
W = tf.Variable(
tf.truncated_normal(
[num_filters_total, num_classes], stddev=0.1),
name='W')
b = tf.Variable(tf.constant(0.1, shape=[num_classes]), name='b')
if l2_reg_lambda:
W_l2_loss = tf.contrib.layers.l2_regularizer(l2_reg_lambda)(W)
tf.add_to_collection('losses', W_l2_loss)
self.scores = tf.nn.xw_plus_b(self.h_drop, W, b, name='scores')
self.predictions = tf.argmax(self.scores, 1, name='predictions')

# 计算交叉损失熵
with tf.name_scope('loss'):
mse_loss = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(
logits=self.scores, labels=self.input_y))
tf.add_to_collection('losses', mse_loss)
self.loss = tf.add_n(tf.get_collection('losses'))

# 正确率
with tf.name_scope('accuracy'):
correct_predictions = tf.equal(self.predictions,
tf.argmax(self.input_y, 1))
self.accuracy = tf.reduce_mean(
tf.cast(correct_predictions, 'float'), name='accuracy')


最终的神经网络结构图在 TensorBoard 中的样式如下



训练

训练模型的 train.py 代码如下,修改的地方如下

将数据文件的路径参数修改为一个用逗号分隔开的字符串,便于实现多分类问题

tf.flags 重命名为 flags,更加简洁

import tensorflow as tf
import numpy as np
import os
import time
import datetime
import data_helpers
from text_cnn import TextCNN
from tensorflow.contrib import learn

# 参数
# ==================================================

flags = tf.flags

# 数据加载参数
flags.DEFINE_float('dev_sample_percentage', 0.1,
'Percentage of the training data to use for validation')
flags.DEFINE_string(
'data_files',
'./data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8',
'Comma-separated data source files')

# 模型超参数
flags.DEFINE_integer('embedding_dim', 128,
'Dimensionality of character embedding (default: 128)')
flags.DEFINE_string('filter_sizes', '3,4,5',
'Comma-separated filter sizes (default: "3,4,5")')
flags.DEFINE_integer('num_filters', 128,
'Number of filters per filter size (default: 128)')
flags.DEFINE_float('dropout_keep_prob', 0.5,
'Dropout keep probability (default: 0.5)')
flags.DEFINE_float('l2_reg_lambda', 0.0,
'L2 regularization lambda (default: 0.0)')

# 训练参数
flags.DEFINE_integer('batch_size', 64, 'Batch Size (default: 64)')
flags.DEFINE_integer('num_epochs', 10,
'Number of training epochs (default: 10)')
flags.DEFINE_integer(
'evaluate_every', 100,
'Evaluate model on dev set after this many steps (default: 100)')
flags.DEFINE_integer('checkpoint_every', 100,
'Save model after this many steps (default: 100)')
flags.DEFINE_integer('num_checkpoints', 5,
'Number of checkpoints to store (default: 5)')

# 其他参数
flags.DEFINE_boolean('allow_soft_placement', True,
'Allow device soft device placement')
flags.DEFINE_boolean('log_device_placement', False,
'Log placement of ops on devices')

FLAGS = flags.FLAGS
FLAGS._parse_flags()
print('\nParameters:')
for attr, value in sorted(FLAGS.__flags.items()):
print('{}={}'.format(attr.upper(), value))
print('')

# 数据准备
# ==================================================

# 加载数据
print('Loading data...')
x_text, y = data_helpers.load_data_and_labels(FLAGS.data_files)

# 建立词汇表
max_document_length = max([len(x.split(' ')) for x in x_text])
vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
x = np.array(list(vocab_processor.fit_transform(x_text)))

# 随机混淆数据
np.random.seed(10)
shuffle_indices = np.random.permutation(np.arange(len(y)))
x_shuffled = x[shuffle_indices]
y_shuffled = y[shuffle_indices]

# 划分train/test数据集
# TODO: 这种做法比较暴力,应该用交叉验证
dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y)))
x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:]
y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:]

del x, y, x_shuffled, y_shuffled

print('Vocabulary Size: {:d}'.format(len(vocab_processor.vocabulary_)))
print('Train/Dev split: {:d}/{:d}'.format(len(y_train), len(y_dev)))
print('')

# 训练
# ==================================================

with tf.Graph().as_default():
session_conf = tf.ConfigProto(
allow_soft_placement=FLAGS.allow_soft_placement,
log_device_placement=FLAGS.log_device_placement)
sess = tf.Session(config=session_conf)
with sess.as_default():
cnn = TextCNN(
sequence_length=x_train.shape[1],
num_classes=y_train.shape[1],
vocab_size=len(vocab_processor.vocabulary_),
embedding_size=FLAGS.embedding_dim,
filter_sizes=list(map(int, FLAGS.filter_sizes.split(','))),
num_filters=FLAGS.num_filters,
l2_reg_lambda=FLAGS.l2_reg_lambda)

# 定义训练相关操作
global_step = tf.Variable(0, name='global_step', trainable=False)
optimizer = tf.train.AdamOptimizer(1e-3)
grads_and_vars = optimizer.compute_gradients(cnn.loss)
train_op = optimizer.apply_gradients(
grads_and_vars, global_step=global_step)

# 跟踪梯度值和稀疏性(可选)
grad_summaries = []
for g, v in grads_and_vars:
if g is not None:
grad_hist_summary = tf.summary.histogram(
'{}/grad/hist'.format(v.name), g)
sparsity_summary = tf.summary.scalar('{}/grad/sparsity'.format(
v.name), tf.nn.zero_fraction(g))
grad_summaries.append(grad_hist_summary)
grad_summaries.append(sparsity_summary)
grad_summaries_merged = tf.summary.merge(grad_summaries)

# 模型和摘要的保存目录
timestamp = str(int(time.time()))
out_dir = os.path.abspath(
os.path.join(os.path.curdir, 'runs', timestamp))
print('\nWriting to {}\n'.format(out_dir))

# 损失值和正确率的摘要
loss_summary = tf.summary.scalar('loss', cnn.loss)
acc_summary = tf.summary.scalar('accuracy', cnn.accuracy)

# 训练摘要
train_summary_op = tf.summary.merge(
[loss_summary, acc_summary, grad_summaries_merged])
train_summary_dir = os.path.join(out_dir, 'summaries', 'train')
train_summary_writer = tf.summary.FileWriter(train_summary_dir,
sess.graph)

# 开发摘要
dev_summary_op = tf.summary.merge([loss_summary, acc_summary])
dev_summary_dir = os.path.join(out_dir, 'summaries', 'dev')
dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)

# 检查点目录,默认存在
checkpoint_dir = os.path.abspath(os.path.join(out_dir, 'checkpoints'))
checkpoint_prefix = os.path.join(checkpoint_dir, 'model')
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
saver = tf.train.Saver(
tf.global_variables(), max_to_keep=FLAGS.num_checkpoints)

# 写入词汇表文件
vocab_processor.save(os.path.join(out_dir, 'vocab'))

# 初始化变量
sess.run(tf.global_variables_initializer())

def train_step(x_batch, y_batch):
"""
一个训练步骤
"""
feed_dict = {
cnn.input_x: x_batch,
cnn.input_y: y_batch,
cnn.dropout_keep_prob: FLAGS.dropout_keep_prob
}
_, step, summaries, loss, accuracy = sess.run([
train_op, global_step, train_summary_op, cnn.loss, cnn.accuracy
], feed_dict)
time_str = datetime.datetime.now().isoformat()
print('{}: step {}, loss {:g}, acc {:g}'.format(
time_str, step, loss, accuracy))
train_summary_writer.add_summary(summaries, step)

def dev_step(x_batch, y_batch, writer=None):
"""
在开发集上验证模型
"""
feed_dict = {
cnn.input_x: x_batch,
cnn.input_y: y_batch,
cnn.dropout_keep_prob: 1.0
}
step, summaries, loss, accuracy = sess.run(
[global_step, dev_summary_op, cnn.loss, cnn.accuracy],
feed_dict)
time_str = datetime.datetime.now().isoformat()
print('{}: step {}, loss {:g}, acc {:g}'.format(
time_str, step, loss, accuracy))
if writer:
writer.add_summary(summaries, step)

# 生成batches
batches = data_helpers.batch_iter(
list(zip(x_train, y_train)), FLAGS.batch_size, FLAGS.num_epochs)
# 迭代训练每个batch
for batch in batches:
x_batch, y_batch = zip(*batch)
train_step(x_batch, y_batch)
current_step = tf.train.global_step(sess, global_step)
if current_step % FLAGS.evaluate_every == 0:
print('\nEvaluation:')
dev_step(x_dev, y_dev, writer=dev_summary_writer)
print('')
if current_step % FLAGS.checkpoint_every == 0:
path = saver.save(
sess, checkpoint_prefix, global_step=current_step)
print('Saved model checkpoint to {}\n'.format(path))


训练过程的输出如下

Parameters:
ALLOW_SOFT_PLACEMENT=True
BATCH_SIZE=64
CHECKPOINT_EVERY=100
DATA_FILES=./data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8
DEV_SAMPLE_PERCENTAGE=0.1
DROPOUT_KEEP_PROB=0.5
EMBEDDING_DIM=128
EVALUATE_EVERY=100
FILTER_SIZES=3,4,5
L2_REG_LAMBDA=0.0
LOG_DEVICE_PLACEMENT=False
NUM_CHECKPOINTS=5
NUM_EPOCHS=10
NUM_FILTERS=128

Loading data...
Vocabulary Size: 3628
Train/Dev split: 9001/1000

Writing to /home/widiot/workspace/tensorflow-ws/tensorflow-gpu/text-classification/cnn-zh-text-classification/runs/1517734186

2018-02-04T16:50:03.709761: step 1, loss 5.36006, acc 0.46875
2018-02-04T16:50:03.786874: step 2, loss 4.61227, acc 0.390625
2018-02-04T16:50:03.857796: step 3, loss 2.50795, acc 0.5625
...
2018-02-04T16:50:10.819505: step 98, loss 0.622567, acc 0.90625
2018-02-04T16:50:10.899140: step 99, loss 1.10189, acc 0.875
2018-02-04T16:50:10.983192: step 100, loss 0.359102, acc 0.9375

Evaluation:
2018-02-04T16:50:11.848838: step 100, loss 0.132987, acc 0.961

Saved model checkpoint to /home/widiot/workspace/tensorflow-ws/tensorflow-gpu/text-classification/cnn-zh-text-classification/runs/1517734186/checkpoints/model-100

2018-02-04T16:50:12.019749: step 101, loss 0.512838, acc 0.890625
2018-02-04T16:50:12.100965: step 102, loss 0.164333, acc 0.96875
2018-02-04T16:50:12.184899: step 103, loss 0.145344, acc 0.921875
...


训练之后会在 runs 目录下生成对应的数据目录,包含检查点、日志摘要和词汇文件

训练时的正确率变化如下



评估

评估模型的 eval.py 代码如下,修改的地方如下

同 train.py 将数据文件路径参数修改为逗号分隔开的字符串,便于实现多分类问题

添加对自己未经处理的数据的清洗操作,便于直接分类评估数据

import tensorflow as tf
import numpy as np
import os
import time
import datetime
import data_helpers
import csv
import data_helpers
from text_cnn import TextCNN
from tensorflow.contrib import learn

# 参数
# ==================================================

flags = tf.flags

# 数据参数
flags.DEFINE_string(
'data_files',
'./data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8',
'Comma-separated data source files')

# 评估参数
flags.DEFINE_integer('batch_size', 64, 'Batch Size (default: 64)')
flags.DEFINE_string('checkpoint_dir', './runs/1517572900/checkpoints',
'Checkpoint directory from training run')
flags.DEFINE_boolean('eval_train', False, 'Evaluate on all training data')

# 其他参数
flags.DEFINE_boolean('allow_soft_placement', True,
'Allow device soft device placement')
flags.DEFINE_boolean('log_device_placement', False,
'Log placement of ops on devices')

FLAGS = flags.FLAGS
FLAGS._parse_flags()
print('\nParameters:')
for attr, value in sorted(FLAGS.__flags.items()):
print('{}={}'.format(attr.upper(), value))
print('')

# 加载训练数据或者修改测试句子
if FLAGS.eval_train:
x_raw, y_test = data_helpers.load_data_and_labels(FLAGS.data_files)
y_test = np.argmax(y_test, axis=1)
else:
x_raw = [
'亲爱的CFer,您获得了英雄级道具。还有全新英雄级道具在等你来拿,立即登录游戏领取吧!',
'第一个build错误的解决方法能再说一下吗,我还是不懂怎么解决', '请联系张经理获取最新资讯'
]
y_test = [0, 1, 0]

# 对自己的数据的处理
x_raw_cleaned = [
data_helpers.clean_str(data_helpers.seperate_line(line)) for line in x_raw
]
print(x_raw_cleaned)

# 将数据转为词汇表的索引
vocab_path = os.path.join(FLAGS.checkpoint_dir, '..', 'vocab')
vocab_processor = learn.preprocessing.VocabularyProcessor.restore(vocab_path)
x_test = np.array(list(vocab_processor.transform(x_raw_cleaned)))

print('\nEvaluating...\n')

# 评估
# ==================================================

checkpoint_file = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
graph = tf.Graph()
with graph.as_default():
session_conf = tf.ConfigProto(
allow_soft_placement=FLAGS.allow_soft_placement,
log_device_placement=FLAGS.log_device_placement)
sess = tf.Session(config=session_conf)
with sess.as_default():
# 加载保存的元图和变量
saver = tf.train.import_meta_graph('{}.meta'.format(checkpoint_file))
saver.restore(sess, checkpoint_file)

# 通过名字从图中获取占位符
input_x = graph.get_operation_by_name('input_x').outputs[0]
# input_y = graph.get_operation_by_name('input_y').outputs[0]
dropout_keep_prob = graph.get_operation_by_name(
'dropout_keep_prob').outputs[0]

# 我们想要评估的tensors
predictions = graph.get_operation_by_name(
'output/predictions').outputs[0]

# 生成每个轮次的batches
batches = data_helpers.batch_iter(
list(x_test), FLAGS.batch_size, 1, shuffle=False)

# 收集预测值
all_predictions = []

for x_test_batch in batches:
batch_predictions = sess.run(predictions, {
input_x: x_test_batch,
dropout_keep_prob: 1.0
})
all_predictions = np.concatenate(
[all_predictions, batch_predictions])

# 如果提供了标签则打印正确率
if y_test is not None:
correct_predictions = float(sum(all_predictions == y_test))
print('\nTotal number of test examples: {}'.format(len(y_test)))
print('Accuracy: {:g}'.format(correct_predictions / float(len(y_test))))

# 保存评估为csv
predictions_human_readable = np.column_stack((np.array(x_raw),
all_predictions))
out_path = os.path.join(FLAGS.checkpoint_dir, '..', 'prediction.csv')
print('Saving evaluation to {0}'.format(out_path))
with open(out_path, 'w') as f:
csv.writer(f).writerows(predictions_human_readable)


评估过程中的输出如下

Parameters:
ALLOW_SOFT_PLACEMENT=True
BATCH_SIZE=64
CHECKPOINT_DIR=./runs/1517572900/checkpoints
DATA_FILES=./data/maildata/spam_5000.utf8,./data/maildata/ham_5000.utf8
EVAL_TRAIN=False
LOG_DEVICE_PLACEMENT=False

['亲 爱 的 您 获 得 了 英 雄 级 道 具 还 有 全 新 英 雄 级 道 具 在 等 你 来 拿 立 即 登 录 游 戏 领 取 吧', '第 一 个 错 误 的 解 决 方 法 能 再 说 一 下 吗 我 还 是 不 懂 怎 么 解 决', '请 联 系 张 经 理 获 取 最 新 资 讯']

Evaluating...

Total number of test examples: 3
Accuracy: 1
Saving evaluation to ./runs/1517572900/checkpoints/../prediction.csv


评估之后会在 runs 目录对应的文件夹下生成一个 prediction.csv 文件,如下所示

亲爱的CFer,您获得了英雄级道具。还有全新英雄级道具在等你来拿,立即登录游戏领取吧!,0.0
第一个build错误的解决方法能再说一下吗,我还是不懂怎么解决,1.0
请联系张经理获取最新资讯,0.0
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息