读facebook babi数据集的preprocess代码,感觉还是有必要记下,省不少事
2017-06-29 14:28
656 查看
data_utils.py
main.py
import os import re import numpy as np def load_task(data_dir, task_id, only_supporting=False): '''Load the nth task. There are 20 tasks in total. Returns a tuple containing the training and testing data for the task. ''' assert task_id > 0 and task_id < 21 files = os.listdir(data_dir) files = [os.path.join(data_dir, f) for f in files] s = 'qa{}_'.format(task_id) train_file = [f for f in files if s in f and 'train' in f][0] test_file = [f for f in files if s in f and 'test' in f][0] train_data = get_stories(train_file, only_supporting) test_data = get_stories(test_file, only_supporting) return train_data, test_data def tokenize(sent): '''Return the tokens of a sentence including punctuation. >>> tokenize('Bob dropped the apple. Where is the apple?') ['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?'] ''' return [x.strip() for x in re.split('(\W+)?', sent) if x.strip()] def parse_stories(lines, only_supporting=False): '''Parse stories provided in the bAbI tasks format If only_supporting is true, only the sentences that support the answer are kept. ''' data = [] story = [] for line in lines: line = str.lower(line) nid, line = line.split(' ', 1) nid = int(nid) if nid == 1: story = [] if '\t' in line: # question q, a, supporting = line.split('\t') q = tokenize(q) #a = tokenize(a) # answer is one vocab word even if it's actually multiple words a = [a] substory = None # remove question marks if q[-1] == "?": q = q[:-1] if only_supporting: # Only select the related substory supporting = map(int, supporting.split()) substory = [story[i - 1] for i in supporting] else: # Provide all the substories substory = [x for x in story if x] data.append((substory, q, a)) story.append('') else: # regular sentence # remove periods sent = tokenize(line) if sent[-1] == ".": sent = sent[:-1] story.append(sent) return data def get_stories(f, only_supporting=False): '''Given a file name, read the file, retrieve the stories, and then convert the sentences into a single story. If max_length is supplied, any stories longer than max_length tokens will be discarded. ''' with open(f) as f: return parse_stories(f.readlines(), only_supporting=only_supporting) def vectorize_data(data, word_idx, sentence_size, memory_size): """ Vectorize stories and queries. If a sentence length < sentence_size, the sentence will be padded with 0's. If a story length < memory_size, the story will be padded with empty memories. Empty memories are 1-D arrays of length sentence_size filled with 0's. The answer array is returned as a one-hot encoding. """ S = [] Q = [] A = [] for story, query, answer in data: ss = [] for i, sentence in enumerate(story, 1): ls = max(0, sentence_size - len(sentence)) ss.append([word_idx[w] for w in sentence] + [0] * ls) # take only the most recent sentences that fit in memory ss = ss[::-1][:memory_size][::-1] # Make the last word of each sentence the time 'word' which # corresponds to vector of lookup table for i in range(len(ss)): ss[i][-1] = len(word_idx) - memory_size - i + len(ss) # pad to memory_size lm = max(0, memory_size - len(ss)) for _ in range(lm): ss.append([0] * sentence_size) lq = max(0, sentence_size - len(query)) q = [word_idx[w] for w in query] + [0] * lq y = np.zeros(len(word_idx) + 1) # 0 is reserved for nil word for a in answer: y[word_idx[a]] = 1 S.append(ss) Q.append(q) A.append(y) return np.array(S), np.array(Q), np.array(A)
main.py
from data_utils import load_task, vectorize_data from sklearn import cross_validation, metrics from itertools import chain from six.moves import range, reduce import tensorflow as tf import numpy as np tf.flags.DEFINE_float("learning_rate", 0.01, "Learning rate for SGD.") tf.flags.DEFINE_float("anneal_rate", 25, "Number of epochs between halving the learnign rate.") tf.flags.DEFINE_float("anneal_stop_epoch", 100, "Epoch number to end annealed lr schedule.") tf.flags.DEFINE_float("max_grad_norm", 40.0, "Clip gradients to this norm.") tf.flags.DEFINE_integer("evaluation_interval", 10, "Evaluate and print results every x epochs") tf.flags.DEFINE_integer("batch_size", 32, "Batch size for training.") tf.flags.DEFINE_integer("hops", 3, "Number of hops in the Memory Network.") tf.flags.DEFINE_integer("epochs", 100, "Number of epochs to train for.") tf.flags.DEFINE_integer("embedding_size", 20, "Embedding size for embedding matrices.") tf.flags.DEFINE_integer("memory_size", 50, "Maximum size of memory.") tf.flags.DEFINE_integer("task_id", 1, "bAbI task id, 1 <= id <= 20") tf.flags.DEFINE_integer("random_state", None, "Random state.") tf.flags.DEFINE_string("data_dir", "/home/gt/Relation-Network-Tensorflow/tasks_1-20_v1-2/en/", "Directory containing bAbI tasks") FLAGS = tf.flags.FLAGS print("Started Task:", FLAGS.task_id) # task data train, test = load_task(FLAGS.data_dir, FLAGS.task_id) data = train + test vocab = sorted(reduce(lambda x, y: x | y, (set(list(chain.from_iterable(s)) + q + a) for s, q, a in data))) word_idx = dict((c, i + 1) for i, c in enumerate(vocab)) max_story_size = max(map(len, (s for s, _, _ in data))) mean_story_size = int(np.mean([ len(s) for s, _, _ in data ])) sentence_size = max(map(len, chain.from_iterable(s for s, _, _ in data))) query_size = max(map(len, (q for _, q, _ in data))) memory_size = min(FLAGS.memory_size, max_story_size) # Add time words/indexes for i in range(memory_size): word_idx['time{}'.format(i+1)] = 'time{}'.format(i+1) vocab_size = len(word_idx) + 1 # +1 for nil word sentence_size = max(query_size, sentence_size) # for the position sentence_size += 1 # +1 for time words print("Longest sentence length", sentence_size) print("Longest story length", max_story_size) print("Average story length", mean_story_size) # train/validation/test sets S, Q, A = vectorize_data(train, word_idx, sentence_size, memory_size) trainS, valS, trainQ, valQ, trainA, valA = cross_validation.train_test_split(S, Q, A, test_size=.1, random_state=FLAGS.random_state) testS, testQ, testA = vectorize_data(test, word_idx, sentence_size, memory_size)
相关文章推荐
- 大家还是要常用用csc,个人感觉有时vs有不少不太方便的东西。
- 感觉 CSDN 博客、CNBLOG 的博客和 51CTO 的博客里还是有不少高质量的技术文章
- 使用了继承、多态还有工厂模式和反射,但是还是没有OO的感觉。[已经增加了实现的代码]
- 牵扯较多属性和方法的类题目,很简单的题目本来不想发的,如果有同学学到这个题目感觉太长不愿敲代码,copy走我的即可~不过还是建议自己打一打
- HTML全选和反选的按钮代码,还是感觉这种按钮爽!
- Folly: Facebook Open-source Library Readme.md 和 Overview.md(感觉包含的东西并不多,还是Boost更有用)
- 大家还是要常用用csc,个人感觉有时vs有不少不太方便的东西
- 牵扯较多属性和方法的类题目,很简单的题目本来不想发的,如果有同学学到这个题目感觉太长不愿敲代码,copy走我的即可~不过还是建议自己打一打
- 决定还是记下这天2005.4.1
- 发送邮件成功率最高代码(转载,个人感觉不错)
- 感觉还是书看少的问题。
- 您的 Java 代码安全吗 — 还是暴露在外?
- 最近刚开学,给同学修电脑但是感觉自己水平,还是查太远
- 帮老朱配的机器(仅供参考---自我感觉还是蛮均衡的)
- 3月12日,感觉有些东西还是不一样的
- 今天看了连战在北大的演讲,感觉还是很有诚意的!
- 今天翻译了英文版书关于Mirus组件的一点东西,感觉还是有收获的~
- 代码 设计 生活 (1)--- 无赖还是注定
- 自己写的一个删除目录的代码(自己感觉不错2004-09-12)
- 用来用去还是感觉Eclipse3 + 中文语言包 + lomboz.301 + jrun4(weblogic 8) 开发J2EE 比较顺利~