您的位置:首页 > 编程语言 > PHP开发

rnn、lstm,gru中output信息说明

2017-12-08 17:01 483 查看

在一般的rnn模型中,rnn一般输出的形式如下[batch,seq_len,hidden_size],如果用做分类,一般是取最后一个状态[batch,hidden_size],如果用于做词性标注和分词则取全部的状态[batch,seq_len,hidden_size],下面介绍下用于文本分类取最后状态的两种方法,一种是直接transpose,取[-1]最后一个状态,大小变为[batch,hidden_size],另外一个是直接[:,-1,:]把中间的seq_len删除,直接变成[batch,hidden_size],直接看代码:

In [2]: import tensorflow  as tf


In [3]: output=tf.get_variable(name='out',shape=[10,30,128])


In [4]: output.get_shape()

Out[4]: TensorShape([Dimension(10), Dimension(30), Dimension(128)])



In [7]: output_trans = tf.transpose(output, [1, 0, 2])


In [8]: output_trans.get_shape()

Out[8]: TensorShape([Dimension(30), Dimension(10), Dimension(128)])


In [9]: output_trans[-1].get_shape()

Out[9]: TensorShape([Dimension(10), Dimension(128)])


In [10]: output[:,-1,:].get_shape()

Out[10]: TensorShape([Dimension(10), Dimension(128)])
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: