您的位置:首页 > 其它

tensorflow1.0 LSTM实现

2017-06-20 17:21 1186 查看
代码中num_step可以设置为需要的序列长度,常常为了梯度反传的方便将num_step设置为较小的值。   

 def LSTM_single( name,  _X, _istate, _weights, _biases):

        # input shape: (batch_size, n_steps, n_input)

        _X = tf.transpose(_X, [1, 0, 2])  # permute num_steps and batch_size

        # Reshape to prepare input to hidden activation

        _X = tf.reshape(_X, [self.num_steps * self.batch_size, self.num_input]) # (num_steps*batch_size, num_input)

        # Split data because rnn cell needs a list of inputs for the RNN inner loop

        _X = tf.split(axis=0, num_or_size_splits=self.num_steps, value=_X) # n_steps * (batch_size, num_input)

        #print("_X: ", _X)

        #cell = tf.nn.rnn_cell.LSTMCell(self.num_input, self.num_input)

        cell = tf.contrib.rnn.LSTMCell(self.num_input, state_is_tuple=False)

        state = _istate

        outputs_seq = []

        for step in range(self.num_steps):

            outputs, state = tf.contrib.rnn.static_rnn(cell, [_X [step] ], state, dtype=tf.float32)

            #get all output from inner loop by jr

            outputs_seq.append(outputs)

            tf.get_variable_scope().reuse_variables()

            #reshape output as input _X dim by jr#

        outputs_seq = tf.reshape(outputs_seq,[self.num_steps,self.batch_size,self.num_input])

        outputs_seq = tf.transpose(outputs_seq,[1, 0, 2])

        #return outputs
        return outputs_seq,state

调用LSTM进行训练代码说明:

    outputs,state = LSTM_single('lstm_train', x,
istate, weights, biases

    correct_prediction = tf.square(outputs - y)

    accuracy = tf.reduce_mean(correct_prediction) * 100

    learning_rate = 0.00001

    with tf.variable_scope(tf.get_variable_scope(), reuse=False):

           optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(accuracy) # Adam Optimizer

     with tf.Session() as sess:

while i < iterations:

outputs = sess.run(outputs_seq, feed_dict={x: batch_xs, y: batch_ys,
istate: state})

sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys, istate: state})

state = state.eval(feed_dict={x: batch_xs, y: batch_ys, istate: state})
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  tensorflow lstm