RNN代码解读之char-RNN with TensorFlow(sample.py)
2016-12-08 15:24
447 查看
此工程解读链接(建议按顺序阅读):
RNN代码解读之char-RNN with TensorFlow(model.py)
RNN代码解读之char-RNN with TensorFlow(train.py)
RNN代码解读之char-RNN with TensorFlow(util.py)
RNN代码解读之char-RNN with TensorFlow(sample.py)
终于到了最后,在这里我们用到了sample.py以及model.py里面的sample方法。
在采样过程中,要注意batch_size和sequence_length都是1了,我们只需要输入一个,根据这一个字符计算下一个就好了,因此在model中,某些张量的尺寸,比如说prob,就会改变,这一点在下面也有注明。
我在这里有一个问题,希望可以获得大家的指点。代码中sample设置了三种方法,其中用到了一种叫weighted_pick的方法,感觉像是在概率分布函数中随机插值取样,这里不太懂为什么要这么做,取最大不是更好吗?希望大家不吝赐教,非常感谢!
关于char-RNN with TensorFlow的个人解读到这里就结束了,我会陆续更新对于CNN和RNN的论文以及工程代码分析,仅代表个人看法,如果有任何问题欢迎指正。
参考资料:
http://blog.csdn.net/mydear_11000/article/details/52776295
https://github.com/sherjilozair/char-rnn-tensorflow
http://www.tensorfly.cn/tfdoc/api_docs/python/constant_op.html#truncated_normal
RNN代码解读之char-RNN with TensorFlow(model.py)
RNN代码解读之char-RNN with TensorFlow(train.py)
RNN代码解读之char-RNN with TensorFlow(util.py)
RNN代码解读之char-RNN with TensorFlow(sample.py)
终于到了最后,在这里我们用到了sample.py以及model.py里面的sample方法。
在采样过程中,要注意batch_size和sequence_length都是1了,我们只需要输入一个,根据这一个字符计算下一个就好了,因此在model中,某些张量的尺寸,比如说prob,就会改变,这一点在下面也有注明。
我在这里有一个问题,希望可以获得大家的指点。代码中sample设置了三种方法,其中用到了一种叫weighted_pick的方法,感觉像是在概率分布函数中随机插值取样,这里不太懂为什么要这么做,取最大不是更好吗?希望大家不吝赐教,非常感谢!
#-*-coding:utf-8-*- from __future__ import print_function import numpy as np import tensorflow as tf import argparse import time import os from six.moves import cPickle from utils import TextLoader from model import Model from six import text_type def main(): #一看到这个是不是特别的熟悉?没错和train.py里面的一个意思 parser = argparse.ArgumentParser() #储存checkpoint,不太懂为什么sample的时候还有这个选项 parser.add_argument('--save_dir', type=str, default='save', help='model directory to store checkpointed models') #生成的字符个数 parser.add_argument('-n', type=int, default=500, help='number of characters to sample') #指定一个开头,如果有开头标志的话这里可以是其他的,默认设置时" " parser.add_argument('--prime', type=text_type, default=u' ', help='prime text') parser.add_argument('--sample', type=int, default=1, help='0 to use max at each timestep, 1 to sample at each timestep, 2 to sample on spaces') args = parser.parse_args() sample(args) def sample(args): #载入各种参数 with open(os.path.join(args.save_dir, 'config.pkl'), 'rb') as f: saved_args = cPickle.load(f) with open(os.path.join(args.save_dir, 'chars_vocab.pkl'), 'rb') as f: chars, vocab = cPickle.load(f) #使用模型 model = Model(saved_args, True) #let's roll with tf.Session() as sess: #初始化所有的变量 tf.initialize_all_variables().run() #创建一个saver,后面模型重载 saver = tf.train.Saver(tf.all_variables()) #载入checkpoint ckpt = tf.train.get_checkpoint_state(args.save_dir) if ckpt and ckpt.model_checkpoint_path: #官方说明:Restores previously saved variables saver.restore(sess, ckpt.model_checkpoint_path) #来我们再回到model.py看一下这sample方法 print(model.sample(sess, chars, vocab, args.n, args.prime, args.sample)) if __name__ == '__main__': main()
def sample(self, sess, chars, vocab, num=200, prime='The ', sampling_type=1): #let's go state = sess.run(self.cell.zero_state(1, tf.float32)) #先把开头自己预设的prime_txt送进模型,不计输出 #这一块程序段还是很好理解的 for char in prime[:-1]: x = np.zeros((1, 1)) #前面说过,vocab是个字典 x[0, 0] = vocab[char] feed = {self.input_data: x, self.initial_state:state} [state] = sess.run([self.final_state], feed) #weight = [0.1,0.2,0.3,0.4] #(分布函数)t = [0.1,0.3,0.6,1] #s = 1 #为什么这样pick还不是太懂 def weighted_pick(weights): t = np.cumsum(weights) s = np.sum(weights) return(int(np.searchsorted(t, np.random.rand(1)*s))) ret = prime char = prime[-1] for n in range(num): x = np.zeros((1, 1)) x[0, 0] = vocab[char] feed = {self.input_data: x, self.initial_state:state} [probs, state] abd5 = sess.run([self.probs, self.final_state], feed) #注意!!这里的probs是长度是1*65的,前面在训练的时候因为batch_size和seq_length都是50 # 所以是2500*65之后用了这2500组预测结果来求loss,再BPTT, # 这里只是根据一个输入求一个输出,batch_size和seq_length都是1,因此是1*65 # 所以p就是代表了由长度为65的一个数组,每一位代表着预测为该位的概率值 p = probs[0] if sampling_type == 0: #第一种方法,直接取最大的prob的索引值 sample = np.argmax(p) elif sampling_type == 2: #第二种方法,如果输入是空格,则wighted_pick #否则取最大prob的索引 if char == ' ': sample = weighted_pick(p) else: sample = np.argmax(p) else: #一直使用weighted_pick方法 # sampling_type == 1 default: sample = weighted_pick(p) pred = chars[sample] ret += pred char = pred return ret
关于char-RNN with TensorFlow的个人解读到这里就结束了,我会陆续更新对于CNN和RNN的论文以及工程代码分析,仅代表个人看法,如果有任何问题欢迎指正。
参考资料:
http://blog.csdn.net/mydear_11000/article/details/52776295
https://github.com/sherjilozair/char-rnn-tensorflow
http://www.tensorfly.cn/tfdoc/api_docs/python/constant_op.html#truncated_normal
相关文章推荐
- RNN代码解读之char-RNN with TensorFlow(model.py)
- tensorflow rnn 最简单实现代码
- seq2seq_model.py ValueError: Attempt to reuse RNNCell <tensorflow.contrib.rnn.python.ops.core_rnn_ce
- 生成对抗网络DCGAN+Tensorflow代码学习笔记(一)----main.py
- 生成对抗网络DCGAN+Tensorflow代码学习笔记(三)----ops.py
- Tensorflow RNN Regression代码示例
- TensorFlow RNN 教程和代码
- 解读tensorflow之rnn 的示例 ptb_word_lm.py 这两天想搞清楚用tensorflow来实现rnn/lstm如何做,但是google了半天,发现tf在rnn方面的实现代码或者教程
- 生成对抗网络DCGAN+Tensorflow代码学习笔记(二)----utils.py
- tensorflow rnn 最简单实现代码
- Tensorflow RNN 关于mnist 的代码示例
- ValueError: Attempt to reuse RNNCell <tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl.BasicLSTMCell object at 0x7f1a3c448390> with a different variable scope than its first use.解决方法
- 机器学习:DeepDreaming with TensorFlow (三)
- 机器学习: Tensor Flow with CNN 做表情识别
- google DQN tensorFlow框架实现 源码解读《一》tensorFlow基础学习
- 机器学习:DeepDreaming with TensorFlow (二)
- 机器学习: DeepDreaming with TensorFlow (一)
- debug tensorflow / 使用gdb调试tensorflow底层C++代码
- Google DQN tensorflow框架实现 源码解读《二》
- 资源推荐 | TensorFlow电子书《FIRST CONTACT WITH TENSORFLOW》