tf.contrib.rnn.BasicLSTMCell, tf.contrib.rnn.MultiRNNCell深度解析
2017-04-23 21:21
721 查看
tf.contrib.rnn.BasicRnnCell
首先来看看BasicRNNCell的源码class BasicRNNCell(RNNCell): """The most basic RNN cell.""" def __init__(self, num_units, input_size=None, activation=tanh, reuse=None): if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) self._num_units = num_units self._activation = activation self._reuse = reuse @property def state_size(self): return self._num_units @property def output_size(self): return self._num_units def __call__(self, inputs, state, scope=None): """Most basic RNN: output = new_state = act(W * input + U * state + B).""" with _checked_scope(self, scope or "basic_rnn_cell", reuse=self._reuse): output = self._activation( _linear([inputs, state], self._num_units, True)) return output, output
BasicRNNCell是最基本的RNN cell单元。
输入参数:num_units:RNN层神经元的个数
input_size(该参数已被弃用)
activation: 内部状态之间的激活函数
reuse: Python布尔值, 描述是否重用现有作用域中的变量
从源码中可以看出通过BasicRnnCell定义的实例对象Cell,其中两个属性Cell.state_size和Cell.output_size返回的都是num_units. 通过_call_将实例A变成一个可调用的对象,当传入输入input和状态state后,根据公式output = new_state = act(W * input + U * state + B) 可以得到相应的输出并返回,
tf.contrib.rnn.BasicLSTMCell
源码如下class BasicLSTMCell(RNNCell): """Basic LSTM recurrent network cell. The implementation is based on: http://arxiv.org/abs/1409.2329. We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training. It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline. For advanced models, please use the full LSTMCell that follows. """ def __init__(self, num_units, forget_bias=1.0, input_size=None, state_is_tuple=True, activation=tanh, reuse=None): """Initialize the basic LSTM cell. Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (see above). input_size: Deprecated and unused. state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. The latter behavior will soon be deprecated. activation: Activation function of the inner states. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) self._num_units = num_units self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple self._activation = activation self._reuse = reuse @property def state_size(self): return (LSTMStateTuple(self._num_units, self._num_units) if self._state_is_tuple else 2 * self._num_units) @property def output_size(self): return self._num_units def __call__(self, inputs, state, scope=None): """Long short-term memory cell (LSTM).""" with _checked_scope(self, scope or "basic_lstm_cell", reuse=self._reuse): # Parameters of gates are concatenated into one multiply for efficiency. if self._state_is_tuple: c, h = state else: c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1) concat = _linear([inputs, h], 4 * self._num_units, True) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j)) new_h = self._activation(new_c) * sigmoid(o) if self._state_is_tuple: new_state = LSTMStateTuple(new_c, new_h) else: new_state = array_ops.concat([new_c, new_h], 1) return new_h, new_state
关于LSTMStateTuple的源码如下
class LSTMStateTuple(_LSTMStateTuple): """Tuple used by LSTM Cells for `state_size`, `zero_state`, and output state. Stores two elements: `(c, h)`, in that order. Only used when `state_is_tuple=True`. """ __slots__ = () @property def dtype(self): (c, h) = self if not c.dtype == h.dtype: raise TypeError("Inconsistent internal state: %s vs %s" % (str(c.dtype), str(h.dtype))) return c.dtype
BasicLSTMCell类是最基本的LSTM循环神经网络单元。
输入参数和BasicRNNCell差不多
num_units: LSTM cell层中的单元数
forget_bias: forget gates中的偏置
state_is_tuple: 还是设置为True吧, 返回 (c_state , m_state)的二元组
activation: 状态之间转移的激活函数
reuse: Python布尔值, 描述是否重用现有作用域中的变量
state_size属性:如果state_is_tuple为true的话,返回的是二元状态元祖。
output_size属性:返回LSTM中的num_units, 也就是LSTM Cell中的单元数,在初始化是输入的num_units参数
_call_()将类实例转化为一个可调用的对象,传入输入input和状态state,根据LSTM的计算公式, 返回new_h, 和新的状态new_state. 其中new_state = (new_c, new_h)关于具体的理论详细见这篇论文https://arxiv.org/pdf/1409.2329.pdf
相关文章推荐
- RNN调试错误:lstm_cell = tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(lstm_size) 方法已失效
- Tensorflow--tf.nn.rnn_cell.BasicLSTMCell
- ValueError: Attempt to reuse RNNCell <tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.BasicLSTMCell object at 0x7f1a3c448390> with a different variable scope than its first use.解决方法
- tensorflow教程:LSTMCell和BasicLSTMCell
- tf.nn.rnn_cell.MultiRNNCell函数用法
- 对tf.nn.rnn_cell.BasicLSTMCell参数n_hidden的理解
- TensorFlow reuse=True BasicRNNCell
- 深度学习笔记——深度学习框架TensorFlow(四)[高级API tf.contrib.learn]
- tf.nn has no attribute rnn_cell in version 1.0.1
- Trying to share variable rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel 解决方法
- tf.contrib.legacy_seq2seq.basic_rnn_seq2seq 函数 example 最简单实现
- 深度学习笔记——深度学习框架TensorFlow(八)[Logging and Monitoring Basics with tf.contrib.learn]
- tf.nn.rnn_cell.GRUCell函数的使用
- 【Ubuntu-Tensorflow】TF1.0到TF1.2出现“Key LSTM/basic_lstm_cell/bias not found in checkpoin”问题
- ValueError: Variable lstm_cell/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel already exists
- seq2seq_model.py ValueError: Attempt to reuse RNNCell <tensorflow.contrib.rnn.python.ops.core_rnn_ce
- TensorFlow--tf.nn.rnn_cell.DrououtWrapper
- tf.nn.rnn_cell.DroupoutWrapper函数的用法
- tf.nn.rnn_cell.DrououtWrapper函数的用法
- BasicLSTMCell