Skip to content
This repository was archived by the owner on Feb 25, 2022. It is now read-only.

improve inputs.py #117

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 53 additions & 25 deletions inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,7 @@ def mlm_sample_text(params, x, random_documents = False):
return masked_features, labels


def pred_input(params, logger, enc=None,
path_to_prompt=""):
def pred_input(params, logger, enc=None, path_to_prompt=""):

unicorns = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
"previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
Expand Down Expand Up @@ -294,12 +293,14 @@ def _get_skip_index(all_files, n_batches):
break
return skip_idx, remainder

def _parse_function(example_proto):
features = {
"text": tf.VarLenFeature(tf.int64)
}
parsed_features = tf.parse_single_example(example_proto, features)
return tf.sparse.to_dense(parsed_features["text"], parsed_features["text"].dense_shape[0])

def _parse_function(proto):
x = tf.parse_single_example(proto, {'text': tf.VarLenFeature(tf.int64)})
x = x['text']
x = tf.sparse.to_dense(x)
x = tf.data.Dataset.from_tensor_slices(x)
return x


def sequential_input(params, global_step=None, eval=False):
"""
Expand All @@ -321,45 +322,72 @@ def sequential_input(params, global_step=None, eval=False):
If training is starting and stopping often, as with TPU pre-emption, reading the whole dataset sequentially appears to improve model
performance, as it results in less repeated data.
"""

if not eval:
assert global_step is not None
logging.warning("Changing batch size with sequential_input() will result in some data being skipped or repeated. Please ensure your batch size stays constant throughout training.")

logging.warning("Changing batch size with sequential_input() will result in some data being skipped or repeated."
"Please ensure your batch size stays constant throughout training.")

batch_size = params['eval_batch_size' if eval else 'train_batch_size']

filenames = []
for dataset_config in params['dataset_configs'].values(): # iterate through each dataset and read params

# iterate through each dataset and read params
for dataset_config in params['dataset_configs'].values():
path_key = 'path' if not eval else 'eval_path'
path = dataset_config[path_key]
filenames.extend(tf.io.gfile.glob(path)) # then glob all files that fit the pattern specified in dataset_configs

# then glob all files that fit the pattern specified in dataset_configs
filenames.extend(tf.io.gfile.glob(path))

filenames = natural_sort(filenames)
shuffle_filenames = params.get("shuffle_input_filenames", True)
if shuffle_filenames:
seed = params.get('seed', 1) # shuffle deterministically

# shuffle deterministically
seed = params.get('seed', 1)
random.seed(seed)
random.shuffle(filenames)

dataset = tf.data.Dataset.from_tensor_slices(filenames).repeat() # repeat filenames to infinity

# repeat filenames to infinity
dataset = tf.data.Dataset.from_tensor_slices(filenames).repeat()

if not eval:

# skip forward first in the filenames list, then skip the remaining amount in the parsed tfrecords files
skip_idx, remainder = _get_skip_index(filenames, n_batches=global_step * params["train_batch_size"]) # TODO: fix for > 1 epoch
dataset = dataset.skip(skip_idx) # skip to skip idx

# read tfrecord examples and skip remainder
dataset = dataset.apply(tf.data.TFRecordDataset)
#dataset = dataset.apply(tf.data.TFRecordDataset)
dataset = dataset.skip(remainder)
else:

# shuffle filenames if in eval mode
dataset = dataset.shuffle(len(filenames))
dataset = dataset.apply(tf.data.TFRecordDataset)
#dataset = dataset.apply(tf.data.TFRecordDataset)


# parse the tokenized data from the tfrecord files and shuffle
dataset = dataset.map(_parse_function, num_parallel_calls=1)
dataset = dataset.map(partial(autoregressive_sample_text, params), num_parallel_calls=1)

# batch data and repeat to infinity
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2)
return dataset.repeat()

def memory_func(x):
x = tf.reshape(x, (batch_size, params["n_ctx"] + 1))
x = tf.cast(x, tf.int32)

vals1 = x[:, :params["n_ctx"]]
vals2 = x[:, 1:params["n_ctx"] + 1]

return vals1, vals2

dataset = dataset.map(tf.data.TFRecordDataset)

dataset = dataset.flat_map(lambda x: x.flat_map(_parse_function))
dataset = dataset.window(size=params["n_ctx"] + 1, shift=params["n_ctx"], stride=1, drop_remainder=True)
dataset = dataset.flat_map(lambda x: x.batch(params["n_ctx"] + 1))

dataset = dataset.shuffle(512, seed=seed)

dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.map(memory_func)
dataset = dataset.prefetch(params["iterations"] * 2)


return dataset