From d58d6437bd5b043f12916ba9bed8f90286a4ea82 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 16 Jul 2022 16:06:41 +0200 Subject: [PATCH 1/3] Add PrefixLM eval --- .../run_bsevalharness_prefix.slurm | 122 ++++++++++++++++++ .../evalharness/run_evalharness_prefix.slurm | 121 +++++++++++++++++ tasks/eval_harness/evaluate.py | 39 +++++- tasks/eval_harness/evaluate_bsevalharness.py | 36 +++++- 4 files changed, 310 insertions(+), 8 deletions(-) create mode 100644 examples/evalharness/run_bsevalharness_prefix.slurm create mode 100644 examples/evalharness/run_evalharness_prefix.slurm diff --git a/examples/evalharness/run_bsevalharness_prefix.slurm b/examples/evalharness/run_bsevalharness_prefix.slurm new file mode 100644 index 000000000..6f20dbaf8 --- /dev/null +++ b/examples/evalharness/run_bsevalharness_prefix.slurm @@ -0,0 +1,122 @@ +#!/bin/bash +#SBATCH --job-name=run_evalharness-tr13f-6b3 +#SBATCH --partition=gpu_p5 +#SBATCH --constraint=a100 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! +#SBATCH --cpus-per-task=8 # number of cores per tasks +#SBATCH --hint=nomultithread # we get physical cores not logical +#SBATCH --gres=gpu:1 # number of gpus +#SBATCH --time 20:00:00 # maximum execution time (HH:MM:SS) +#SBATCH --output=%x-%j.out # output file name +#SBATCH --account=six@a100 +#SBATCH --reservation=hug + +set -x -e + +source $six_ALL_CCFRWORK/start-muennighofflmeval + +echo "START TIME: $(date)" + +# a unique identifier for the current eval ideally correspnding to the modelname +VARIANT="tr13f-prefix" + + +CHECKPOINT_PATH=$six_ALL_CCFRSCRATCH/checkpoints/tr11c-2B5-ml/checkpoints/main/global_step337250 +MEGATRON_DEEPSPEED_REPO=$six_ALL_CCFRSCRATCH/commun/experiments/muennighoff/megdsbslmeval/Megatron-DeepSpeed +export HF_DATASETS_OFFLINE=1 +export TRANSFORMERS_OFFLINE=1 + +export TRANSFORMERS_CACHE=$six_ALL_CCFRWORK/models +export HF_DATASETS_CACHE=$six_ALL_CCFRWORK/datasetseval +export HF_MODULES_CACHE=$six_ALL_CCFRWORK/modules +export HF_METRICS_CACHE=$six_ALL_CCFRWORK/metrics +export TOKENIZERS_PARALLELISM=false + +cd $MEGATRON_DEEPSPEED_REPO + +TOKENIZER_NAME_OR_PATH=bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles + +PP_SIZE=1 +TP_SIZE=1 +SEQ_LEN=2048 + +# different from the training MICRO_BATCH_SIZE - no optim memory, so can do bigger BS +# make as big as it can fit into gpu w/o OOM, but not too close to 100% +EVAL_MICRO_BATCH_SIZE=1 + +#dummy arguments to make megatron happy. +MEGATRON_REQUIRED_ARGS=" \ + --num-layers -1 \ + --hidden-size -1 \ + --num-attention-heads -1 \ + --seq-length -1 \ + --max-position-embeddings -1 \ +" + + +ZERO_STAGE=0 + +config_json="./ds_config.json" + +# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size() +cat < $config_json +{ + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": 1, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": $ZERO_STAGE + }, + "bf16": { + "enabled": false + }, + "steps_per_print": 2000, + "wall_clock_breakdown": false +} +EOT + + +CMD="./tasks/eval_harness/evaluate_bsevalharness_prefix.py \ + --load $CHECKPOINT_PATH \ + --results_path $VARIANT-results.json \ + --tensor-model-parallel-size $TP_SIZE \ + --pipeline-model-parallel-size $PP_SIZE \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path $TOKENIZER_NAME_OR_PATH \ + --micro-batch-size $EVAL_MICRO_BATCH_SIZE \ + --no-load-optim \ + --no-load-rng \ + --eval_fp32 \ + --inference \ + --seq-length $SEQ_LEN \ + --task_list copa \ + --prefix \ + --deepspeed \ + --deepspeed_config ds_config.json \ + --intermed_results \ + --adaptive_seq_len \ + --micro_bs_multiplier 8 \ + $MEGATRON_REQUIRED_ARGS \ + " + +GPUS_PER_NODE=1 +NNODES=$SLURM_NNODES +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 +export LAUNCHER="python -u -m torch.distributed.run \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + " + +export CUDA_LAUNCH_BLOCKING=1 + +echo $LAUNCHER $CMD + +export PYTHONPATH=$MEGATRON_DEEPSPEED_REPO + +$LAUNCHER $CMD 2>&1 | tee $VARIANT-eval-harness.log diff --git a/examples/evalharness/run_evalharness_prefix.slurm b/examples/evalharness/run_evalharness_prefix.slurm new file mode 100644 index 000000000..aaebfc7c6 --- /dev/null +++ b/examples/evalharness/run_evalharness_prefix.slurm @@ -0,0 +1,121 @@ +#!/bin/bash +#SBATCH --job-name=run_evalharness-tr13f-6B3-prefix +#SBATCH --partition=gpu_p5 +#SBATCH --constraint=a100 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node! +#SBATCH --cpus-per-task=8 # number of cores per tasks +#SBATCH --hint=nomultithread # we get physical cores not logical +#SBATCH --gres=gpu:1 # number of gpus +#SBATCH --time 20:00:00 # maximum execution time (HH:MM:SS) +#SBATCH --output=%x-%j.out # output file name +#SBATCH --account=six@a100 +#SBATCH --reservation=hug + +set -x -e + +source $six_ALL_CCFRWORK/start-py38-pt111 + +echo "START TIME: $(date)" + +# a unique identifier for the current eval ideally correspnding to the modelname +VARIANT="tr13f-6B3-prefix" + + +CHECKPOINT_PATH=/gpfsscratch/rech/six/commun/checkpoints/tr13f-6B3-ml-t0/checkpoints/prefix/global_step3100 +MEGATRON_DEEPSPEED_REPO=$six_ALL_CCFRSCRATCH/commun/experiments/muennighoff/megdsbslmeval/Megatron-DeepSpeed +export HF_DATASETS_OFFLINE=1 +export TRANSFORMERS_OFFLINE=1 + +export TRANSFORMERS_CACHE=$six_ALL_CCFRWORK/models +export HF_DATASETS_CACHE=$six_ALL_CCFRWORK/datasets +export HF_MODULES_CACHE=$six_ALL_CCFRWORK/modules +export HF_METRICS_CACHE=$six_ALL_CCFRWORK/metrics + +cd $MEGATRON_DEEPSPEED_REPO + +TOKENIZER_NAME_OR_PATH=bigscience-catalogue-data-dev/byte-level-bpe-tokenizer-no-norm-250k-whitespace-and-eos-regex-alpha-v3-dedup-lines-articles + +PP_SIZE=1 +TP_SIZE=1 +SEQ_LEN=2048 + +# different from the training MICRO_BATCH_SIZE - no optim memory, so can do bigger BS +# make as big as it can fit into gpu w/o OOM, but not too close to 100% +EVAL_MICRO_BATCH_SIZE=1 + +#dummy arguments to make megatron happy. +MEGATRON_REQUIRED_ARGS=" \ + --num-layers -1 \ + --hidden-size -1 \ + --num-attention-heads -1 \ + --seq-length -1 \ + --max-position-embeddings -1 \ +" + + +ZERO_STAGE=0 + +config_json="./ds_config.json" + +# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size() +cat < $config_json +{ + "train_micro_batch_size_per_gpu": 1, + "train_batch_size": 1, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": $ZERO_STAGE + }, + "bf16": { + "enabled": false + }, + "steps_per_print": 2000, + "wall_clock_breakdown": false +} +EOT + + +CMD="./tasks/eval_harness/evaluate_evalharness_prefix.py \ + --load $CHECKPOINT_PATH \ + --results_path $VARIANT-results.json \ + --tensor-model-parallel-size $TP_SIZE \ + --pipeline-model-parallel-size $PP_SIZE \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path $TOKENIZER_NAME_OR_PATH \ + --micro-batch-size $EVAL_MICRO_BATCH_SIZE \ + --no-load-optim \ + --no-load-rng \ + --eval_fp32 \ + --inference \ + --seq-length $SEQ_LEN \ + --task_list copa \ + --prefix \ + --deepspeed \ + --deepspeed_config ds_config.json \ + --intermed_results \ + --adaptive_seq_len \ + --micro_bs_multiplier 8 \ + $MEGATRON_REQUIRED_ARGS \ + " + +GPUS_PER_NODE=1 +NNODES=$SLURM_NNODES +MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +MASTER_PORT=6000 +export LAUNCHER="python -u -m torch.distributed.run \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $NNODES \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + " + +export CUDA_LAUNCH_BLOCKING=1 + +echo $LAUNCHER $CMD + +export PYTHONPATH=$MEGATRON_DEEPSPEED_REPO + +$LAUNCHER $CMD 2>&1 | tee $VARIANT-eval-harness.log diff --git a/tasks/eval_harness/evaluate.py b/tasks/eval_harness/evaluate.py index 8c362f1a7..a36cd185a 100644 --- a/tasks/eval_harness/evaluate.py +++ b/tasks/eval_harness/evaluate.py @@ -13,7 +13,6 @@ import torch.nn.functional as F from lm_eval.tasks import ALL_TASKS -from pretrain_gpt import model_provider import numpy as np import torch @@ -154,7 +153,9 @@ def _collate(x): contlens.append(cont) inplens.append(inplen) - logits = self._model_call(torch.cat(inps, dim=0)) + # contlens stores contencs not contlens, but not changing the variable names for consistency + prefix_lens = torch.tensor([ilen - len(ctoks) for ilen, ctoks in zip(inplens, contlens)])[:, None] + logits = self._model_call(torch.cat(inps, dim=0), prefix_lens=prefix_lens) res_len += len(chunk) if logits is not None: if self.args.offloadearly: @@ -186,8 +187,14 @@ def _collate(x): return reord.get_original(res) def create_model_inputs(self, tokens): + args = get_args() + prefix_lens = None + if args.prefix: + assert len(tokens) == 2 + tokens, prefix_lens = tokens + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, self.EOT_TOKEN_ID, @@ -196,10 +203,19 @@ def create_model_inputs(self, tokens): args.eod_mask_loss, prefix_indices=None, loss_on_targets_only=False) + + if prefix_lens is not None: + assert len(prefix_lens) == attention_mask.shape[0] == tokens.shape[0] + for i, prefix_len in enumerate(prefix_lens): + assert prefix_len <= attention_mask.shape[-1] + # Attention is paid to False (True ones are masked out) + attention_mask[i, :, :prefix_len, :prefix_len] = False + + return (tokens, position_ids, attention_mask), (tokens, loss_mask) - def _model_call(self, inps): + def _model_call(self, inps, prefix_lens=None): args = get_args() if args.deepspeed: @@ -208,7 +224,15 @@ def _model_call(self, inps): new_size = ((len(inps) + args.micro_batch_size-1) // args.micro_batch_size) * args.micro_batch_size padded = F.pad(inps, (0, 0, 0, new_size-len(inps)), value = 0) # dummy data iterator for pipelining. - data_iterator = list((torch.stack(inp) for inp in utils.chunks(padded, args.micro_batch_size))) + if args.prefix: + assert prefix_lens.shape == (padded.shape[0], 1) + data_iterator = [(torch.stack(inp), torch.stack(pfx)) for inp, pfx in zip( + utils.chunks(padded, args.micro_batch_size), + utils.chunks(prefix_lens, args.micro_batch_size), + ) + ] + else: + data_iterator = list((torch.stack(inp) for inp in utils.chunks(padded, args.micro_batch_size))) self.model.micro_batches = len(data_iterator) if self.adaptive_seq_len: @@ -348,6 +372,12 @@ def load_ds_checkpoint_and_setup_megatron(args): # print final arguments. _print_args(args) + + if args.prefix: + from finetune_t0_non_causal_decoder import model_provider + else: + from pretrain_gpt import model_provider + if args.deepspeed: # Hack #3: @@ -393,6 +423,7 @@ def tasks_args(parser): group.add_argument('--adaptive_seq_len', default = False, action='store_true', help='Should the sequence length be adapted to the batch during evaluation, if in fp16 the results will be slightly different due to numerical errors but greatly speed up evaluation.') group.add_argument('--eval_fp32', default = False, action='store_true', help='Should the evaluation run in fp32') + group.add_argument('--prefix', default=False, action='store_true', help='Prefix LM - Bidirectional att over input') group.add_argument('--intermed_results', default = False, action='store_true', help='Whether to print & write intermediate results for each task') group.add_argument('--bootstrap_iters', type=int, default=100000, help='How many iterations to use for stderr estimation') group.add_argument('--micro_bs_multiplier', type=int, default=1, help='Increase the global batch size to remove bubble when pipeline parallel') diff --git a/tasks/eval_harness/evaluate_bsevalharness.py b/tasks/eval_harness/evaluate_bsevalharness.py index bf429e7d5..199f4ab55 100644 --- a/tasks/eval_harness/evaluate_bsevalharness.py +++ b/tasks/eval_harness/evaluate_bsevalharness.py @@ -25,7 +25,6 @@ import torch.nn.functional as F from lm_eval.tasks import ALL_TASKS -from pretrain_gpt import model_provider import numpy as np import torch @@ -183,7 +182,9 @@ def _collate(x): contlens.append(cont) inplens.append(inplen) - logits = self._model_call(torch.cat(inps, dim=0)) + # contlens stores contencs not contlens, but not changing the variable names for consistency + prefix_lens = torch.tensor([ilen - len(ctoks) for ilen, ctoks in zip(inplens, contlens)])[:, None] + logits = self._model_call(torch.cat(inps, dim=0), prefix_lens=prefix_lens) torch.distributed.barrier() res_len += len(chunk) if logits is not None: @@ -216,8 +217,14 @@ def _collate(x): return reord.get_original(res) def create_model_inputs(self, tokens): + args = get_args() + prefix_lens = None + if args.prefix: + assert len(tokens) == 2 + tokens, prefix_lens = tokens + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, self.EOT_TOKEN_ID, @@ -226,10 +233,16 @@ def create_model_inputs(self, tokens): args.eod_mask_loss, prefix_indices=None, loss_on_targets_only=False) + + if prefix_lens is not None: + assert len(prefix_lens) == attention_mask.shape[0] + for i, prefix_len in enumerate(prefix_lens): + attention_mask[i, :, :prefix_len, :prefix_len] = 1 + return (tokens, position_ids, attention_mask), (tokens, loss_mask) - def _model_call(self, inps): + def _model_call(self, inps, prefix_lens=None): args = get_args() if args.deepspeed: @@ -238,7 +251,15 @@ def _model_call(self, inps): new_size = ((len(inps) + args.micro_batch_size-1) // args.micro_batch_size) * args.micro_batch_size padded = F.pad(inps, (0, 0, 0, new_size-len(inps)), value = 0) # dummy data iterator for pipelining. - data_iterator = list((torch.stack(inp) for inp in utils.chunks(padded, args.micro_batch_size))) + if args.prefix: + assert prefix_lens.shape == (padded.shape[0], 1) + data_iterator = [(torch.stack(inp), torch.stack(pfx)) for inp, pfx in zip( + utils.chunks(padded, args.micro_batch_size), + utils.chunks(prefix_lens, args.micro_batch_size), + ) + ] + else: + data_iterator = list((torch.stack(inp) for inp in utils.chunks(padded, args.micro_batch_size))) self.model.micro_batches = len(data_iterator) if self.adaptive_seq_len: @@ -375,6 +396,12 @@ def load_ds_checkpoint_and_setup_megatron(args): # print final arguments. _print_args(args) + + if args.prefix: + from finetune_t0_non_causal_decoder import model_provider + else: + from pretrain_gpt import model_provider + if args.deepspeed: # Hack #3: @@ -420,6 +447,7 @@ def tasks_args(parser): group.add_argument('--adaptive_seq_len', default = False, action='store_true', help='Should the sequence length be adapted to the batch during evaluation, if in fp16 the results will be slightly different due to numerical errors but greatly speed up evaluation.') group.add_argument('--eval_fp32', default = False, action='store_true', help='Should the evaluation run in fp32') + group.add_argument('--prefix', default=False, action='store_true', help='Prefix LM - Bidirectional att over input') group.add_argument('--intermed_results', default = False, action='store_true', help='Whether to print & write intermediate results for each task') group.add_argument('--bootstrap_iters', type=int, default=100000, help='How many iterations to use for stderr estimation') group.add_argument('--micro_bs_multiplier', type=int, default=1, help='Increase the global batch size to remove bubble when pipeline parallel') From 1fcd41336596c7543840838a71cfe36e9fe5958f Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Sat, 16 Jul 2022 16:12:25 +0200 Subject: [PATCH 2/3] Use prefix arg --- tasks/eval_harness/evaluate.py | 3 +-- tasks/eval_harness/evaluate_bsevalharness.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tasks/eval_harness/evaluate.py b/tasks/eval_harness/evaluate.py index a36cd185a..1885beac3 100644 --- a/tasks/eval_harness/evaluate.py +++ b/tasks/eval_harness/evaluate.py @@ -190,7 +190,6 @@ def create_model_inputs(self, tokens): args = get_args() - prefix_lens = None if args.prefix: assert len(tokens) == 2 tokens, prefix_lens = tokens @@ -204,7 +203,7 @@ def create_model_inputs(self, tokens): prefix_indices=None, loss_on_targets_only=False) - if prefix_lens is not None: + if args.prefix: assert len(prefix_lens) == attention_mask.shape[0] == tokens.shape[0] for i, prefix_len in enumerate(prefix_lens): assert prefix_len <= attention_mask.shape[-1] diff --git a/tasks/eval_harness/evaluate_bsevalharness.py b/tasks/eval_harness/evaluate_bsevalharness.py index 199f4ab55..864d10294 100644 --- a/tasks/eval_harness/evaluate_bsevalharness.py +++ b/tasks/eval_harness/evaluate_bsevalharness.py @@ -220,7 +220,6 @@ def create_model_inputs(self, tokens): args = get_args() - prefix_lens = None if args.prefix: assert len(tokens) == 2 tokens, prefix_lens = tokens @@ -234,7 +233,7 @@ def create_model_inputs(self, tokens): prefix_indices=None, loss_on_targets_only=False) - if prefix_lens is not None: + if args.prefix: assert len(prefix_lens) == attention_mask.shape[0] for i, prefix_len in enumerate(prefix_lens): attention_mask[i, :, :prefix_len, :prefix_len] = 1 From 03b497bcd74cae7afc8b522fdf50b2decec1bae9 Mon Sep 17 00:00:00 2001 From: Muennighoff Date: Tue, 19 Jul 2022 16:23:32 +0200 Subject: [PATCH 3/3] Updates --- examples/evalharness/run_bsevalharness_prefix.slurm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/evalharness/run_bsevalharness_prefix.slurm b/examples/evalharness/run_bsevalharness_prefix.slurm index 6f20dbaf8..8b7cdf001 100644 --- a/examples/evalharness/run_bsevalharness_prefix.slurm +++ b/examples/evalharness/run_bsevalharness_prefix.slurm @@ -22,7 +22,7 @@ echo "START TIME: $(date)" VARIANT="tr13f-prefix" -CHECKPOINT_PATH=$six_ALL_CCFRSCRATCH/checkpoints/tr11c-2B5-ml/checkpoints/main/global_step337250 +CHECKPOINT_PATH=/gpfsscratch/rech/six/commun/checkpoints/tr13f-6B3-ml-t0/checkpoints/prefix/global_step3100 MEGATRON_DEEPSPEED_REPO=$six_ALL_CCFRSCRATCH/commun/experiments/muennighoff/megdsbslmeval/Megatron-DeepSpeed export HF_DATASETS_OFFLINE=1 export TRANSFORMERS_OFFLINE=1