Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ Code and model for the paper "Improving Language Understanding by Generative Pre
Currently this code implements the ROCStories Cloze Test result reported in the paper by running:
`python train.py --dataset rocstories --desc rocstories --submit --analysis --data_dir [path to data here]`

For the RACE dataset result, running:
`python train.py --dataset [race/racem] --desc [race/racem] --submit --analysis --data_dir [path to data here]`

`racem` for middle school and `race` for high school


Note: The code is currently non-deterministic due to various GPU ops. The median accuracy of 10 runs with this codebase (using default hyperparameters) is 85.8% - slightly lower than the reported single run of 86.5% from the paper.

The ROCStories dataset can be downloaded from the associated [website](http://cs.rochester.edu/nlp/rocstories/).
The RACE dataset can be download from the associated [website](http://www.cs.cmu.edu/~glai1/data/race/)
14 changes: 13 additions & 1 deletion analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sklearn.metrics import accuracy_score

from datasets import _rocstories
from datasets import _rocstories, _race, NUM2CHOI

def rocstories(data_dir, pred_path, log_path):
preds = pd.read_csv(pred_path, delimiter='\t')['prediction'].values.tolist()
Expand All @@ -16,3 +16,15 @@ def rocstories(data_dir, pred_path, log_path):
valid_accuracy = logs[best_validation_index]['va_acc']
print('ROCStories Valid Accuracy: %.2f'%(valid_accuracy))
print('ROCStories Test Accuracy: %.2f'%(test_accuracy))

def race(data_dir, pred_path, log_path):
preds = pd.read_csv(pred_path, delimiter='\t')['prediction'].values.tolist()
*_, labels = _race(os.path.join(data_dir, "test", "high"))
labels = [NUM2CHOI[l] for l in labels]
test_accuracy = accuracy_score(labels, preds)*100.
logs = [json.loads(line) for line in open(log_path)][1:]
best_validation_index = np.argmax([log['va_acc'] for log in logs])
valid_accuracy = logs[best_validation_index]['va_acc']
print('RACE Valid Accuracy: %.2f'%(valid_accuracy))
print('RACE Test Accuracy: %.2f'%(test_accuracy))

41 changes: 40 additions & 1 deletion datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
import csv
import json
from string import ascii_uppercase
from collections import Counter

import numpy as np

from tqdm import tqdm
Expand Down Expand Up @@ -48,4 +52,39 @@ def rocstories(data_dir, n_train=1497, n_valid=374):
vaY.append(y)
trY = np.asarray(trY, dtype=np.int32)
vaY = np.asarray(vaY, dtype=np.int32)
return (trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3)
return (trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3, _)



CHOICE_NUMBER = {v:i for i,v in enumerate(ascii_uppercase)}
NUM2CHOI = {i:v for i,v in enumerate(ascii_uppercase)}

def _race(path):
files = os.listdir(path)
art, ques, c1, c2, c3, c4, y = [], [], [], [], [], [], []
for fn in files:
with open(os.path.join(path, fn)) as f:
j = json.load(f)
for q, cs, ans in zip(j["questions"], j["options"], j["answers"]):
art.append(j["article"])
ques.append(q)
y.append(CHOICE_NUMBER[ans])
c1.append(cs[0])
c2.append(cs[1])
c3.append(cs[2])
c4.append(cs[3])

return art, ques, c1, c2, c3, c4, y

def racem(data_dir):
trainset = _race(os.path.join(data_dir, "train", "middle"))
devset = _race(os.path.join(data_dir, "dev", "middle"))
testset = _race(os.path.join(data_dir, "test", "middle"))
return trainset, devset, testset

def race(data_dir):
trainset = _race(os.path.join(data_dir, "train", "high"))
devset = _race(os.path.join(data_dir, "dev", "high"))
testset = _race(os.path.join(data_dir, "test", "high"))
return trainset, devset, testset

117 changes: 102 additions & 15 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import joblib
import random
import argparse
from itertools import islice

import numpy as np
import tensorflow as tf

Expand All @@ -14,8 +16,9 @@
from sklearn.metrics import accuracy_score

from opt import adam, warmup_cosine, warmup_linear, warmup_constant
from datasets import rocstories
from datasets import rocstories, racem, race, NUM2CHOI
from analysis import rocstories as rocstories_analysis
from analysis import race as race_analysis
from text_utils import TextEncoder
from utils import encode_dataset, flatten, iter_data, find_trainable_variables, convert_gradient_to_tensor, shape_list, ResultLogger, assign_to_gpu, average_grads, make_path

Expand Down Expand Up @@ -181,14 +184,14 @@ def model(X, M, Y, train=False, reuse=False):
pool_idx = tf.cast(tf.argmax(tf.cast(tf.equal(X[:, :, 0], clf_token), tf.float32), 1), tf.int32)
clf_h = tf.gather(clf_h, tf.range(shape_list(X)[0], dtype=tf.int32)*n_ctx+pool_idx)

clf_h = tf.reshape(clf_h, [-1, 2, n_embd])
clf_h = tf.reshape(clf_h, [-1, cn, n_embd])
if train and clf_pdrop > 0:
shape = shape_list(clf_h)
shape[1] = 1
clf_h = tf.nn.dropout(clf_h, 1-clf_pdrop, shape)
clf_h = tf.reshape(clf_h, [-1, n_embd])
clf_logits = clf(clf_h, 1, train=train)
clf_logits = tf.reshape(clf_logits, [-1, 2])
clf_logits = tf.reshape(clf_logits, [-1, cn])

clf_losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=clf_logits, labels=Y)
return clf_logits, clf_losses, lm_losses
Expand Down Expand Up @@ -226,6 +229,34 @@ def mgpu_predict(*xs):
ops = [tf.concat(op, 0) for op in zip(*gpu_ops)]
return ops


def transform_race(art, ques, c1, c2, c3, c4):
n_batch = len(art)
xmb = np.zeros((n_batch, 4, n_ctx, 2), dtype=np.int32)
mmb = np.zeros((n_batch, 4, n_ctx), dtype=np.float32)
start = encoder['_start_']
delimiter = encoder['_delimiter_']
for i, (x1, q, x2, x3, x4, x5), in enumerate(zip(art, ques, c1, c2, c3, c4)):
x12 = [start]+x1[:469]+q[:23]+[delimiter]+x2[:17]+[clf_token]
x13 = [start]+x1[:469]+q[:23]+[delimiter]+x3[:17]+[clf_token]
x14 = [start]+x1[:469]+q[:23]+[delimiter]+x4[:17]+[clf_token]
x15 = [start]+x1[:469]+q[:23]+[delimiter]+x5[:17]+[clf_token]
l12 = len(x12)
l13 = len(x13)
l14 = len(x14)
l15 = len(x15)
xmb[i, 0, :l12, 0] = x12
xmb[i, 1, :l13, 0] = x13
xmb[i, 2, :l14, 0] = x14
xmb[i, 3, :l15, 0] = x15
mmb[i, 0, :l12] = 1
mmb[i, 1, :l13] = 1
mmb[i, 2, :l14] = 1
mmb[i, 3, :l15] = 1
xmb[:, :, :, 1] = np.arange(n_vocab+n_special, n_vocab+n_special+n_ctx)
return xmb, mmb


def transform_roc(X1, X2, X3):
n_batch = len(X1)
xmb = np.zeros((n_batch, 2, n_ctx, 2), dtype=np.int32)
Expand Down Expand Up @@ -293,16 +324,54 @@ def log():

pred_fns = {
'rocstories':argmax,
'racem':argmax,
'race':argmax,
}

filenames = {
'rocstories':'ROCStories.tsv',
'racem':'RACE-M.tsv',
'race':'RACE.tsv',
}

label_decoders = {
'rocstories':None,
'racem':NUM2CHOI,
'race':NUM2CHOI,
}

load_dataset = {
'rocstories': rocstories,
'racem':racem,
'race': race,
}

num_choice = {
'rocstories': 2,
'racem': 4,
'race': 4,
}

transforms = {
'rocstories': transform_roc,
'racem': transform_race,
'race': transform_race,
}

analyses = {
'rocstories': rocstories_analysis,
'racem': race_analysis,
'race': race_analysis,
}


def slice(a, n):
if n == -1:
return islice(a, len(a)-1)
else:
return islice(a, n)


def predict():
filename = filenames[dataset]
pred_fn = pred_fns[dataset]
Expand Down Expand Up @@ -367,29 +436,46 @@ def predict():
encoder = text_encoder.encoder
n_vocab = len(text_encoder.encoder)

(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir), encoder=text_encoder)
n_y = 2
raw_dataset = load_dataset[dataset](data_dir)
transform = transforms[dataset]

trainset, devset, testset = encode_dataset(raw_dataset, encoder=text_encoder)

encoder['_start_'] = len(encoder)
encoder['_delimiter_'] = len(encoder)
encoder['_classify_'] = len(encoder)
clf_token = encoder['_classify_']
n_special = 3
max_len = n_ctx//2-2
n_ctx = min(max([len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(trX1, trX2, trX3)]+[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]+[len(x1[:max_len])+max(len(x2[:max_len]), len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)])+3, n_ctx)
trX, trM = transform_roc(trX1, trX2, trX3)
vaX, vaM = transform_roc(vaX1, vaX2, vaX3)


if dataset == "race" or dataset == "racem":
max_len = -1
n_ctx = 512
else:
max_len = n_ctx//2-2
n_ctx = min(max(min(len(x1), max_len) + max(map(lambda xi: min(len(xi), max_len), x)) for xset in (trainset, devset, testset) for x1, *x in zip(*slice(xset, -1)))+3, n_ctx)
print(n_ctx)

trX, trM = transform(*slice(trainset, -1))
vaX, vaM = transform(*slice(devset, -1))
trY = trainset[-1]
vaY = devset[-1]

if submit:
teX, teM = transform_roc(teX1, teX2, teX3)
teX, teM = transform(*slice(testset, -1))

n_train = len(trY)
n_valid = len(vaY)
n_batch_train = n_batch*n_gpu
n_updates_total = (n_train//n_batch_train)*n_iter

X_train = tf.placeholder(tf.int32, [n_batch_train, 2, n_ctx, 2])
M_train = tf.placeholder(tf.float32, [n_batch_train, 2, n_ctx])
X = tf.placeholder(tf.int32, [None, 2, n_ctx, 2])
M = tf.placeholder(tf.float32, [None, 2, n_ctx])

cn = num_choice[dataset]

X_train = tf.placeholder(tf.int32, [n_batch_train, cn, n_ctx, 2])
M_train = tf.placeholder(tf.float32, [n_batch_train, cn, n_ctx])
X = tf.placeholder(tf.int32, [None, cn, n_ctx, 2])
M = tf.placeholder(tf.float32, [None, cn, n_ctx])

Y_train = tf.placeholder(tf.int32, [n_batch_train])
Y = tf.placeholder(tf.int32, [None])
Expand Down Expand Up @@ -440,4 +526,5 @@ def predict():
sess.run([p.assign(ip) for p, ip in zip(params, joblib.load(os.path.join(save_dir, desc, 'best_params.jl')))])
predict()
if analysis:
rocstories_analysis(data_dir, os.path.join(submission_dir, 'ROCStories.tsv'), os.path.join(log_dir, 'rocstories.jsonl'))
analyzer = analyses[dataset]
analyzer(data_dir, os.path.join(submission_dir, filenames[dataset]), os.path.join(log_dir, f'{desc}.jsonl'))