diff --git a/inputs.py b/inputs.py index 6191eec3..41056c76 100644 --- a/inputs.py +++ b/inputs.py @@ -210,8 +210,7 @@ def mlm_sample_text(params, x, random_documents = False): return masked_features, labels -def pred_input(params, logger, enc=None, - path_to_prompt=""): +def pred_input(params, logger, enc=None, path_to_prompt=""): unicorns = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \ "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \ @@ -294,12 +293,14 @@ def _get_skip_index(all_files, n_batches): break return skip_idx, remainder -def _parse_function(example_proto): - features = { - "text": tf.VarLenFeature(tf.int64) - } - parsed_features = tf.parse_single_example(example_proto, features) - return tf.sparse.to_dense(parsed_features["text"], parsed_features["text"].dense_shape[0]) + +def _parse_function(proto): + x = tf.parse_single_example(proto, {'text': tf.VarLenFeature(tf.int64)}) + x = x['text'] + x = tf.sparse.to_dense(x) + x = tf.data.Dataset.from_tensor_slices(x) + return x + def sequential_input(params, global_step=None, eval=False): """ @@ -321,45 +322,72 @@ def sequential_input(params, global_step=None, eval=False): If training is starting and stopping often, as with TPU pre-emption, reading the whole dataset sequentially appears to improve model performance, as it results in less repeated data. """ + if not eval: assert global_step is not None - logging.warning("Changing batch size with sequential_input() will result in some data being skipped or repeated. Please ensure your batch size stays constant throughout training.") + + logging.warning("Changing batch size with sequential_input() will result in some data being skipped or repeated." + "Please ensure your batch size stays constant throughout training.") + batch_size = params['eval_batch_size' if eval else 'train_batch_size'] filenames = [] - for dataset_config in params['dataset_configs'].values(): # iterate through each dataset and read params + + # iterate through each dataset and read params + for dataset_config in params['dataset_configs'].values(): path_key = 'path' if not eval else 'eval_path' path = dataset_config[path_key] - filenames.extend(tf.io.gfile.glob(path)) # then glob all files that fit the pattern specified in dataset_configs + + # then glob all files that fit the pattern specified in dataset_configs + filenames.extend(tf.io.gfile.glob(path)) filenames = natural_sort(filenames) shuffle_filenames = params.get("shuffle_input_filenames", True) if shuffle_filenames: - seed = params.get('seed', 1) # shuffle deterministically + + # shuffle deterministically + seed = params.get('seed', 1) random.seed(seed) random.shuffle(filenames) - - dataset = tf.data.Dataset.from_tensor_slices(filenames).repeat() # repeat filenames to infinity + + # repeat filenames to infinity + dataset = tf.data.Dataset.from_tensor_slices(filenames).repeat() if not eval: + # skip forward first in the filenames list, then skip the remaining amount in the parsed tfrecords files skip_idx, remainder = _get_skip_index(filenames, n_batches=global_step * params["train_batch_size"]) # TODO: fix for > 1 epoch dataset = dataset.skip(skip_idx) # skip to skip idx # read tfrecord examples and skip remainder - dataset = dataset.apply(tf.data.TFRecordDataset) + #dataset = dataset.apply(tf.data.TFRecordDataset) dataset = dataset.skip(remainder) else: + # shuffle filenames if in eval mode dataset = dataset.shuffle(len(filenames)) - dataset = dataset.apply(tf.data.TFRecordDataset) + #dataset = dataset.apply(tf.data.TFRecordDataset) - - # parse the tokenized data from the tfrecord files and shuffle - dataset = dataset.map(_parse_function, num_parallel_calls=1) - dataset = dataset.map(partial(autoregressive_sample_text, params), num_parallel_calls=1) - - # batch data and repeat to infinity - dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2) - return dataset.repeat() - + def memory_func(x): + x = tf.reshape(x, (batch_size, params["n_ctx"] + 1)) + x = tf.cast(x, tf.int32) + + vals1 = x[:, :params["n_ctx"]] + vals2 = x[:, 1:params["n_ctx"] + 1] + + return vals1, vals2 + + dataset = dataset.map(tf.data.TFRecordDataset) + + dataset = dataset.flat_map(lambda x: x.flat_map(_parse_function)) + dataset = dataset.window(size=params["n_ctx"] + 1, shift=params["n_ctx"], stride=1, drop_remainder=True) + dataset = dataset.flat_map(lambda x: x.batch(params["n_ctx"] + 1)) + + dataset = dataset.shuffle(512, seed=seed) + + dataset = dataset.batch(batch_size, drop_remainder=True) + dataset = dataset.map(memory_func) + dataset = dataset.prefetch(params["iterations"] * 2) + + + return dataset \ No newline at end of file