diff --git a/.gitignore b/.gitignore index 168abd66..e6a880b0 100644 --- a/.gitignore +++ b/.gitignore @@ -12,7 +12,8 @@ data/.DS_Store *.gz .spyproject/ .vscode/* -model.npz env/ venv/ .idea/ +test.py +chat.txt diff --git a/data/reddit_data/data.py b/data/reddit_data/data.py new file mode 100644 index 00000000..7d354ef3 --- /dev/null +++ b/data/reddit_data/data.py @@ -0,0 +1,288 @@ +EN_WHITELIST = '0123456789abcdefghijklmnopqrstuvwxyz ' # space is included in whitelist +EN_BLACKLIST = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\'' + +FILENAME = 'chat.txt' + +limit = { + 'maxq' : 20, + 'minq' : 0, + 'maxa' : 20, + 'mina' : 3 + } + +UNK = 'unk' +VOCAB_SIZE = 6000 + +import random +import sys + +import nltk +import itertools +from collections import defaultdict + +import numpy as np + +import pickle + + +def ddefault(): + return 1 + +''' + read lines from file + return [list of lines] +''' + +def read_lines(filename): + return open(filename).read().split('\n')[:-1] + + +''' + split sentences in one line + into multiple lines + return [list of lines] + +''' +def split_line(line): + return line.split('.') + + +''' + remove anything that isn't in the vocabulary + return str(pure ta/en) + +''' +def filter_line(line, whitelist): + return ''.join([ ch for ch in line if ch in whitelist ]) + + +''' + read list of words, create index to word, + word to index dictionaries + return tuple( vocab->(word, count), idx2w, w2idx ) + +''' +def index_(tokenized_sentences, vocab_size): + # get frequency distribution + freq_dist = nltk.FreqDist(itertools.chain(*tokenized_sentences)) + # get vocabulary of 'vocab_size' most used words + vocab = freq_dist.most_common(vocab_size) + # index2word + index2word = ['_'] + [UNK] + [ x[0] for x in vocab ] + # word2index + word2index = dict([(w,i) for i,w in enumerate(index2word)] ) + return index2word, word2index, freq_dist + + +''' + filter too long and too short sequences + return tuple( filtered_ta, filtered_en ) + +''' +def filter_data(sequences): + filtered_q, filtered_a = [], [] + raw_data_len = len(sequences)//2 + + for i in range(0, len(sequences), 2): + qlen, alen = len(sequences[i].split(' ')), len(sequences[i+1].split(' ')) + if qlen >= limit['minq'] and qlen <= limit['maxq']: + if alen >= limit['mina'] and alen <= limit['maxa']: + filtered_q.append(sequences[i]) + filtered_a.append(sequences[i+1]) + + # print the fraction of the original data, filtered + filt_data_len = len(filtered_q) + filtered = int((raw_data_len - filt_data_len)*100/raw_data_len) + print(str(filtered) + '% filtered from original data') + + return filtered_q, filtered_a + + + + + +''' + create the final dataset : + - convert list of items to arrays of indices + - add zero padding + return ( [array_en([indices]), array_ta([indices]) ) + +''' +def zero_pad(qtokenized, atokenized, w2idx): + # num of rows + data_len = len(qtokenized) + + # numpy arrays to store indices + idx_q = np.zeros([data_len, limit['maxq']], dtype=np.int32) + idx_a = np.zeros([data_len, limit['maxa']], dtype=np.int32) + + for i in range(data_len): + q_indices = pad_seq(qtokenized[i], w2idx, limit['maxq']) + a_indices = pad_seq(atokenized[i], w2idx, limit['maxa']) + + #print(len(idx_q[i]), len(q_indices)) + #print(len(idx_a[i]), len(a_indices)) + idx_q[i] = np.array(q_indices) + idx_a[i] = np.array(a_indices) + + return idx_q, idx_a + + +''' + replace words with indices in a sequence + replace with unknown if word not in lookup + return [list of indices] + +''' +def pad_seq(seq, lookup, maxlen): + indices = [] + for word in seq: + if word in lookup: + indices.append(lookup[word]) + else: + indices.append(lookup[UNK]) + return indices + [0]*(maxlen - len(seq)) + + +def process_data(): + + print('\n>> Read lines from file') + lines = read_lines(filename=FILENAME) + + # change to lower case (just for en) + lines = [ line.lower() for line in lines ] + + print('\n:: Sample from read(p) lines') + print(lines[121:125]) + + # filter out unnecessary characters + print('\n>> Filter lines') + lines = [ filter_line(line, EN_WHITELIST) for line in lines ] + print(lines[121:125]) + + # filter out too long or too short sequences + print('\n>> 2nd layer of filtering') + qlines, alines = filter_data(lines) + print('\nq : {0} ; a : {1}'.format(qlines[60], alines[60])) + print('\nq : {0} ; a : {1}'.format(qlines[61], alines[61])) + + + # convert list of [lines of text] into list of [list of words ] + print('\n>> Segment lines into words') + qtokenized = [ wordlist.split(' ') for wordlist in qlines ] + atokenized = [ wordlist.split(' ') for wordlist in alines ] + print('\n:: Sample from segmented list of words') + print('\nq : {0} ; a : {1}'.format(qtokenized[60], atokenized[60])) + print('\nq : {0} ; a : {1}'.format(qtokenized[61], atokenized[61])) + + + # indexing -> idx2w, w2idx : en/ta + print('\n >> Index words') + idx2w, w2idx, freq_dist = index_( qtokenized + atokenized, vocab_size=VOCAB_SIZE) + + print('\n >> Zero Padding') + idx_q, idx_a = zero_pad(qtokenized, atokenized, w2idx) + + print('\n >> Save numpy arrays to disk') + # save them + np.save('idx_q.npy', idx_q) + np.save('idx_a.npy', idx_a) + + # let us now save the necessary dictionaries + metadata = { + 'w2idx' : w2idx, + 'idx2w' : idx2w, + 'limit' : limit, + 'freq_dist' : freq_dist + } + + # write to disk : data control dictionaries + with open('metadata.pkl', 'wb') as f: + pickle.dump(metadata, f) + +def load_data(PATH=''): + # read data control dictionaries + try: + with open(PATH + 'metadata.pkl', 'rb') as f: + metadata = pickle.load(f) + except: + metadata = None + # read numpy arrays + idx_q = np.load(PATH + 'idx_q.npy') + idx_a = np.load(PATH + 'idx_a.npy') + return metadata, idx_q, idx_a + +import numpy as np +from random import sample + +''' + split data into train (70%), test (15%) and valid(15%) + return tuple( (trainX, trainY), (testX,testY), (validX,validY) ) + +''' +def split_dataset(x, y, ratio = [0.7, 0.15, 0.15] ): + # number of examples + data_len = len(x) + lens = [ int(data_len*item) for item in ratio ] + + trainX, trainY = x[:lens[0]], y[:lens[0]] + testX, testY = x[lens[0]:lens[0]+lens[1]], y[lens[0]:lens[0]+lens[1]] + validX, validY = x[-lens[-1]:], y[-lens[-1]:] + + return (trainX,trainY), (testX,testY), (validX,validY) + + +''' + generate batches from dataset + yield (x_gen, y_gen) + + TODO : fix needed + +''' +def batch_gen(x, y, batch_size): + # infinite while + while True: + for i in range(0, len(x), batch_size): + if (i+1)*batch_size < len(x): + yield x[i : (i+1)*batch_size ].T, y[i : (i+1)*batch_size ].T + +''' + generate batches, by random sampling a bunch of items + yield (x_gen, y_gen) + +''' +def rand_batch_gen(x, y, batch_size): + while True: + sample_idx = sample(list(np.arange(len(x))), batch_size) + yield x[sample_idx].T, y[sample_idx].T + +#''' +# convert indices of alphabets into a string (word) +# return str(word) +# +#''' +#def decode_word(alpha_seq, idx2alpha): +# return ''.join([ idx2alpha[alpha] for alpha in alpha_seq if alpha ]) +# +# +#''' +# convert indices of phonemes into list of phonemes (as string) +# return str(phoneme_list) +# +#''' +#def decode_phonemes(pho_seq, idx2pho): +# return ' '.join( [ idx2pho[pho] for pho in pho_seq if pho ]) + + +''' + a generic decode function + inputs : sequence, lookup + +''' +def decode(sequence, lookup, separator=''): # 0 used for padding, is ignored + return separator.join([ lookup[element] for element in sequence if element ]) + + + +if __name__ == '__main__': + process_data() diff --git a/data/reddit_data/idx_a.npy b/data/reddit_data/idx_a.npy new file mode 100644 index 00000000..2ea727af Binary files /dev/null and b/data/reddit_data/idx_a.npy differ diff --git a/data/reddit_data/idx_q.npy b/data/reddit_data/idx_q.npy new file mode 100644 index 00000000..6ff9c462 Binary files /dev/null and b/data/reddit_data/idx_q.npy differ diff --git a/data/reddit_data/metadata.pkl b/data/reddit_data/metadata.pkl new file mode 100644 index 00000000..f3fe63f4 Binary files /dev/null and b/data/reddit_data/metadata.pkl differ diff --git a/data/twitter/data.py b/data/twitter/data.py index 22ca4be9..7d354ef3 100644 --- a/data/twitter/data.py +++ b/data/twitter/data.py @@ -1,7 +1,7 @@ EN_WHITELIST = '0123456789abcdefghijklmnopqrstuvwxyz ' # space is included in whitelist EN_BLACKLIST = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~\'' -FILENAME = 'data/chat.txt' +FILENAME = 'chat.txt' limit = { 'maxq' : 20, @@ -31,8 +31,8 @@ def ddefault(): ''' read lines from file return [list of lines] - ''' + def read_lines(filename): return open(filename).read().split('\n')[:-1] diff --git a/data/twitter/idx_a.npy b/data/twitter/idx_a.npy index 33ecfb87..56c4b438 100644 Binary files a/data/twitter/idx_a.npy and b/data/twitter/idx_a.npy differ diff --git a/data/twitter/idx_q.npy b/data/twitter/idx_q.npy index 4b78249d..680363aa 100644 Binary files a/data/twitter/idx_q.npy and b/data/twitter/idx_q.npy differ diff --git a/data/twitter/metadata.pkl b/data/twitter/metadata.pkl index b1556fad..5a320d92 100644 Binary files a/data/twitter/metadata.pkl and b/data/twitter/metadata.pkl differ diff --git a/infer.py b/infer.py new file mode 100644 index 00000000..1ff245e3 --- /dev/null +++ b/infer.py @@ -0,0 +1,21 @@ +import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +from main import * # import the main python file with model from the example +import time +import tensorlayer as tl + +load_weights = tl.files.load_npz(name='saved/model.npz') +tl.files.assign_weights(load_weights, model_) + +top_n = 3 + +def respond(input): + sentence = inference(input, top_n) + response=' '.join(sentence) + return response + +while True: + userInput = input("Query > ") + for i in range(top_n): + print("bot# ", respond(userInput)) + diff --git a/main.py b/main.py index b5eab72f..0da146c2 100644 --- a/main.py +++ b/main.py @@ -7,7 +7,7 @@ from tensorlayer.cost import cross_entropy_seq, cross_entropy_seq_with_mask from tqdm import tqdm from sklearn.utils import shuffle -from data.twitter import data +from data.reddit_data import data from tensorlayer.models.seq2seq import Seq2seq from tensorlayer.models.seq2seq_with_attention import Seq2seqLuongAttention import os @@ -24,67 +24,64 @@ def initial_setup(data_corpus): validY = tl.prepro.remove_pad_sequences(validY.tolist()) return metadata, trainX, trainY, testX, testY, validX, validY +data_corpus = "reddit_data" +#data preprocessing +metadata, trainX, trainY, testX, testY, validX, validY = initial_setup(data_corpus) -if __name__ == "__main__": - data_corpus = "twitter" +# Parameters +src_len = len(trainX) +tgt_len = len(trainY) - #data preprocessing - metadata, trainX, trainY, testX, testY, validX, validY = initial_setup(data_corpus) +assert src_len == tgt_len - # Parameters - src_len = len(trainX) - tgt_len = len(trainY) +batch_size = 32 +n_step = src_len // batch_size +src_vocab_size = len(metadata['idx2w']) # 8002 (0~8001) +emb_dim = 1024 - assert src_len == tgt_len +word2idx = metadata['w2idx'] # dict word 2 index +idx2word = metadata['idx2w'] # list index 2 word - batch_size = 32 - n_step = src_len // batch_size - src_vocab_size = len(metadata['idx2w']) # 8002 (0~8001) - emb_dim = 1024 +unk_id = word2idx['unk'] # 1 +pad_id = word2idx['_'] # 0 - word2idx = metadata['w2idx'] # dict word 2 index - idx2word = metadata['idx2w'] # list index 2 word +start_id = src_vocab_size # 8002 +end_id = src_vocab_size + 1 # 8003 - unk_id = word2idx['unk'] # 1 - pad_id = word2idx['_'] # 0 +word2idx.update({'start_id': start_id}) +word2idx.update({'end_id': end_id}) +idx2word = idx2word + ['start_id', 'end_id'] - start_id = src_vocab_size # 8002 - end_id = src_vocab_size + 1 # 8003 +src_vocab_size = tgt_vocab_size = src_vocab_size + 2 - word2idx.update({'start_id': start_id}) - word2idx.update({'end_id': end_id}) - idx2word = idx2word + ['start_id', 'end_id'] +num_epochs = 50 +vocabulary_size = src_vocab_size - src_vocab_size = tgt_vocab_size = src_vocab_size + 2 +def inference(seed, top_n): + model_.eval() + seed_id = [word2idx.get(w, unk_id) for w in seed.split(" ")] + sentence_id = model_(inputs=[[seed_id]], seq_length=20, start_token=start_id, top_n = top_n) + sentence = [] + for w_id in sentence_id[0]: + w = idx2word[w_id] + if w == 'end_id': + break + sentence = sentence + [w] + return sentence - num_epochs = 50 - vocabulary_size = src_vocab_size - +decoder_seq_length = 20 +model_ = Seq2seq( + decoder_seq_length = decoder_seq_length, + cell_enc=tf.keras.layers.GRUCell, + cell_dec=tf.keras.layers.GRUCell, + n_layer=3, + n_units=256, + embedding_layer=tl.layers.Embedding(vocabulary_size=vocabulary_size, embedding_size=emb_dim), + ) - def inference(seed, top_n): - model_.eval() - seed_id = [word2idx.get(w, unk_id) for w in seed.split(" ")] - sentence_id = model_(inputs=[[seed_id]], seq_length=20, start_token=start_id, top_n = top_n) - sentence = [] - for w_id in sentence_id[0]: - w = idx2word[w_id] - if w == 'end_id': - break - sentence = sentence + [w] - return sentence - - decoder_seq_length = 20 - model_ = Seq2seq( - decoder_seq_length = decoder_seq_length, - cell_enc=tf.keras.layers.GRUCell, - cell_dec=tf.keras.layers.GRUCell, - n_layer=3, - n_units=256, - embedding_layer=tl.layers.Embedding(vocabulary_size=vocabulary_size, embedding_size=emb_dim), - ) - +if __name__ == "__main__": # Uncomment below statements if you have already saved the model diff --git a/saved/model.npz b/saved/model.npz new file mode 100644 index 00000000..b54e31a9 Binary files /dev/null and b/saved/model.npz differ