From 44a482825c0d550e13ca9f63fc16a8af2ce16387 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 2 Dec 2022 15:20:05 -0800 Subject: [PATCH 01/37] add test cases on inference layer construction --- benchmark/alpa/benchmark.py | 1 + benchmark/alpa/suite_inference_gpt.py | 28 ++++++++++++++++++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/benchmark/alpa/benchmark.py b/benchmark/alpa/benchmark.py index 66d17ccf3..b1d675e8c 100644 --- a/benchmark/alpa/benchmark.py +++ b/benchmark/alpa/benchmark.py @@ -27,6 +27,7 @@ "gpt.grid_search_auto": suite_auto_gpt.grid_search_suite, "gpt.correctness_test_auto": suite_auto_gpt.correctness_test_suite, "gpt_inference.profile": suite_inference_gpt.profile_suite, + "gpt_inference.test": suite_inference_gpt.test_suite, "gpt_no_embedding_inference.profile": suite_inference_gpt.profile_suite, "moe.tmp": suite_manual_moe.tmp_suite, "moe.tmp_auto": suite_auto_moe.tmp_suite, diff --git a/benchmark/alpa/suite_inference_gpt.py b/benchmark/alpa/suite_inference_gpt.py index 3d1d345a8..708b4032a 100644 --- a/benchmark/alpa/suite_inference_gpt.py +++ b/benchmark/alpa/suite_inference_gpt.py @@ -1,6 +1,7 @@ """Benchmark suites for gpt with auto parallelization.""" from suite_manual_gpt import gpt_specs -from benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs) +from benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs, + LoadSolutionParallelArgs) prefer_reduce_scatter = True force_batch_dim_mapping = True @@ -60,3 +61,28 @@ def get_config(model_config, [1, 2, 4, 8, 16]) get_config(gpt_specs["15B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], [1, 2, 4, 8, 16]) + +test_suite = { + 8: [ + BenchmarkCase( + 1, gpt_specs["1.3B"], 1, "uniform", + UniformParallelArgs( + prefer_reduce_scatter, + use_remat, + dp=1, + op=1, + pp=8, + force_batch_dim_mapping=force_batch_dim_mapping)), + BenchmarkCase( + 1, gpt_specs["1.3B"], 1, "load_solution", + LoadSolutionParallelArgs( + prefer_reduce_scatter, + use_remat, + num_auto_layers=8, + forward_stage_layer_ids=[[0], [1], [2], [3], [4], [5], [6], + [7]], + submesh_physical_shapes=[(1, 1)] * 8, + submesh_logical_shapes=[(1, 1)] * 8, + submesh_autosharding_option_dicts=[force_dp_dict])) + ] +} From e5ddd939820946e835d687a87ab0a87bac26a313 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 2 Dec 2022 16:44:08 -0800 Subject: [PATCH 02/37] fix bugs --- .../benchmark_one_case_gpt_bert_inference.py | 41 +++++++++++++------ benchmark/alpa/suite_inference_gpt.py | 2 +- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py index 18f305787..e007f7236 100644 --- a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py +++ b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py @@ -191,19 +191,34 @@ def benchmark_gpt_inference_internal(model_type, avg_stage_latencies = compute_avg_stage_latencies(timelines) assert len(avg_stage_latencies) == num_manual_pipeline_stages parallel_args = benchmark_case.parallel_args - dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp - heads = [ - "ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU", - "MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)", - "StagePeakMem(B)", "StageLatencies(s)" - ] - values = [ - model_name, benchmark_case.batch_size, - benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp, - f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", - f"{tflops:.2f}", f"{per_stage_weight_mem}", f"{per_stage_peak_mem}", - avg_stage_latencies - ] + if benchmark_case.parallel_mode == "uniform": + dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp + heads = [ + "ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU", + "MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)", + "StagePeakMem(B)", "StageLatencies(s)" + ] + values = [ + model_name, benchmark_case.batch_size, + benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp, + f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", + f"{tflops:.2f}", f"{per_stage_weight_mem}", f"{per_stage_peak_mem}", + avg_stage_latencies + ] + else: + heads = [ + "ModelName", "BS", "#Microbatch", "ParallelArgs", + "MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)", + "StagePeakMem(B)", "StageLatencies(s)" + ] + values = [ + model_name, benchmark_case.batch_size, + benchmark_case.num_micro_batches, parallel_args, + f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", + f"{tflops:.2f}", f"{per_stage_weight_mem}", + f"{per_stage_peak_mem}", + avg_stage_latencies + ] write_tsv(heads, values, f"benchmark_results.tsv") metadata = { diff --git a/benchmark/alpa/suite_inference_gpt.py b/benchmark/alpa/suite_inference_gpt.py index 708b4032a..609b65afc 100644 --- a/benchmark/alpa/suite_inference_gpt.py +++ b/benchmark/alpa/suite_inference_gpt.py @@ -83,6 +83,6 @@ def get_config(model_config, [7]], submesh_physical_shapes=[(1, 1)] * 8, submesh_logical_shapes=[(1, 1)] * 8, - submesh_autosharding_option_dicts=[force_dp_dict])) + submesh_autosharding_option_dicts=[force_dp_dict] * 8)) ] } From 1893ab1524663ceadb053eaaeb724d6ee7a963e4 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 2 Dec 2022 16:55:43 -0800 Subject: [PATCH 03/37] fix bugs --- .../benchmark_one_case_gpt_bert_inference.py | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py index e007f7236..d1b148b27 100644 --- a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py +++ b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py @@ -1,5 +1,6 @@ """Benchmark one case of inter-op + intra-op parallelism.""" import os +import datetime import jax import jax.numpy as jnp @@ -179,10 +180,6 @@ def benchmark_gpt_inference_internal(model_type, # Log per-stage execution information if needed if profile_stage_execution_time: model_name = f"bert-{parameter_count/1e9:.1f}b" - # dump chrome trace - executable.dump_stage_execution_trace( - f"./chrome_trace/{model_name},bs={benchmark_case.batch_size},op={benchmark_case.parallel_args.op},pp={benchmark_case.parallel_args.pp}.json" - ) # compute and log per-stage latency/memory statistics exec_info = executable.get_stage_execution_info() timelines = list(zip(*exec_info)) @@ -193,6 +190,11 @@ def benchmark_gpt_inference_internal(model_type, parallel_args = benchmark_case.parallel_args if benchmark_case.parallel_mode == "uniform": dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp + # dump chrome trace + executable.dump_stage_execution_trace( + f"./chrome_trace/{model_name}," + f"bs={benchmark_case.batch_size}," + f"op={op},pp={pp}.json") heads = [ "ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU", "MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)", @@ -202,22 +204,26 @@ def benchmark_gpt_inference_internal(model_type, model_name, benchmark_case.batch_size, benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp, f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", - f"{tflops:.2f}", f"{per_stage_weight_mem}", f"{per_stage_peak_mem}", - avg_stage_latencies + f"{tflops:.2f}", f"{per_stage_weight_mem}", + f"{per_stage_peak_mem}", avg_stage_latencies ] else: + timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + executable.dump_stage_execution_trace( + f"./chrome_trace/{model_name}," + f"bs={benchmark_case.batch_size}," + f"{timestamp}.json") heads = [ - "ModelName", "BS", "#Microbatch", "ParallelArgs", - "MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)", - "StagePeakMem(B)", "StageLatencies(s)" + "ModelName", "BS", "#Microbatch", "ParallelArgs", "MeanTime(s)", + "StdTime(s)", "TFLOPs", "StageWeights(B)", "StagePeakMem(B)", + "StageLatencies(s)", "TimeStamp" ] values = [ model_name, benchmark_case.batch_size, benchmark_case.num_micro_batches, parallel_args, f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", f"{tflops:.2f}", f"{per_stage_weight_mem}", - f"{per_stage_peak_mem}", - avg_stage_latencies + f"{per_stage_peak_mem}", avg_stage_latencies, timestamp ] write_tsv(heads, values, f"benchmark_results.tsv") From e3e90b899ffce2695626fc19ebed16aa0dbd26b1 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 2 Dec 2022 17:04:36 -0800 Subject: [PATCH 04/37] fix --- benchmark/alpa/benchmark_one_case_gpt_bert_inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py index d1b148b27..535acf097 100644 --- a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py +++ b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py @@ -186,7 +186,8 @@ def benchmark_gpt_inference_internal(model_type, # drop warmup case timelines = timelines[1:] avg_stage_latencies = compute_avg_stage_latencies(timelines) - assert len(avg_stage_latencies) == num_manual_pipeline_stages + if num_manual_pipeline_stages is not None: + assert len(avg_stage_latencies) == num_manual_pipeline_stages parallel_args = benchmark_case.parallel_args if benchmark_case.parallel_mode == "uniform": dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp From f20fd9e7df7184fb6c89ce8ca54d80ac8c586a92 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 3 Dec 2022 01:20:54 +0000 Subject: [PATCH 05/37] fix --- .../benchmark_one_case_gpt_bert_inference.py | 2 +- benchmark/alpa/suite_inference_gpt.py | 18 +++++++++--------- third_party/tensorflow-alpa | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py index 535acf097..32072ed69 100644 --- a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py +++ b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py @@ -1,6 +1,6 @@ """Benchmark one case of inter-op + intra-op parallelism.""" import os -import datetime +from datetime import datetime import jax import jax.numpy as jnp diff --git a/benchmark/alpa/suite_inference_gpt.py b/benchmark/alpa/suite_inference_gpt.py index 609b65afc..ff05466c9 100644 --- a/benchmark/alpa/suite_inference_gpt.py +++ b/benchmark/alpa/suite_inference_gpt.py @@ -64,15 +64,15 @@ def get_config(model_config, test_suite = { 8: [ - BenchmarkCase( - 1, gpt_specs["1.3B"], 1, "uniform", - UniformParallelArgs( - prefer_reduce_scatter, - use_remat, - dp=1, - op=1, - pp=8, - force_batch_dim_mapping=force_batch_dim_mapping)), + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "uniform", + # UniformParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # dp=1, + # op=1, + # pp=8, + # force_batch_dim_mapping=force_batch_dim_mapping)), BenchmarkCase( 1, gpt_specs["1.3B"], 1, "load_solution", LoadSolutionParallelArgs( diff --git a/third_party/tensorflow-alpa b/third_party/tensorflow-alpa index 721260d12..cd865615b 160000 --- a/third_party/tensorflow-alpa +++ b/third_party/tensorflow-alpa @@ -1 +1 @@ -Subproject commit 721260d122f096040762b2d226b37e8ab23f74b8 +Subproject commit cd865615b9b518bc507fbdc71dc44c7cc76618ac From 4f6f1a47bea0ee8e394ee41bdaef52d23fba9c93 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 2 Dec 2022 17:27:08 -0800 Subject: [PATCH 06/37] test search --- benchmark/alpa/suite_inference_gpt.py | 30 +++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/benchmark/alpa/suite_inference_gpt.py b/benchmark/alpa/suite_inference_gpt.py index ff05466c9..bcb856ac6 100644 --- a/benchmark/alpa/suite_inference_gpt.py +++ b/benchmark/alpa/suite_inference_gpt.py @@ -1,7 +1,9 @@ """Benchmark suites for gpt with auto parallelization.""" +from alpa import AutoStageOption from suite_manual_gpt import gpt_specs from benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs, - LoadSolutionParallelArgs) + LoadSolutionParallelArgs, + SearchParallelArgs) prefer_reduce_scatter = True force_batch_dim_mapping = True @@ -73,16 +75,26 @@ def get_config(model_config, # op=1, # pp=8, # force_batch_dim_mapping=force_batch_dim_mapping)), + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "load_solution", + # LoadSolutionParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # num_auto_layers=8, + # forward_stage_layer_ids=[[0], [1], [2], [3], [4], [5], [6], + # [7]], + # submesh_physical_shapes=[(1, 1)] * 8, + # submesh_logical_shapes=[(1, 1)] * 8, + # submesh_autosharding_option_dicts=[force_dp_dict] * 8)), BenchmarkCase( - 1, gpt_specs["1.3B"], 1, "load_solution", - LoadSolutionParallelArgs( + 1, gpt_specs["1.3B"], 1, "search", + SearchParallelArgs( prefer_reduce_scatter, use_remat, - num_auto_layers=8, - forward_stage_layer_ids=[[0], [1], [2], [3], [4], [5], [6], - [7]], - submesh_physical_shapes=[(1, 1)] * 8, - submesh_logical_shapes=[(1, 1)] * 8, - submesh_autosharding_option_dicts=[force_dp_dict] * 8)) + num_auto_layers=24, + auto_stage_option=AutoStageOption( + submesh_physical_shape_space="manual", + manually_specified_submeshes=((1, 1),), + submesh_logical_shape_space="model_parallel_only"))), ] } From a7f5e2da2b95fb7a5766ac7705cb8d86cf758e37 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 2 Dec 2022 17:34:38 -0800 Subject: [PATCH 07/37] fix --- benchmark/alpa/suite_inference_gpt.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/benchmark/alpa/suite_inference_gpt.py b/benchmark/alpa/suite_inference_gpt.py index bcb856ac6..c342c5a30 100644 --- a/benchmark/alpa/suite_inference_gpt.py +++ b/benchmark/alpa/suite_inference_gpt.py @@ -92,9 +92,10 @@ def get_config(model_config, prefer_reduce_scatter, use_remat, num_auto_layers=24, - auto_stage_option=AutoStageOption( - submesh_physical_shape_space="manual", - manually_specified_submeshes=((1, 1),), - submesh_logical_shape_space="model_parallel_only"))), + auto_stage_option={ + "submesh_physical_shape_space": "manual", + "manually_specified_submeshes": ((1, 1),), + "submesh_logical_shape_space": "model_parallel_only", + })), ] } From 413bb343b7543b7c802800a4f5e688f37add469b Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 3 Dec 2022 00:26:49 -0800 Subject: [PATCH 08/37] add 1d inference stage construction --- alpa/pipeline_parallel/stage_profiling.py | 129 +++++++++++++++--- .../pipeline_parallel/test_inference_auto.py | 101 +++----------- .../pipeline_parallel/test_inference_only.py | 29 ++-- .../test_stage_construction_util.py | 24 ++-- 4 files changed, 152 insertions(+), 131 deletions(-) diff --git a/alpa/pipeline_parallel/stage_profiling.py b/alpa/pipeline_parallel/stage_profiling.py index 15e1348cc..e485c7a40 100644 --- a/alpa/pipeline_parallel/stage_profiling.py +++ b/alpa/pipeline_parallel/stage_profiling.py @@ -756,7 +756,9 @@ def generate_inference_stages_2d(layers, return stages -def get_max_n_succ_stages(profile_results: Sequence[StageProfileResult]): +def get_merged_stages_memory_stats( + profile_results: Sequence[StageProfileResult], + inference_mode: bool = False): initial_var_sizes_dict = {} for stage_result in profile_results: for name, size in zip(stage_result.initial_var_names, @@ -775,13 +777,18 @@ def get_max_n_succ_stages(profile_results: Sequence[StageProfileResult]): for result in profile_results) n_stages = len(profile_results) n_modules = profile_results[0].n_modules - assert n_modules == 2, "Only support forward and backward module" + if inference_mode: + assert n_modules == 1, "Inference mode should only have 1 module." + module_execution_orders = [list(range(n_stages))] + else: + assert n_modules == 2, ("Only support forward and backward modules in " + "training mode.") + module_execution_orders = [ + list(range(n_stages)), + list(range(n_stages - 1, -1, -1)) + ] assert all(result.n_modules == n_modules for result in profile_results) - module_execution_orders = [ - list(range(n_stages)), - list(range(n_stages - 1, -1, -1)), - ] # eliminate_time[var] = k means that the variable can be eliminated after # stage k. last_used_stage_no = {} @@ -888,7 +895,7 @@ def get_max_n_succ_stages(profile_results: Sequence[StageProfileResult]): # Record the variables that are not eliminated at the end of the # last forward module. - if module_id == 0: + if module_id == 0 and not inference_mode: intermediate_size = sum(env.values()) for var in acc_grad_invars: @@ -900,12 +907,15 @@ def get_max_n_succ_stages(profile_results: Sequence[StageProfileResult]): assert len(env) == 0, f"Variables {env.keys()} are not eliminated." - max_stage = int((available_memory - peak_memory - initial_size) // - max(intermediate_size, 1e-8) - 1) - max_stage = min(max(-1, max_stage), INFINITY_N_STAGES) + if inference_mode: + max_stage = None + else: + max_stage = int((available_memory - peak_memory - initial_size) // + max(intermediate_size, 1e-8) - 1) + max_stage = min(max(-1, max_stage), INFINITY_N_STAGES) - return max_stage, (available_memory, peak_memory, initial_size, - intermediate_size) + return (available_memory, peak_memory, initial_size, intermediate_size, + max_stage) def interpret_profile_result_training_2d( @@ -929,8 +939,8 @@ def interpret_profile_result_training_2d( all_compute_cost[index] = sum( result.compute_cost for result in profile_result.module_profile_results) - all_max_n_succ_stages[index], _ = get_max_n_succ_stages( - [profile_result]) + _, _, _, _, all_max_n_succ_stages[index] = ( + get_merged_stages_memory_stats([profile_result])) return all_compute_cost, all_max_n_succ_stages @@ -971,14 +981,42 @@ def generate_training_stages_1d(layers, accumulator_mapping, acc_grad_invars, num_layers = len(layers) // 2 stages = [] for l in tqdm.tqdm(range(0, num_layers)): - selected_apply_grad_layers = ([] if apply_grad_layers[l] is None else - [apply_grad_layers[l]]) + selected_apply_grad_layers = [ + [] if apply_grad_layers[l] is None else [apply_grad_layers[l]] + ] stage_name = f"stage_{l}" stage_config = generate_stage_info(layers, [(l,), (2 * num_layers - l - 1,)], accumulator_mapping, acc_grad_invars, acc_grad_outvars, stage_name, - list(selected_apply_grad_layers), + selected_apply_grad_layers, + apply_grad_global_info) + for config_idx, autosharding_config in enumerate(autosharding_configs): + if autosharding_config is not None: + stage_indices = (l, mesh_id, config_idx) + stages.append( + (stage_indices, stage_config, autosharding_config)) + return stages + + +def generate_inference_stages_1d(layers, accumulator_mapping, acc_grad_invars, + acc_grad_outvars, apply_grad_layers, + apply_grad_global_info, mesh_id, + autosharding_configs): + print("- Generate all stage infos (Jaxpr -> HLO)") + num_layers = len(layers) + stages = [] + for l in tqdm.tqdm(range(0, num_layers)): + selected_apply_grad_layers = [ + [] if apply_grad_layers[l] is None else [apply_grad_layers[l]] + ] + assert len(selected_apply_grad_layers) == 0, ( + "Inference stage should not have apply_grad_layers") + stage_name = f"stage_{l}" + stage_config = generate_stage_info(layers, [(l,)], accumulator_mapping, + acc_grad_invars, acc_grad_outvars, + stage_name, + selected_apply_grad_layers, apply_grad_global_info) for config_idx, autosharding_config in enumerate(autosharding_configs): if autosharding_config is not None: @@ -1018,12 +1056,53 @@ def interpret_profile_result_training_1d( result.compute_cost for profile_result in selected_profile_results for result in profile_result.module_profile_results) - (all_max_n_succ_stages[start, end, submesh_choice, - config_idx], - _) = get_max_n_succ_stages(selected_profile_results) + (_, _, _, _, all_max_n_succ_stages[start, end, + submesh_choice, + config_idx] + ) = get_merged_stages_memory_stats(selected_profile_results) return all_compute_cost, all_max_n_succ_stages +def interpret_profile_result_inference_1d( + profile_results: Dict[Tuple[int, ...], + StageProfileResult], num_layers: int, + num_submesh_choices: int, num_autosharding_configs: int): + all_compute_cost = np.full( + (num_layers, num_layers, num_submesh_choices, num_autosharding_configs), + np.inf, + dtype=np.float64) + all_peak_memory = np.full( + (num_layers, num_layers, num_submesh_choices, num_autosharding_configs), + np.inf, + dtype=np.float64) + + for start in range(num_layers): + for end in range(start, num_layers): + for submesh_choice in range(num_submesh_choices): + for config_idx in range(num_autosharding_configs): + if any( + (l, submesh_choice, config_idx) not in profile_results + for l in range(start, end + 1)): + continue + selected_profile_results = [ + profile_results[(l, submesh_choice, config_idx)] + for l in range(start, end + 1) + ] + for result in selected_profile_results: + assert len(result.module_profile_results) == 1 + all_compute_cost[ + start, end, submesh_choice, config_idx] = sum( + profile_result.module_profile_results[0]. + compute_cost + for profile_result in selected_profile_results) + (available_memory, peak_memory, _, _, _ + ) = get_merged_stages_memory_stats(selected_profile_results) + if peak_memory > available_memory: + all_compute_cost[start, end, submesh_choice, + config_idx] = np.inf + return all_compute_cost, all_peak_memory + + def distributed_profile_on_mesh(stages, meshes: Sequence[VirtualPhysicalMesh], num_micro_batches, default_as_option, auto_stage_option, profile_results): @@ -1192,7 +1271,10 @@ def get_compute_cost( auto_stage_option.stage_imbalance_tolerance) elif auto_stage_option.layer_profile_mode == "individual": if inference_mode: - raise NotImplementedError() + stages = generate_inference_stages_1d( + layers, accumulator_mapping, acc_grad_invars, + acc_grad_outvars, apply_grad_layers, apply_grad_global_info, + mesh_id, autosharding_configs[mesh_id]) else: stages = generate_training_stages_1d( layers, accumulator_mapping, acc_grad_invars, @@ -1234,7 +1316,10 @@ def get_compute_cost( num_autosharding_configs) elif auto_stage_option.layer_profile_mode == "individual": if inference_mode: - raise NotImplementedError() + compute_cost, _ = interpret_profile_result_inference_1d( + profile_results, num_layers, num_submesh_choices, + num_autosharding_configs) + max_n_succ_stages = None else: (compute_cost, max_n_succ_stages) = interpret_profile_result_training_1d( diff --git a/tests/pipeline_parallel/test_inference_auto.py b/tests/pipeline_parallel/test_inference_auto.py index 9a637550f..6143c4535 100644 --- a/tests/pipeline_parallel/test_inference_auto.py +++ b/tests/pipeline_parallel/test_inference_auto.py @@ -1,26 +1,14 @@ import unittest +from alpa import init, PipeshardParallel, AutoStageOption +from .test_inference_only import PipelineInferenceTest -import jax -import jax.numpy as jnp -import numpy as np -from alpa import (init, shutdown, parallelize, PipeshardParallel, - mark_pipeline_boundary, AutoStageOption) -from alpa.model.bert_model import BertConfig, FlaxBertLayerCollection -from alpa.testing import (MLPModel, create_train_state, mlp_inference_step, - bert_layer_collection_inference_step, assert_allclose) - - -class PipelineInferenceAutoTest(unittest.TestCase): +class PipelineInferenceAutoTest(PipelineInferenceTest): def setUp(self): init(cluster="ray", num_nodes=1, num_devices_per_node=4) - # pylint: disable=no-self-use - def tearDown(self): - shutdown() - - def run_mlp_inference(self, manual_pipeline_layer): + def test_mlp(self): stage_option = AutoStageOption( submesh_physical_shape_space="manual", manually_specified_submeshes=((1, 2),), @@ -29,34 +17,9 @@ def run_mlp_inference(self, manual_pipeline_layer): pipeline_schedule="inference", layer_option="manual", stage_option=stage_option) + self.run_mlp_inference(True, method) - # Init model and optimizer - batch_size = 64 - hidden_size = 16 - - model = MLPModel(hidden_size=hidden_size, - num_layers=4, - add_manual_pipeline_marker=manual_pipeline_layer) - rngkey = jax.random.PRNGKey(0) - x = jax.random.normal(rngkey, (batch_size, hidden_size)) - y = jax.random.normal(rngkey, (batch_size, hidden_size)) - batch = {'x': x, 'y': y} - state = create_train_state(rngkey, model, [x]) - - # Compile - serial_inference_step = mlp_inference_step - - parallel_inference_step = parallelize(mlp_inference_step, - method=method, - donate_argnums=()) - executable = parallel_inference_step.get_executable(state, batch) - - # Run correctnesss test - serial_out = serial_inference_step(state, batch) - parallel_out = parallel_inference_step(state, batch) - assert_allclose(serial_out, parallel_out, 1e-3, 1e-3) - - def run_bert_layer_collection_inference(self, manual_pipeline_layer): + def test_bert(self): stage_option = AutoStageOption( submesh_physical_shape_space="manual", manually_specified_submeshes=((1, 2),), @@ -65,52 +28,26 @@ def run_bert_layer_collection_inference(self, manual_pipeline_layer): pipeline_schedule="inference", layer_option="manual", stage_option=stage_option) + self.run_bert_layer_collection_inference(True, method) - # Init model and optimizer - batch_size = 16 - seq_len = 256 - hidden_size = 512 - num_heads = 512 // 64 - n_layers = 4 - - model = FlaxBertLayerCollection( - config=BertConfig(hidden_size=hidden_size, - intermediate_size=hidden_size * 4, - num_attention_heads=num_heads, - num_hidden_layers=n_layers, - add_manual_pipeline_markers=manual_pipeline_layer, - pipeline_mp_size=n_layers)) - rngkey = jax.random.PRNGKey(0) - x = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size)) - y = jax.random.normal(rngkey, (batch_size, seq_len, hidden_size)) - attention_mask = jnp.ones((batch_size, seq_len), dtype=jnp.int8) - batch = {"x": x, "y": y, "attention_mask": attention_mask} - state = create_train_state(rngkey, model, [x, attention_mask]) - - # Compile - serial_inference_step = bert_layer_collection_inference_step - parallel_inference_step = parallelize( - bert_layer_collection_inference_step, - method=method, - donate_argnums=()) - executable = parallel_inference_step.get_executable(state, batch) - - # Run correctnesss test - serial_out = serial_inference_step(state, batch) - parallel_out = parallel_inference_step(state, batch) - assert_allclose(serial_out, parallel_out, 1e-3, 1e-3) - - def test_mlp(self): - self.run_mlp_inference(True) - - def test_bert(self): - self.run_bert_layer_collection_inference(True) + def test_mlp_1d(self): + stage_option = AutoStageOption( + submesh_physical_shape_space="manual", + manually_specified_submeshes=((1, 2),), + submesh_logical_shape_space="model_parallel_only", + layer_profile_mode="individual") + method = PipeshardParallel(num_micro_batches=1, + pipeline_schedule="inference", + layer_option="manual", + stage_option=stage_option) + self.run_mlp_inference(True, method) def suite(): suite = unittest.TestSuite() suite.addTest(PipelineInferenceAutoTest("test_mlp")) suite.addTest(PipelineInferenceAutoTest("test_bert")) + suite.addTest(PipelineInferenceAutoTest("test_mlp_1d")) return suite diff --git a/tests/pipeline_parallel/test_inference_only.py b/tests/pipeline_parallel/test_inference_only.py index 96f9fb009..c83384438 100644 --- a/tests/pipeline_parallel/test_inference_only.py +++ b/tests/pipeline_parallel/test_inference_only.py @@ -20,11 +20,7 @@ def setUp(self): def tearDown(self): shutdown() - def run_mlp_inference(self, manual_pipeline_layer): - method = PipeshardParallel(num_micro_batches=4, - pipeline_schedule="inference", - layer_option="manual") - + def run_mlp_inference(self, manual_pipeline_layer, parallel_method): # Init model and optimizer batch_size = 64 hidden_size = 16 @@ -42,7 +38,7 @@ def run_mlp_inference(self, manual_pipeline_layer): serial_inference_step = mlp_inference_step parallel_inference_step = parallelize(mlp_inference_step, - method=method, + method=parallel_method, donate_argnums=()) executable = parallel_inference_step.get_executable(state, batch) @@ -51,17 +47,14 @@ def run_mlp_inference(self, manual_pipeline_layer): parallel_out = parallel_inference_step(state, batch) assert_allclose(serial_out, parallel_out, 1e-3, 1e-3) - def run_bert_layer_collection_inference(self, manual_pipeline_layer): - method = PipeshardParallel(num_micro_batches=4, - pipeline_schedule="inference", - layer_option="manual") - + def run_bert_layer_collection_inference(self, manual_pipeline_layer, + parallel_method): # Init model and optimizer batch_size = 16 seq_len = 256 hidden_size = 512 num_heads = 512 // 64 - n_layers = 2 + n_layers = 4 model = FlaxBertLayerCollection( config=BertConfig(hidden_size=hidden_size, @@ -81,7 +74,7 @@ def run_bert_layer_collection_inference(self, manual_pipeline_layer): serial_inference_step = bert_layer_collection_inference_step parallel_inference_step = parallelize( bert_layer_collection_inference_step, - method=method, + method=parallel_method, donate_argnums=()) executable = parallel_inference_step.get_executable(state, batch) @@ -91,10 +84,16 @@ def run_bert_layer_collection_inference(self, manual_pipeline_layer): assert_allclose(serial_out, parallel_out, 1e-3, 1e-3) def test_mlp(self): - self.run_mlp_inference(True) + method = PipeshardParallel(num_micro_batches=4, + pipeline_schedule="inference", + layer_option="manual") + self.run_mlp_inference(True, method) def test_bert(self): - self.run_bert_layer_collection_inference(True) + method = PipeshardParallel(num_micro_batches=4, + pipeline_schedule="inference", + layer_option="manual") + self.run_bert_layer_collection_inference(True, method) def test_output(self): method = PipeshardParallel(num_micro_batches=2, diff --git a/tests/pipeline_parallel/test_stage_construction_util.py b/tests/pipeline_parallel/test_stage_construction_util.py index 55e20830e..dc7b5ba10 100644 --- a/tests/pipeline_parallel/test_stage_construction_util.py +++ b/tests/pipeline_parallel/test_stage_construction_util.py @@ -12,9 +12,9 @@ from alpa.pipeline_parallel.compile_executable import ( split_and_process_layers, slice_apply_grad_for_stage_construction) from alpa.pipeline_parallel.layer_construction import ManualLayerOption -from alpa.pipeline_parallel.stage_profiling import (generate_stage_info, - distributed_profile_on_mesh, - get_max_n_succ_stages) +from alpa.pipeline_parallel.stage_profiling import ( + generate_stage_info, distributed_profile_on_mesh, + get_merged_stages_memory_stats) from alpa.shard_parallel.auto_sharding import AutoShardingOption from alpa.testing import (get_bert_layer_train_state_and_step, get_mlp_train_state_and_step) @@ -206,12 +206,12 @@ def check_1d_2d_results_the_same(self, train_step, state, batch, profile_results_1d.append(result) # Compare - max_stage_2d, (available_memory_2d, peak_memory_2d, initial_size_2d, - intermediate_size_2d) = get_max_n_succ_stages( - [profile_results_2d]) - max_stage_1d, ( - available_memory_1d, peak_memory_1d, initial_size_1d, - intermediate_size_1d) = get_max_n_succ_stages(profile_results_1d) + (available_memory_2d, peak_memory_2d, initial_size_2d, + intermediate_size_2d, + max_stage_2d) = get_merged_stages_memory_stats([profile_results_2d]) + (available_memory_1d, peak_memory_1d, initial_size_1d, + intermediate_size_1d, + max_stage_1d) = get_merged_stages_memory_stats(profile_results_1d) assert available_memory_1d == available_memory_2d, ( f"available_memory_1d: {available_memory_1d}, " @@ -262,9 +262,9 @@ def check_2d_real_the_same(self): jax_pipeline_layers, accumulator_mapping, acc_grad_invars, acc_grad_outvars, jax_apply_layers, apply_grad_global_info, num_microbatch, 0, num_layers - 1) - max_stage_2d, (available_memory_2d, peak_memory_2d, initial_size_2d, - intermediate_size_2d) = get_max_n_succ_stages( - [profile_results_2d]) + (available_memory_2d, peak_memory_2d, initial_size_2d, + intermediate_size_2d, + max_stage_2d) = get_merged_stages_memory_stats([profile_results_2d]) # Real pipeshard_method = PipeshardParallel( From 7fa3cfc4782faa1baa9d47fb61122168ebd673db Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 3 Dec 2022 00:59:56 -0800 Subject: [PATCH 09/37] fix --- alpa/pipeline_parallel/stage_profiling.py | 8 ++++---- tests/pipeline_parallel/test_inference_auto.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/alpa/pipeline_parallel/stage_profiling.py b/alpa/pipeline_parallel/stage_profiling.py index e485c7a40..703780c8c 100644 --- a/alpa/pipeline_parallel/stage_profiling.py +++ b/alpa/pipeline_parallel/stage_profiling.py @@ -981,9 +981,9 @@ def generate_training_stages_1d(layers, accumulator_mapping, acc_grad_invars, num_layers = len(layers) // 2 stages = [] for l in tqdm.tqdm(range(0, num_layers)): - selected_apply_grad_layers = [ + selected_apply_grad_layers = ( [] if apply_grad_layers[l] is None else [apply_grad_layers[l]] - ] + ) stage_name = f"stage_{l}" stage_config = generate_stage_info(layers, [(l,), (2 * num_layers - l - 1,)], @@ -1007,9 +1007,9 @@ def generate_inference_stages_1d(layers, accumulator_mapping, acc_grad_invars, num_layers = len(layers) stages = [] for l in tqdm.tqdm(range(0, num_layers)): - selected_apply_grad_layers = [ + selected_apply_grad_layers = ( [] if apply_grad_layers[l] is None else [apply_grad_layers[l]] - ] + ) assert len(selected_apply_grad_layers) == 0, ( "Inference stage should not have apply_grad_layers") stage_name = f"stage_{l}" diff --git a/tests/pipeline_parallel/test_inference_auto.py b/tests/pipeline_parallel/test_inference_auto.py index 6143c4535..aa169ec96 100644 --- a/tests/pipeline_parallel/test_inference_auto.py +++ b/tests/pipeline_parallel/test_inference_auto.py @@ -1,6 +1,6 @@ import unittest from alpa import init, PipeshardParallel, AutoStageOption -from .test_inference_only import PipelineInferenceTest +from test_inference_only import PipelineInferenceTest class PipelineInferenceAutoTest(PipelineInferenceTest): From bf6c932fd7de5779107d61fe32775d0b74c81386 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 3 Dec 2022 01:02:54 -0800 Subject: [PATCH 10/37] fix --- alpa/pipeline_parallel/stage_profiling.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/alpa/pipeline_parallel/stage_profiling.py b/alpa/pipeline_parallel/stage_profiling.py index 703780c8c..4eb5a7b2b 100644 --- a/alpa/pipeline_parallel/stage_profiling.py +++ b/alpa/pipeline_parallel/stage_profiling.py @@ -981,9 +981,8 @@ def generate_training_stages_1d(layers, accumulator_mapping, acc_grad_invars, num_layers = len(layers) // 2 stages = [] for l in tqdm.tqdm(range(0, num_layers)): - selected_apply_grad_layers = ( - [] if apply_grad_layers[l] is None else [apply_grad_layers[l]] - ) + selected_apply_grad_layers = ([] if apply_grad_layers[l] is None else + [apply_grad_layers[l]]) stage_name = f"stage_{l}" stage_config = generate_stage_info(layers, [(l,), (2 * num_layers - l - 1,)], @@ -1007,9 +1006,8 @@ def generate_inference_stages_1d(layers, accumulator_mapping, acc_grad_invars, num_layers = len(layers) stages = [] for l in tqdm.tqdm(range(0, num_layers)): - selected_apply_grad_layers = ( - [] if apply_grad_layers[l] is None else [apply_grad_layers[l]] - ) + selected_apply_grad_layers = ([] if apply_grad_layers[l] is None else + [apply_grad_layers[l]]) assert len(selected_apply_grad_layers) == 0, ( "Inference stage should not have apply_grad_layers") stage_name = f"stage_{l}" @@ -1095,8 +1093,9 @@ def interpret_profile_result_inference_1d( profile_result.module_profile_results[0]. compute_cost for profile_result in selected_profile_results) - (available_memory, peak_memory, _, _, _ - ) = get_merged_stages_memory_stats(selected_profile_results) + (available_memory, peak_memory, _, _, + _) = get_merged_stages_memory_stats( + selected_profile_results, inference_mode=True) if peak_memory > available_memory: all_compute_cost[start, end, submesh_choice, config_idx] = np.inf From bf8cf0c97258c19c491c579ad8f1d5993dfe25dd Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 3 Dec 2022 01:07:42 -0800 Subject: [PATCH 11/37] add bert tests --- tests/pipeline_parallel/test_inference_auto.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/pipeline_parallel/test_inference_auto.py b/tests/pipeline_parallel/test_inference_auto.py index aa169ec96..7e0a1fee8 100644 --- a/tests/pipeline_parallel/test_inference_auto.py +++ b/tests/pipeline_parallel/test_inference_auto.py @@ -42,12 +42,25 @@ def test_mlp_1d(self): stage_option=stage_option) self.run_mlp_inference(True, method) + def test_bert_1d(self): + stage_option = AutoStageOption( + submesh_physical_shape_space="manual", + manually_specified_submeshes=((1, 2),), + submesh_logical_shape_space="model_parallel_only", + layer_profile_mode="individual") + method = PipeshardParallel(num_micro_batches=1, + pipeline_schedule="inference", + layer_option="manual", + stage_option=stage_option) + self.run_bert_layer_collection_inference(True, method) + def suite(): suite = unittest.TestSuite() suite.addTest(PipelineInferenceAutoTest("test_mlp")) suite.addTest(PipelineInferenceAutoTest("test_bert")) suite.addTest(PipelineInferenceAutoTest("test_mlp_1d")) + suite.addTest(PipelineInferenceAutoTest("test_bert_1d")) return suite From 6c23c7c898583fe52e62ac874cc622a945b0df2d Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 3 Dec 2022 16:25:13 -0800 Subject: [PATCH 12/37] fix --- tests/pipeline_parallel/test_inference_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipeline_parallel/test_inference_auto.py b/tests/pipeline_parallel/test_inference_auto.py index 7e0a1fee8..263f0c693 100644 --- a/tests/pipeline_parallel/test_inference_auto.py +++ b/tests/pipeline_parallel/test_inference_auto.py @@ -1,6 +1,6 @@ import unittest from alpa import init, PipeshardParallel, AutoStageOption -from test_inference_only import PipelineInferenceTest +from tests.pipeline_parallel.test_inference_only import PipelineInferenceTest class PipelineInferenceAutoTest(PipelineInferenceTest): From 8d41a5fad210403d85f0726b37689774f4c782c7 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 3 Dec 2022 17:16:05 -0800 Subject: [PATCH 13/37] test 1d stage construction --- benchmark/alpa/suite_inference_gpt.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/benchmark/alpa/suite_inference_gpt.py b/benchmark/alpa/suite_inference_gpt.py index c342c5a30..1c0071175 100644 --- a/benchmark/alpa/suite_inference_gpt.py +++ b/benchmark/alpa/suite_inference_gpt.py @@ -86,16 +86,28 @@ def get_config(model_config, # submesh_physical_shapes=[(1, 1)] * 8, # submesh_logical_shapes=[(1, 1)] * 8, # submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "search", + # SearchParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # num_auto_layers=50, + # auto_stage_option={ + # "submesh_physical_shape_space": "manual", + # "manually_specified_submeshes": ((1, 1),), + # "submesh_logical_shape_space": "model_parallel_only", + # })), BenchmarkCase( 1, gpt_specs["1.3B"], 1, "search", SearchParallelArgs( prefer_reduce_scatter, use_remat, - num_auto_layers=24, + num_auto_layers=50, auto_stage_option={ "submesh_physical_shape_space": "manual", "manually_specified_submeshes": ((1, 1),), "submesh_logical_shape_space": "model_parallel_only", + "layer_profile_mode": "individual", })), ] } From e7b67dbb63d368b525d6fe139b339095ebb22153 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 3 Dec 2022 17:53:26 -0800 Subject: [PATCH 14/37] add layer construction solution logging --- alpa/pipeline_parallel/layer_construction.py | 48 +++++++++++++------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/alpa/pipeline_parallel/layer_construction.py b/alpa/pipeline_parallel/layer_construction.py index 7eeda03ba..86aceca95 100644 --- a/alpa/pipeline_parallel/layer_construction.py +++ b/alpa/pipeline_parallel/layer_construction.py @@ -434,22 +434,7 @@ def dp(input_sizes, blocked): assert r == 0, "No solution for layer construction." solution = list(reversed(reversed_sliced_eqns)) - # print("dp solution") - # for i, eqns in enumerate(solution): - # invars = OrderedSet() - # for eqn in eqns: - # invars.update([var for var in eqn.invars if isinstance(var, Var)]) - # invars.intersection_update(jaxpr.jaxpr.invars) - # print(f"mesh: {i}, set_shapes: " - # f"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}") - # - # invars = [] - # for eqn in eqns: - # tmp_set = set([var for var in eqn.invars if isinstance(var, Var)]) - # tmp_set.intersection_update(jaxpr.jaxpr.invars) - # invars.extend(list(tmp_set)) - # print(f"mesh: {i}, list_shapes: " - # f"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}") + log_solution(solution, jaxpr) solution_info = { "total_cost": value, @@ -457,6 +442,37 @@ def dp(input_sizes, blocked): return solution, solution_info +def log_solution(solution, jaxpr): + print("-" * 80) + print(f"Layer construction solution ({len(solution)} layers):") + for i, eqns in enumerate(solution): + print("-" * 40) + print(f"Layer {i}:") + + total_flops = 0 + for j, eqn in enumerate(eqns): + flops = eqn_flops(eqn) + print(f"Eqn {j}: {eqn}, Flops: {flops}") + total_flops += flops + print(f"Total flops: {total_flops}") + invars = OrderedSet() + for eqn in eqns: + invars.update([var for var in eqn.invars if isinstance(var, Var)]) + invars.intersection_update(jaxpr.jaxpr.invars) + + print(f"set_invar_shapes: " + f"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}") + + invars = [] + for eqn in eqns: + tmp_set = set([var for var in eqn.invars if isinstance(var, Var)]) + tmp_set.intersection_update(jaxpr.jaxpr.invars) + invars.extend(list(tmp_set)) + print(f"list_invar_shapes: " + f"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}") + print("-" * 80) + + def search_layer_num(jaxpr, eps, layer_eps=0, From 27af75b89053cdd1e65effcc7daf0f797bf59c54 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 3 Dec 2022 23:12:38 -0800 Subject: [PATCH 15/37] add some searched results --- alpa/pipeline_parallel/layer_construction.py | 2 +- benchmark/alpa/suite_inference_gpt.py | 108 +++++++++++++------ 2 files changed, 77 insertions(+), 33 deletions(-) diff --git a/alpa/pipeline_parallel/layer_construction.py b/alpa/pipeline_parallel/layer_construction.py index 86aceca95..9a1c33e09 100644 --- a/alpa/pipeline_parallel/layer_construction.py +++ b/alpa/pipeline_parallel/layer_construction.py @@ -465,7 +465,7 @@ def log_solution(solution, jaxpr): invars = [] for eqn in eqns: - tmp_set = set([var for var in eqn.invars if isinstance(var, Var)]) + tmp_set = {var for var in eqn.invars if isinstance(var, Var)} tmp_set.intersection_update(jaxpr.jaxpr.invars) invars.extend(list(tmp_set)) print(f"list_invar_shapes: " diff --git a/benchmark/alpa/suite_inference_gpt.py b/benchmark/alpa/suite_inference_gpt.py index 1c0071175..e58c01a53 100644 --- a/benchmark/alpa/suite_inference_gpt.py +++ b/benchmark/alpa/suite_inference_gpt.py @@ -66,26 +66,79 @@ def get_config(model_config, test_suite = { 8: [ - # BenchmarkCase( - # 1, gpt_specs["1.3B"], 1, "uniform", - # UniformParallelArgs( - # prefer_reduce_scatter, - # use_remat, - # dp=1, - # op=1, - # pp=8, - # force_batch_dim_mapping=force_batch_dim_mapping)), - # BenchmarkCase( - # 1, gpt_specs["1.3B"], 1, "load_solution", - # LoadSolutionParallelArgs( - # prefer_reduce_scatter, - # use_remat, - # num_auto_layers=8, - # forward_stage_layer_ids=[[0], [1], [2], [3], [4], [5], [6], - # [7]], - # submesh_physical_shapes=[(1, 1)] * 8, - # submesh_logical_shapes=[(1, 1)] * 8, - # submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + BenchmarkCase( + 1, gpt_specs["1.3B"], 1, "uniform", + UniformParallelArgs( + prefer_reduce_scatter, + use_remat, + dp=1, + op=1, + pp=8, + force_batch_dim_mapping=force_batch_dim_mapping)), + BenchmarkCase( + 1, gpt_specs["1.3B"], 1, "load_solution", + LoadSolutionParallelArgs( + prefer_reduce_scatter, + use_remat, + num_auto_layers=8, + forward_stage_layer_ids=[[0], [1], [2], [3], [4], [5], [6], + [7]], + submesh_physical_shapes=[(1, 1)] * 8, + submesh_logical_shapes=[(1, 1)] * 8, + submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + # 2D + Profile + BenchmarkCase( + 1, gpt_specs["1.3B"], 1, "load_solution", + LoadSolutionParallelArgs( + prefer_reduce_scatter, + use_remat, + num_auto_layers=50, + forward_stage_layer_ids=[[0, 1, 2, 3, 4, 5], + [6, 7, 8, 9, 10, 11], + [12, 13, 14, 15, 16, 17, 18], + [19, 20, 21, 22, 23, 24, 25], + [26, 27, 28, 29, 30, 31], + [32, 33, 34, 35, 36, 37], + [38, 39, 40, 41, 42, 43, 44], + [45, 46, 47, 48, 49]], + submesh_physical_shapes=[(1, 1)] * 8, + submesh_logical_shapes=[(1, 1)] * 8, + submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + # 1D + Profile + BenchmarkCase( + 1, gpt_specs["1.3B"], 1, "load_solution", + LoadSolutionParallelArgs( + prefer_reduce_scatter, + use_remat, + num_auto_layers=50, + forward_stage_layer_ids=[[0, 1, 2, 3, 4, 5], + [6, 7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18], + [19, 20, 21, 22, 23, 24], + [25, 26, 27, 28, 29, 30, 31], + [32, 33, 34, 35, 36, 37], + [38, 39, 40, 41, 42, 43, 44], + [45, 46, 47, 48, 49]], + submesh_physical_shapes=[(1, 1)] * 8, + submesh_logical_shapes=[(1, 1)] * 8, + submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + # 1D + Cost model + BenchmarkCase( + 1, gpt_specs["1.3B"], 1, "load_solution", + LoadSolutionParallelArgs( + prefer_reduce_scatter, + use_remat, + num_auto_layers=50, + forward_stage_layer_ids=[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9, 10], + [11, 12, 13, 14, 15, 16], + [17, 18, 19, 20, 21, 22, 23], + [24, 25, 26, 27, 28, 29, 30], + [31, 32, 33, 34, 35, 36, 37], + [38, 39, 40, 41, 42, 43, 44], + [45, 46, 47, 48, 49]], + submesh_physical_shapes=[(1, 1)] * 8, + submesh_logical_shapes=[(1, 1)] * 8, + submesh_autosharding_option_dicts=[force_dp_dict] * 8)), # BenchmarkCase( # 1, gpt_specs["1.3B"], 1, "search", # SearchParallelArgs( @@ -96,18 +149,9 @@ def get_config(model_config, # "submesh_physical_shape_space": "manual", # "manually_specified_submeshes": ((1, 1),), # "submesh_logical_shape_space": "model_parallel_only", + # "layer_profile_mode": "individual", + # "use_hlo_cost_model": True, + # "profiling_database_filename": "prof_database.pkl", # })), - BenchmarkCase( - 1, gpt_specs["1.3B"], 1, "search", - SearchParallelArgs( - prefer_reduce_scatter, - use_remat, - num_auto_layers=50, - auto_stage_option={ - "submesh_physical_shape_space": "manual", - "manually_specified_submeshes": ((1, 1),), - "submesh_logical_shape_space": "model_parallel_only", - "layer_profile_mode": "individual", - })), ] } From 8efe119f8475ba885a9592d8647033994f348f5c Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 4 Dec 2022 02:26:58 -0800 Subject: [PATCH 16/37] test new layer construction --- alpa/pipeline_parallel/layer_construction.py | 77 ++++++++++++++++++-- 1 file changed, 72 insertions(+), 5 deletions(-) diff --git a/alpa/pipeline_parallel/layer_construction.py b/alpa/pipeline_parallel/layer_construction.py index 9a1c33e09..480ebb62f 100644 --- a/alpa/pipeline_parallel/layer_construction.py +++ b/alpa/pipeline_parallel/layer_construction.py @@ -442,6 +442,76 @@ def dp(input_sizes, blocked): return solution, solution_info +def cluster_jaxpr_by_cost_optimized(jaxpr: Jaxpr, layer_num: int, costs, + cost_criteria): + """Clusters the jaxpr by cost.""" + layer_num = int(layer_num) + length = len(jaxpr.eqns) + _, input_sizes, compute_costs = costs + assert cost_criteria == "flops" + FLOPS_NORMALIZER = 30 * 1e12 # 30 TFLOPS + NETWORK_NORMALIZER = 2 * 1e9 # 2 GB / s + + @maybe_numba_jit + def init_layer_costs(): + layer_costs = np.full((length + 1, layer_num + 1), + np.inf, + dtype=np.float32) + for l in range(1, length + 1): + layer_flops = 0 + for r in range(l, length + 1): + layer_flops += compute_costs[r - 1] + layer_costs[l, r] = (layer_flops / FLOPS_NORMALIZER + + input_sizes[l - 1, r] / NETWORK_NORMALIZER) + return layer_costs + + @maybe_numba_jit + def dp(layer_costs): + max_cost = np.full((length + 1, layer_num + 1), + np.inf, + dtype=np.float32) + sum_cost_under_max = np.full((length + 1, layer_num + 1), + np.inf, + dtype=np.float32) + max_cost_argmin = np.full((length + 1, layer_num + 1), + -1, + dtype=np.int32) + max_cost[0, 0] = 0 + sum_cost_under_max[0, 0] = 0 + for q in range(1, layer_num + 1): + for r in range(1, length + 1): + for k in range(0, r): + new_value = max(max_cost[k, q - 1], layer_costs[k, r]) + new_sum = (sum_cost_under_max[k, q - 1] + layer_costs[k, r]) + if (new_value < max_cost[r, q] or + (new_value <= max_cost[r, q] * + (1 + 1e-4) and new_sum < sum_cost_under_max[r, q])): + max_cost[r, q] = new_value + sum_cost_under_max[r, q] = new_sum + max_cost_argmin[r, q] = k + return max_cost_argmin, max_cost[length, layer_num] + + layer_costs = init_layer_costs() + a_argmin, value = dp(layer_costs) + + reversed_sliced_eqns = [] + + r = length + for q in range(layer_num, 0, -1): + k = a_argmin[r, q] + reversed_sliced_eqns.append(jaxpr.eqns[k:r]) + r = k + assert r == 0, "No solution for layer construction." + solution = list(reversed(reversed_sliced_eqns)) + + log_solution(solution, jaxpr) + + solution_info = { + "total_cost": value, + } + return solution, solution_info + + def log_solution(solution, jaxpr): print("-" * 80) print(f"Layer construction solution ({len(solution)} layers):") @@ -527,11 +597,8 @@ def wrapped(*args): layer_num = search_layer_num(jaxpr, eps, layer_eps) costs = get_layer_construction_costs(jaxpr, cost_criteria=cost_criteria) - sliced_eqns, _ = cluster_jaxpr_by_cost(jaxpr, - layer_num, - eps, - costs, - cost_criteria=cost_criteria) + sliced_eqns, _ = cluster_jaxpr_by_cost_optimized( + jaxpr, layer_num, costs, cost_criteria=cost_criteria) else: sliced_eqns = slice_eqns_by_layer_boundary(jaxpr) From 7bc75836f99749be473bab96dacacbce55304d67 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 4 Dec 2022 15:07:14 -0800 Subject: [PATCH 17/37] fix indexing error --- alpa/pipeline_parallel/layer_construction.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/alpa/pipeline_parallel/layer_construction.py b/alpa/pipeline_parallel/layer_construction.py index 480ebb62f..20dbca7fc 100644 --- a/alpa/pipeline_parallel/layer_construction.py +++ b/alpa/pipeline_parallel/layer_construction.py @@ -454,15 +454,15 @@ def cluster_jaxpr_by_cost_optimized(jaxpr: Jaxpr, layer_num: int, costs, @maybe_numba_jit def init_layer_costs(): - layer_costs = np.full((length + 1, layer_num + 1), + layer_costs = np.full((length, length), np.inf, dtype=np.float32) - for l in range(1, length + 1): + for l in range(0, length): layer_flops = 0 - for r in range(l, length + 1): - layer_flops += compute_costs[r - 1] + for r in range(l, length): + layer_flops += compute_costs[r] layer_costs[l, r] = (layer_flops / FLOPS_NORMALIZER + - input_sizes[l - 1, r] / NETWORK_NORMALIZER) + input_sizes[l, r + 1] / NETWORK_NORMALIZER) return layer_costs @maybe_numba_jit @@ -481,8 +481,8 @@ def dp(layer_costs): for q in range(1, layer_num + 1): for r in range(1, length + 1): for k in range(0, r): - new_value = max(max_cost[k, q - 1], layer_costs[k, r]) - new_sum = (sum_cost_under_max[k, q - 1] + layer_costs[k, r]) + new_value = max(max_cost[k, q - 1], layer_costs[k, r - 1]) + new_sum = (sum_cost_under_max[k, q - 1] + layer_costs[k, r - 1]) if (new_value < max_cost[r, q] or (new_value <= max_cost[r, q] * (1 + 1e-4) and new_sum < sum_cost_under_max[r, q])): From a6cc06e0406e2448d8218b9042b9e4bd909f2f25 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 4 Dec 2022 16:31:36 -0800 Subject: [PATCH 18/37] add squared cost --- alpa/pipeline_parallel/layer_construction.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/alpa/pipeline_parallel/layer_construction.py b/alpa/pipeline_parallel/layer_construction.py index 20dbca7fc..9115a0d21 100644 --- a/alpa/pipeline_parallel/layer_construction.py +++ b/alpa/pipeline_parallel/layer_construction.py @@ -454,9 +454,7 @@ def cluster_jaxpr_by_cost_optimized(jaxpr: Jaxpr, layer_num: int, costs, @maybe_numba_jit def init_layer_costs(): - layer_costs = np.full((length, length), - np.inf, - dtype=np.float32) + layer_costs = np.full((length, length), np.inf, dtype=np.float32) for l in range(0, length): layer_flops = 0 for r in range(l, length): @@ -473,21 +471,33 @@ def dp(layer_costs): sum_cost_under_max = np.full((length + 1, layer_num + 1), np.inf, dtype=np.float32) + squared_cost_under_max = np.full((length + 1, layer_num + 1), + np.inf, + dtype=np.float32) max_cost_argmin = np.full((length + 1, layer_num + 1), -1, dtype=np.int32) max_cost[0, 0] = 0 sum_cost_under_max[0, 0] = 0 + squared_cost_under_max[0, 0] = 0 for q in range(1, layer_num + 1): for r in range(1, length + 1): for k in range(0, r): new_value = max(max_cost[k, q - 1], layer_costs[k, r - 1]) - new_sum = (sum_cost_under_max[k, q - 1] + layer_costs[k, r - 1]) + new_sum = (sum_cost_under_max[k, q - 1] + + layer_costs[k, r - 1]) + new_squared_sum = (sum_cost_under_max[k, q - 1] + + layer_costs[k, r - 1]**2) if (new_value < max_cost[r, q] or (new_value <= max_cost[r, q] * - (1 + 1e-4) and new_sum < sum_cost_under_max[r, q])): + (1 + 1e-4) and new_sum < sum_cost_under_max[r, q]) or + (new_value <= max_cost[r, q] * + (1 + 1e-4) and new_sum <= sum_cost_under_max[r, q] * + (1 + 1e-4) and + new_squared_sum < squared_cost_under_max[r, q])): max_cost[r, q] = new_value sum_cost_under_max[r, q] = new_sum + squared_cost_under_max[r, q] = new_squared_sum max_cost_argmin[r, q] = k return max_cost_argmin, max_cost[length, layer_num] From 006c7817adee745c18e1765845ac1446d83300e0 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 4 Dec 2022 21:41:21 -0800 Subject: [PATCH 19/37] fix bug --- alpa/pipeline_parallel/layer_construction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/alpa/pipeline_parallel/layer_construction.py b/alpa/pipeline_parallel/layer_construction.py index 9115a0d21..3d5565f72 100644 --- a/alpa/pipeline_parallel/layer_construction.py +++ b/alpa/pipeline_parallel/layer_construction.py @@ -486,7 +486,7 @@ def dp(layer_costs): new_value = max(max_cost[k, q - 1], layer_costs[k, r - 1]) new_sum = (sum_cost_under_max[k, q - 1] + layer_costs[k, r - 1]) - new_squared_sum = (sum_cost_under_max[k, q - 1] + + new_squared_sum = (squared_cost_under_max[k, q - 1] + layer_costs[k, r - 1]**2) if (new_value < max_cost[r, q] or (new_value <= max_cost[r, q] * From 20c3ddbb64a11ddfd7e8b1e287ab7a2f61ed9913 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sun, 4 Dec 2022 22:10:20 -0800 Subject: [PATCH 20/37] do not use sum, only use squared sum --- alpa/pipeline_parallel/layer_construction.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/alpa/pipeline_parallel/layer_construction.py b/alpa/pipeline_parallel/layer_construction.py index 3d5565f72..2464b4b55 100644 --- a/alpa/pipeline_parallel/layer_construction.py +++ b/alpa/pipeline_parallel/layer_construction.py @@ -484,19 +484,12 @@ def dp(layer_costs): for r in range(1, length + 1): for k in range(0, r): new_value = max(max_cost[k, q - 1], layer_costs[k, r - 1]) - new_sum = (sum_cost_under_max[k, q - 1] + - layer_costs[k, r - 1]) new_squared_sum = (squared_cost_under_max[k, q - 1] + layer_costs[k, r - 1]**2) if (new_value < max_cost[r, q] or - (new_value <= max_cost[r, q] * - (1 + 1e-4) and new_sum < sum_cost_under_max[r, q]) or - (new_value <= max_cost[r, q] * - (1 + 1e-4) and new_sum <= sum_cost_under_max[r, q] * - (1 + 1e-4) and + (new_value <= max_cost[r, q] * (1 + 1e-4) and new_squared_sum < squared_cost_under_max[r, q])): max_cost[r, q] = new_value - sum_cost_under_max[r, q] = new_sum squared_cost_under_max[r, q] = new_squared_sum max_cost_argmin[r, q] = k return max_cost_argmin, max_cost[length, layer_num] From 04177f8e77d0caec0ddd6f76c2a28dcefb20728d Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 5 Dec 2022 00:01:00 -0800 Subject: [PATCH 21/37] add search suite --- benchmark/alpa/benchmark.py | 1 + benchmark/alpa/run_exp.py | 4 ++ benchmark/alpa/suite_inference_gpt.py | 56 ++++++++++++++++++++------- 3 files changed, 47 insertions(+), 14 deletions(-) diff --git a/benchmark/alpa/benchmark.py b/benchmark/alpa/benchmark.py index 63f6caf9c..6fe852f0c 100644 --- a/benchmark/alpa/benchmark.py +++ b/benchmark/alpa/benchmark.py @@ -29,6 +29,7 @@ "gpt.correctness_test_auto": suite_auto_gpt.correctness_test_suite, "gpt_inference.profile": suite_inference_gpt.profile_suite, "gpt_inference.test": suite_inference_gpt.test_suite, + "gpt_inference.search": suite_inference_gpt.search_suite, "gpt_no_embedding_inference.profile": suite_inference_gpt.profile_suite, "moe.tmp": suite_manual_moe.tmp_suite, "moe.tmp_auto": suite_auto_moe.tmp_suite, diff --git a/benchmark/alpa/run_exp.py b/benchmark/alpa/run_exp.py index cb5632d8f..5899796ed 100644 --- a/benchmark/alpa/run_exp.py +++ b/benchmark/alpa/run_exp.py @@ -39,6 +39,10 @@ def run_exp(exp_name, cluster_settings, suite_name, benchmark_settings=None): "niter": 10, "profile_stage_execution_time": True }), + "gpt_inference_search": ("gpt_inference.search", { + "niter": 10, + "profile_stage_execution_time": True + }), "moe_inference": ("moe_inference.profile", { "niter": 10, "profile_stage_execution_time": True diff --git a/benchmark/alpa/suite_inference_gpt.py b/benchmark/alpa/suite_inference_gpt.py index e58c01a53..587245e60 100644 --- a/benchmark/alpa/suite_inference_gpt.py +++ b/benchmark/alpa/suite_inference_gpt.py @@ -139,19 +139,47 @@ def get_config(model_config, submesh_physical_shapes=[(1, 1)] * 8, submesh_logical_shapes=[(1, 1)] * 8, submesh_autosharding_option_dicts=[force_dp_dict] * 8)), - # BenchmarkCase( - # 1, gpt_specs["1.3B"], 1, "search", - # SearchParallelArgs( - # prefer_reduce_scatter, - # use_remat, - # num_auto_layers=50, - # auto_stage_option={ - # "submesh_physical_shape_space": "manual", - # "manually_specified_submeshes": ((1, 1),), - # "submesh_logical_shape_space": "model_parallel_only", - # "layer_profile_mode": "individual", - # "use_hlo_cost_model": True, - # "profiling_database_filename": "prof_database.pkl", - # })), ] } + +search_suite = {} + + +def generate_search_configs(model_config, num_auto_layers, pp_list, op_list): + """Generate search configs.""" + for pp in pp_list: + for op in op_list: + num_gpus = pp * op + if num_gpus not in search_suite: + search_suite[num_gpus] = [] + search_suite[num_gpus].append( + BenchmarkCase( + 1, + model_config, + 1, + "search", + SearchParallelArgs( + prefer_reduce_scatter, + use_remat, + num_auto_layers=num_auto_layers, + auto_stage_option={ + "submesh_physical_shape_space": + "manual", + "manually_specified_submeshes": ((1, op),), + "submesh_logical_shape_space": + "model_parallel_only", + "layer_profile_mode": + "individual", + # "use_hlo_cost_model": True, + # "profiling_database_filename": + # "prof_database.pkl", + }))) + + +generate_search_configs(gpt_specs["1.3B"], 50, [8], [1]) +# generate_search_configs(gpt_specs["1.3B"], 50, [1, 2, 4, 8, 16, 32], +# [1, 2, 4, 8]) +# generate_search_configs(gpt_specs["2.6B"], 66, [1, 2, 4, 8, 16, 32], +# [1, 2, 4, 8]) +# generate_search_configs(gpt_specs["6.7B"], 66, [1, 2, 4, 8, 16, 32], +# [1, 2, 4, 8]) From eab6cbd25cc2255e2be7351b9db8283aec9b808e Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 5 Dec 2022 00:11:44 -0800 Subject: [PATCH 22/37] add metadata --- .../benchmark_one_case_gpt_bert_inference.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py index 32072ed69..194e6c57b 100644 --- a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py +++ b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py @@ -11,6 +11,7 @@ from alpa.model.bert_model import BertConfig, FlaxBertLayerCollection from alpa.model.gpt_model import FlaxGPTForLMModule from alpa.util import print_used_time, GB, write_tsv +from alpa.pipeline_parallel.stage_construction import get_last_dp_result from util import compute_gpt_parameter_count, compute_gpt_tflops from benchmark_parallel_utils import ( @@ -209,6 +210,18 @@ def benchmark_gpt_inference_internal(model_type, f"{per_stage_peak_mem}", avg_stage_latencies ] else: + (compute_cost_file_name, forward_stage_layer_ids, submesh_shapes, + logical_mesh_shapes, + autosharding_option_dicts) = get_last_dp_result() + metadata = { + "compilation_times": compilation_times, + "compute_cost_file_name": compute_cost_file_name, + "forward_stage_layer_ids": forward_stage_layer_ids, + "submesh_shapes": submesh_shapes, + "logical_mesh_shapes": logical_mesh_shapes, + "autosharding_option_dicts": autosharding_option_dicts, + } + timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") executable.dump_stage_execution_trace( f"./chrome_trace/{model_name}," @@ -217,14 +230,14 @@ def benchmark_gpt_inference_internal(model_type, heads = [ "ModelName", "BS", "#Microbatch", "ParallelArgs", "MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)", "StagePeakMem(B)", - "StageLatencies(s)", "TimeStamp" + "StageLatencies(s)", "Metadata", "TimeStamp" ] values = [ model_name, benchmark_case.batch_size, benchmark_case.num_micro_batches, parallel_args, f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", f"{tflops:.2f}", f"{per_stage_weight_mem}", - f"{per_stage_peak_mem}", avg_stage_latencies, timestamp + f"{per_stage_peak_mem}", avg_stage_latencies, metadata, timestamp ] write_tsv(heads, values, f"benchmark_results.tsv") From 759d04d428e9dabc75bbb7fdff610f53451faefd Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 5 Dec 2022 00:21:04 -0800 Subject: [PATCH 23/37] print python list --- benchmark/alpa/benchmark_one_case_gpt_bert_inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py index 194e6c57b..146763bce 100644 --- a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py +++ b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py @@ -236,8 +236,9 @@ def benchmark_gpt_inference_internal(model_type, model_name, benchmark_case.batch_size, benchmark_case.num_micro_batches, parallel_args, f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", - f"{tflops:.2f}", f"{per_stage_weight_mem}", - f"{per_stage_peak_mem}", avg_stage_latencies, metadata, timestamp + f"{tflops:.2f}", f"{list(per_stage_weight_mem)}", + f"{list(per_stage_peak_mem)}", avg_stage_latencies, metadata, + timestamp ] write_tsv(heads, values, f"benchmark_results.tsv") From de4a36562fd2294cb0c17fdf39446f020ea7b3bc Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 5 Dec 2022 00:28:05 -0800 Subject: [PATCH 24/37] fix --- benchmark/alpa/benchmark_one_case_gpt_bert_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py index 146763bce..5067cf87a 100644 --- a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py +++ b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py @@ -237,8 +237,8 @@ def benchmark_gpt_inference_internal(model_type, benchmark_case.num_micro_batches, parallel_args, f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", f"{tflops:.2f}", f"{list(per_stage_weight_mem)}", - f"{list(per_stage_peak_mem)}", avg_stage_latencies, metadata, - timestamp + f"{list(per_stage_peak_mem)}", + list(avg_stage_latencies), metadata, timestamp ] write_tsv(heads, values, f"benchmark_results.tsv") From 6a589a684792f591ca6133f90c12306d94b64cc0 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 6 Dec 2022 00:22:27 -0800 Subject: [PATCH 25/37] fix head ip address --- alpa/device_mesh.py | 14 +------------- third_party/tensorflow-alpa | 2 +- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/alpa/device_mesh.py b/alpa/device_mesh.py index c91e2aa6d..4443aa25d 100644 --- a/alpa/device_mesh.py +++ b/alpa/device_mesh.py @@ -958,7 +958,6 @@ class DistributedPhysicalDeviceMesh(PhysicalDeviceMesh): def __init__(self, host_ids: Sequence[int], host_info: Sequence[dict], - head_ip: str, num_devices_per_host: int, parent: Optional["VirtualPhysicalMesh"] = None, devices: Optional[Sequence[Sequence[int]]] = None, @@ -967,7 +966,6 @@ def __init__(self, # host_ids are the indices of hosts in the global DeviceCluster self.host_ids = host_ids self.host_info = host_info - self.head_ip = head_ip self.num_hosts = len(host_ids) self.num_devices_per_host = num_devices_per_host self.parent = parent @@ -1035,7 +1033,7 @@ def launch_xla_servers(self): port = np.random.randint(20000, 25000) used_port_set.add(port) - server_address = f"{self.head_ip}:{port}" + server_address = f"{ray.util.get_node_ip_address()}:{port}" logger.debug(f"Trying to start XLA gRPC server on port: {port}...") service_server = xla_client._xla.get_distributed_runtime_service( server_address, self.num_hosts, use_coordination_service=False) @@ -1127,7 +1125,6 @@ def get_virtual_physical_mesh(self): return VirtualPhysicalMesh( host_ids=self.host_ids, host_info=self.host_info, - head_ip=self.head_ip, num_devices_per_host=self.num_devices_per_host, parent=self, devices=self.devices) @@ -1795,14 +1792,12 @@ class VirtualPhysicalMesh: def __init__(self, host_ids: Sequence[int], host_info: Sequence[dict], - head_ip, num_devices_per_host, parent: "VirtualPhysicalMesh" = None, devices: Sequence[Sequence[int]] = None): # host_ids are the indices of hosts in the global DeviceCluster self.host_ids = host_ids self.host_info = host_info - self.head_ip = head_ip self.num_devices_per_host = num_devices_per_host self.parent = parent @@ -1860,7 +1855,6 @@ def slice_1d(self, dim: int, indices: Sequence[int]): return VirtualPhysicalMesh( host_ids=host_ids, host_info=host_info, - head_ip=self.head_ip, num_devices_per_host=self.num_devices_per_host, parent=self) else: @@ -1873,7 +1867,6 @@ def slice_1d(self, dim: int, indices: Sequence[int]): return VirtualPhysicalMesh(host_ids=self.host_ids, host_info=self.host_info, - head_ip=self.head_ip, num_devices_per_host=len(indices[0]), parent=self, devices=indices) @@ -1889,7 +1882,6 @@ def slice_2d(self, host_indices, device_indices): return VirtualPhysicalMesh(host_ids=host_ids, host_info=host_info, - head_ip=self.head_ip, num_devices_per_host=len(device_indices[0]), parent=self, devices=device_indices) @@ -1939,7 +1931,6 @@ def get_physical_mesh(self, mesh_id: int = 0): self.launched_physical_mesh = DistributedPhysicalDeviceMesh( host_ids=self.host_ids, host_info=self.host_info, - head_ip=self.head_ip, num_devices_per_host=self.num_devices_per_host, parent=self, devices=self.devices, @@ -2142,7 +2133,6 @@ def __init__(self, raise RuntimeError( "Cannot access ray global node. Did you call ray.init?") \ from ae - self.head_ip = self.head_info["node_ip_address"] # Gather host ids self.host_info = [] @@ -2265,7 +2255,6 @@ def get_physical_mesh(self, return DistributedPhysicalDeviceMesh( host_ids=host_ids, host_info=host_info, - head_ip=self.head_ip, num_devices_per_host=num_devices_per_host, parent=self, namespace=self.namespace) @@ -2289,7 +2278,6 @@ def get_virtual_physical_mesh(self, return VirtualPhysicalMesh(host_ids=host_ids, host_info=host_info, - head_ip=self.head_ip, num_devices_per_host=num_devices_per_host, parent=self) diff --git a/third_party/tensorflow-alpa b/third_party/tensorflow-alpa index 721260d12..cd865615b 160000 --- a/third_party/tensorflow-alpa +++ b/third_party/tensorflow-alpa @@ -1 +1 @@ -Subproject commit 721260d122f096040762b2d226b37e8ab23f74b8 +Subproject commit cd865615b9b518bc507fbdc71dc44c7cc76618ac From 00c81142d9d17b6cf31f9a232853396d19aff67e Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 6 Dec 2022 00:43:13 -0800 Subject: [PATCH 26/37] fix available memory bug --- alpa/pipeline_parallel/stage_profiling.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/alpa/pipeline_parallel/stage_profiling.py b/alpa/pipeline_parallel/stage_profiling.py index 4eb5a7b2b..83e8af630 100644 --- a/alpa/pipeline_parallel/stage_profiling.py +++ b/alpa/pipeline_parallel/stage_profiling.py @@ -125,11 +125,8 @@ def add_module_profile_result(self, module_idx, result): if self.available_memory is None: self.available_memory = result.available_memory else: - assert self.available_memory == result.available_memory, ( - f"available_memory is not consistent: {self.available_memory} " - f"vs {result.available_memory}. This may be caused by " - f"mismatch of loaded profile results and newly profiled " - f"results.") + self.available_memory = min(self.available_memory, + result.available_memory) def __str__(self): total_initial_var_size = sum(self.initial_var_sizes) From c1211cce634a231b59d9d097f20f2a15b0a28e13 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 6 Dec 2022 01:30:23 -0800 Subject: [PATCH 27/37] fix more available memory --- alpa/pipeline_parallel/stage_profiling.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/alpa/pipeline_parallel/stage_profiling.py b/alpa/pipeline_parallel/stage_profiling.py index 83e8af630..155fcd1e8 100644 --- a/alpa/pipeline_parallel/stage_profiling.py +++ b/alpa/pipeline_parallel/stage_profiling.py @@ -769,9 +769,8 @@ def get_merged_stages_memory_stats( f"vs. {size}.") initial_size = sum(initial_var_sizes_dict.values()) peak_memory = 0 - available_memory = profile_results[0].available_memory - assert all(result.available_memory == available_memory - for result in profile_results) + available_memory = min( + result.available_memory for result in profile_results) n_stages = len(profile_results) n_modules = profile_results[0].n_modules if inference_mode: From 4e4ca90cfaf3062aa362e1af9fe3b6594042a798 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 6 Dec 2022 02:11:37 -0800 Subject: [PATCH 28/37] use old layer construction and generate moe search config --- alpa/pipeline_parallel/layer_construction.py | 4 +- benchmark/alpa/suite_inference_moe.py | 46 +++++++++++++++++++- 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/alpa/pipeline_parallel/layer_construction.py b/alpa/pipeline_parallel/layer_construction.py index 2464b4b55..c5289e65b 100644 --- a/alpa/pipeline_parallel/layer_construction.py +++ b/alpa/pipeline_parallel/layer_construction.py @@ -600,8 +600,8 @@ def wrapped(*args): layer_num = search_layer_num(jaxpr, eps, layer_eps) costs = get_layer_construction_costs(jaxpr, cost_criteria=cost_criteria) - sliced_eqns, _ = cluster_jaxpr_by_cost_optimized( - jaxpr, layer_num, costs, cost_criteria=cost_criteria) + sliced_eqns, _ = cluster_jaxpr_by_cost( + jaxpr, layer_num, eps, costs, cost_criteria=cost_criteria) else: sliced_eqns = slice_eqns_by_layer_boundary(jaxpr) diff --git a/benchmark/alpa/suite_inference_moe.py b/benchmark/alpa/suite_inference_moe.py index 2c3437ef7..017e0f8e2 100644 --- a/benchmark/alpa/suite_inference_moe.py +++ b/benchmark/alpa/suite_inference_moe.py @@ -1,6 +1,8 @@ """Benchmark suites for gpt with auto parallelization.""" from suite_manual_moe import moe_specs -from benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs) +from benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs, + LoadSolutionParallelArgs, + SearchParallelArgs) prefer_reduce_scatter = True force_batch_dim_mapping = True @@ -44,3 +46,45 @@ def get_config(model_config, [1, 2, 4, 8, 16]) get_config(moe_specs["10B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], [1, 2, 4, 8, 16]) + +search_suite = {} + + +def generate_search_configs(model_config, num_auto_layers, pp_list, op_list): + """Generate search configs.""" + for pp in pp_list: + for op in op_list: + num_gpus = pp * op + if num_gpus not in search_suite: + search_suite[num_gpus] = [] + search_suite[num_gpus].append( + BenchmarkCase( + 1, + model_config, + 1, + "search", + SearchParallelArgs( + prefer_reduce_scatter, + use_remat, + num_auto_layers=num_auto_layers, + auto_stage_option={ + "submesh_physical_shape_space": + "manual", + "manually_specified_submeshes": ((1, op),), + "submesh_logical_shape_space": + "model_parallel_only", + "layer_profile_mode": + "individual", + # "use_hlo_cost_model": True, + # "profiling_database_filename": + # "prof_database.pkl", + }))) + + +generate_search_configs(moe_specs["1.3B"], 34, [8], [1]) +# generate_search_configs(moe_specs["1.3B"], 50, [1, 2, 4, 8, 16, 32], +# [1, 2, 4, 8]) +# generate_search_configs(moe_specs["2.4B"], 66, [1, 2, 4, 8, 16, 32], +# [1, 2, 4, 8]) +# generate_search_configs(moe_specs["7.1B"], 66, [1, 2, 4, 8, 16, 32], +# [1, 2, 4, 8]) From ce7e4e7685e501fc12d3d92581cd2e8141039662 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 6 Dec 2022 02:13:14 -0800 Subject: [PATCH 29/37] add moe search profile --- benchmark/alpa/run_exp.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/benchmark/alpa/run_exp.py b/benchmark/alpa/run_exp.py index 5899796ed..e31a3b560 100644 --- a/benchmark/alpa/run_exp.py +++ b/benchmark/alpa/run_exp.py @@ -47,6 +47,10 @@ def run_exp(exp_name, cluster_settings, suite_name, benchmark_settings=None): "niter": 10, "profile_stage_execution_time": True }), + "moe_inference_search": ("moe_inference.profile", { + "niter": 10, + "profile_stage_execution_time": True + }), "gpt_no_embedding_inference": ("gpt_no_embedding_inference.profile", {}), "gpt_inference_streaming": ("gpt_inference.profile", { "profile_driver_time": True From ff2ddbf1280df3d8ab50a2760fe4df6bc1d6f1a9 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 6 Dec 2022 02:15:57 -0800 Subject: [PATCH 30/37] fix --- benchmark/alpa/run_exp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/alpa/run_exp.py b/benchmark/alpa/run_exp.py index e31a3b560..1951c5a54 100644 --- a/benchmark/alpa/run_exp.py +++ b/benchmark/alpa/run_exp.py @@ -47,7 +47,7 @@ def run_exp(exp_name, cluster_settings, suite_name, benchmark_settings=None): "niter": 10, "profile_stage_execution_time": True }), - "moe_inference_search": ("moe_inference.profile", { + "moe_inference_search": ("moe_inference.search", { "niter": 10, "profile_stage_execution_time": True }), From e9bbb584e103292216c6061675f1bc0fccaeb866 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 6 Dec 2022 02:17:21 -0800 Subject: [PATCH 31/37] fix --- benchmark/alpa/benchmark.py | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmark/alpa/benchmark.py b/benchmark/alpa/benchmark.py index 6fe852f0c..7128ac0d0 100644 --- a/benchmark/alpa/benchmark.py +++ b/benchmark/alpa/benchmark.py @@ -37,6 +37,7 @@ "moe.perf_test_auto": suite_auto_moe.perf_test_suite, "moe.grid_search_auto": suite_auto_moe.grid_search_suite, "moe_inference.profile": suite_inference_moe.profile_suite, + "moe_inference.search": suite_inference_moe.search_suite, "unet.perf_test_auto": suite_unet.perf_test_auto_suite, "unet.grid_search_auto": suite_unet.grid_search_auto_suite, "wresnet.perf_test_2d": suite_wresnet.perf_test_2d_suite, From b8f9fd3e5b64fe157e56a2965c9ce7eda9b64e0c Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 6 Dec 2022 02:21:21 -0800 Subject: [PATCH 32/37] fix --- .../alpa/benchmark_one_case_moe_inference.py | 76 ++++++++++++++----- 1 file changed, 55 insertions(+), 21 deletions(-) diff --git a/benchmark/alpa/benchmark_one_case_moe_inference.py b/benchmark/alpa/benchmark_one_case_moe_inference.py index f434c1331..9ab82f130 100644 --- a/benchmark/alpa/benchmark_one_case_moe_inference.py +++ b/benchmark/alpa/benchmark_one_case_moe_inference.py @@ -23,7 +23,7 @@ def create_infer_params_aval(rngkey, model, batch): params = jax.eval_shape( lambda p: jax.tree_util.tree_map( lambda x: jnp.asarray(x, dtype=jnp.float16), p), params) - return params + return params def get_infer_step(parallel_method, model): @@ -41,7 +41,7 @@ def infer_step(params, batch, rng_key): loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) loss = (label_mask * loss).sum() / label_mask.sum() return loss - + return parallelize(infer_step, method=parallel_method, donate_argnums=()) @@ -167,32 +167,66 @@ def benchmark_moe_inference_internal(benchmark_case, # Log per-stage execution information if needed if profile_stage_execution_time: - model_name = f"moe-{parameter_count/1e9:.1f}b" - # dump chrome trace - executable.dump_stage_execution_trace( - f"./chrome_trace/{model_name},bs={benchmark_case.batch_size},op={benchmark_case.parallel_args.op},pp={benchmark_case.parallel_args.pp}.json" - ) + model_name = f"bert-{parameter_count/1e9:.1f}b" # compute and log per-stage latency/memory statistics exec_info = executable.get_stage_execution_info() timelines = list(zip(*exec_info)) # drop warmup case timelines = timelines[1:] avg_stage_latencies = compute_avg_stage_latencies(timelines) - assert len(avg_stage_latencies) == num_manual_pipeline_stages + if num_manual_pipeline_stages is not None: + assert len(avg_stage_latencies) == num_manual_pipeline_stages parallel_args = benchmark_case.parallel_args - dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp - heads = [ - "ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU", - "MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)", - "StagePeakMem(B)", "StageLatencies(s)" - ] - values = [ - model_name, benchmark_case.batch_size, - benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp, - f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", - f"{tflops:.2f}", f"{per_stage_weight_mem}", f"{per_stage_peak_mem}", - avg_stage_latencies - ] + if benchmark_case.parallel_mode == "uniform": + dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp + # dump chrome trace + executable.dump_stage_execution_trace( + f"./chrome_trace/{model_name}," + f"bs={benchmark_case.batch_size}," + f"op={op},pp={pp}.json") + heads = [ + "ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU", + "MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)", + "StagePeakMem(B)", "StageLatencies(s)" + ] + values = [ + model_name, benchmark_case.batch_size, + benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp, + f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", + f"{tflops:.2f}", f"{per_stage_weight_mem}", + f"{per_stage_peak_mem}", avg_stage_latencies + ] + else: + (compute_cost_file_name, forward_stage_layer_ids, submesh_shapes, + logical_mesh_shapes, + autosharding_option_dicts) = get_last_dp_result() + metadata = { + "compilation_times": compilation_times, + "compute_cost_file_name": compute_cost_file_name, + "forward_stage_layer_ids": forward_stage_layer_ids, + "submesh_shapes": submesh_shapes, + "logical_mesh_shapes": logical_mesh_shapes, + "autosharding_option_dicts": autosharding_option_dicts, + } + + timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + executable.dump_stage_execution_trace( + f"./chrome_trace/{model_name}," + f"bs={benchmark_case.batch_size}," + f"{timestamp}.json") + heads = [ + "ModelName", "BS", "#Microbatch", "ParallelArgs", "MeanTime(s)", + "StdTime(s)", "TFLOPs", "StageWeights(B)", "StagePeakMem(B)", + "StageLatencies(s)", "Metadata", "TimeStamp" + ] + values = [ + model_name, benchmark_case.batch_size, + benchmark_case.num_micro_batches, parallel_args, + f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", + f"{tflops:.2f}", f"{list(per_stage_weight_mem)}", + f"{list(per_stage_peak_mem)}", + list(avg_stage_latencies), metadata, timestamp + ] write_tsv(heads, values, f"benchmark_results.tsv") metadata = { From 672e17b107aa8c9198953ca917947358d944bb48 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 6 Dec 2022 02:24:59 -0800 Subject: [PATCH 33/37] fix --- benchmark/alpa/benchmark_one_case_moe_inference.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/benchmark/alpa/benchmark_one_case_moe_inference.py b/benchmark/alpa/benchmark_one_case_moe_inference.py index 9ab82f130..30c3b444d 100644 --- a/benchmark/alpa/benchmark_one_case_moe_inference.py +++ b/benchmark/alpa/benchmark_one_case_moe_inference.py @@ -1,4 +1,6 @@ """Benchmark one case of inter-op + intra-op parallelism.""" +from datetime import datetime + import jax import jax.numpy as jnp import numpy as np From ddcabebdf1e55e3f9cb39e18418a2dc1211803eb Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 6 Dec 2022 11:20:56 -0800 Subject: [PATCH 34/37] name change --- benchmark/alpa/benchmark_one_case_moe_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/alpa/benchmark_one_case_moe_inference.py b/benchmark/alpa/benchmark_one_case_moe_inference.py index 30c3b444d..7d4f3502c 100644 --- a/benchmark/alpa/benchmark_one_case_moe_inference.py +++ b/benchmark/alpa/benchmark_one_case_moe_inference.py @@ -169,7 +169,7 @@ def benchmark_moe_inference_internal(benchmark_case, # Log per-stage execution information if needed if profile_stage_execution_time: - model_name = f"bert-{parameter_count/1e9:.1f}b" + model_name = f"moe-{parameter_count/1e9:.1f}b" # compute and log per-stage latency/memory statistics exec_info = executable.get_stage_execution_info() timelines = list(zip(*exec_info)) From 944d7c53a3b4760085bc50e603154ae617a219a5 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 6 Dec 2022 14:36:26 -0800 Subject: [PATCH 35/37] fix tf version --- third_party/tensorflow-alpa | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/tensorflow-alpa b/third_party/tensorflow-alpa index cd865615b..721260d12 160000 --- a/third_party/tensorflow-alpa +++ b/third_party/tensorflow-alpa @@ -1 +1 @@ -Subproject commit cd865615b9b518bc507fbdc71dc44c7cc76618ac +Subproject commit 721260d122f096040762b2d226b37e8ab23f74b8 From 0c894674634a6db08a172fc307abd0c2ca31b42a Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 8 Dec 2022 19:29:08 +0000 Subject: [PATCH 36/37] fix --- alpa/pipeline_parallel/layer_construction.py | 4 +- benchmark/alpa/gen_serving_database.py | 2 +- benchmark/alpa/run_exp.py | 2 +- benchmark/alpa/suite_inference_gpt.py | 180 +++++++++++-------- benchmark/alpa/suite_inference_moe.py | 26 +-- third_party/tensorflow-alpa | 2 +- 6 files changed, 124 insertions(+), 92 deletions(-) diff --git a/alpa/pipeline_parallel/layer_construction.py b/alpa/pipeline_parallel/layer_construction.py index c5289e65b..8858d8b78 100644 --- a/alpa/pipeline_parallel/layer_construction.py +++ b/alpa/pipeline_parallel/layer_construction.py @@ -434,7 +434,7 @@ def dp(input_sizes, blocked): assert r == 0, "No solution for layer construction." solution = list(reversed(reversed_sliced_eqns)) - log_solution(solution, jaxpr) + # log_solution(solution, jaxpr) solution_info = { "total_cost": value, @@ -507,7 +507,7 @@ def dp(layer_costs): assert r == 0, "No solution for layer construction." solution = list(reversed(reversed_sliced_eqns)) - log_solution(solution, jaxpr) + # log_solution(solution, jaxpr) solution_info = { "total_cost": value, diff --git a/benchmark/alpa/gen_serving_database.py b/benchmark/alpa/gen_serving_database.py index d9570a89d..0cd05e898 100644 --- a/benchmark/alpa/gen_serving_database.py +++ b/benchmark/alpa/gen_serving_database.py @@ -10,5 +10,5 @@ args = parser.parse_args() database = ProfilingDatabase(args.output, args.new) - database.update_from_csv(args.input) + database.update_from_auto_csv(args.input) database.materialize() diff --git a/benchmark/alpa/run_exp.py b/benchmark/alpa/run_exp.py index 1951c5a54..e1f700c11 100644 --- a/benchmark/alpa/run_exp.py +++ b/benchmark/alpa/run_exp.py @@ -56,7 +56,7 @@ def run_exp(exp_name, cluster_settings, suite_name, benchmark_settings=None): "profile_driver_time": True }), } -cluster_settings = [(8, 8), (4, 8), (3, 8), (2, 8), (1, 8), (1, 4), (1, 2), +cluster_settings = [(1, 8), (1, 4), (1, 2), (1, 1)] if __name__ == "__main__": diff --git a/benchmark/alpa/suite_inference_gpt.py b/benchmark/alpa/suite_inference_gpt.py index 587245e60..bead07e62 100644 --- a/benchmark/alpa/suite_inference_gpt.py +++ b/benchmark/alpa/suite_inference_gpt.py @@ -66,79 +66,105 @@ def get_config(model_config, test_suite = { 8: [ - BenchmarkCase( - 1, gpt_specs["1.3B"], 1, "uniform", - UniformParallelArgs( - prefer_reduce_scatter, - use_remat, - dp=1, - op=1, - pp=8, - force_batch_dim_mapping=force_batch_dim_mapping)), - BenchmarkCase( - 1, gpt_specs["1.3B"], 1, "load_solution", - LoadSolutionParallelArgs( - prefer_reduce_scatter, - use_remat, - num_auto_layers=8, - forward_stage_layer_ids=[[0], [1], [2], [3], [4], [5], [6], - [7]], - submesh_physical_shapes=[(1, 1)] * 8, - submesh_logical_shapes=[(1, 1)] * 8, - submesh_autosharding_option_dicts=[force_dp_dict] * 8)), - # 2D + Profile - BenchmarkCase( - 1, gpt_specs["1.3B"], 1, "load_solution", - LoadSolutionParallelArgs( - prefer_reduce_scatter, - use_remat, - num_auto_layers=50, - forward_stage_layer_ids=[[0, 1, 2, 3, 4, 5], - [6, 7, 8, 9, 10, 11], - [12, 13, 14, 15, 16, 17, 18], - [19, 20, 21, 22, 23, 24, 25], - [26, 27, 28, 29, 30, 31], - [32, 33, 34, 35, 36, 37], - [38, 39, 40, 41, 42, 43, 44], - [45, 46, 47, 48, 49]], - submesh_physical_shapes=[(1, 1)] * 8, - submesh_logical_shapes=[(1, 1)] * 8, - submesh_autosharding_option_dicts=[force_dp_dict] * 8)), - # 1D + Profile - BenchmarkCase( - 1, gpt_specs["1.3B"], 1, "load_solution", - LoadSolutionParallelArgs( - prefer_reduce_scatter, - use_remat, - num_auto_layers=50, - forward_stage_layer_ids=[[0, 1, 2, 3, 4, 5], - [6, 7, 8, 9, 10, 11, 12], - [13, 14, 15, 16, 17, 18], - [19, 20, 21, 22, 23, 24], - [25, 26, 27, 28, 29, 30, 31], - [32, 33, 34, 35, 36, 37], - [38, 39, 40, 41, 42, 43, 44], - [45, 46, 47, 48, 49]], - submesh_physical_shapes=[(1, 1)] * 8, - submesh_logical_shapes=[(1, 1)] * 8, - submesh_autosharding_option_dicts=[force_dp_dict] * 8)), - # 1D + Cost model + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "uniform", + # UniformParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # dp=1, + # op=1, + # pp=8, + # force_batch_dim_mapping=force_batch_dim_mapping)), + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "load_solution", + # LoadSolutionParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # num_auto_layers=8, + # forward_stage_layer_ids=[[0], [1], [2], [3], [4], [5], [6], + # [7]], + # submesh_physical_shapes=[(1, 1)] * 8, + # submesh_logical_shapes=[(1, 1)] * 8, + # submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + # # 2D + Profile + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "load_solution", + # LoadSolutionParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # num_auto_layers=50, + # forward_stage_layer_ids=[[0, 1, 2, 3, 4, 5], + # [6, 7, 8, 9, 10, 11], + # [12, 13, 14, 15, 16, 17, 18], + # [19, 20, 21, 22, 23, 24, 25], + # [26, 27, 28, 29, 30, 31], + # [32, 33, 34, 35, 36, 37], + # [38, 39, 40, 41, 42, 43, 44], + # [45, 46, 47, 48, 49]], + # submesh_physical_shapes=[(1, 1)] * 8, + # submesh_logical_shapes=[(1, 1)] * 8, + # submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + # # 1D + Profile + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "load_solution", + # LoadSolutionParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # num_auto_layers=50, + # forward_stage_layer_ids=[[0, 1, 2, 3, 4, 5], + # [6, 7, 8, 9, 10, 11, 12], + # [13, 14, 15, 16, 17, 18], + # [19, 20, 21, 22, 23, 24], + # [25, 26, 27, 28, 29, 30, 31], + # [32, 33, 34, 35, 36, 37], + # [38, 39, 40, 41, 42, 43, 44], + # [45, 46, 47, 48, 49]], + # submesh_physical_shapes=[(1, 1)] * 8, + # submesh_logical_shapes=[(1, 1)] * 8, + # submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + # # 1D + Cost model + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "load_solution", + # LoadSolutionParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # num_auto_layers=50, + # forward_stage_layer_ids=[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9, 10], + # [11, 12, 13, 14, 15, 16], + # [17, 18, 19, 20, 21, 22, 23], + # [24, 25, 26, 27, 28, 29, 30], + # [31, 32, 33, 34, 35, 36, 37], + # [38, 39, 40, 41, 42, 43, 44], + # [45, 46, 47, 48, 49]], + # submesh_physical_shapes=[(1, 1)] * 8, + # submesh_logical_shapes=[(1, 1)] * 8, + # submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "load_solution", + # LoadSolutionParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # num_auto_layers=50, + # forward_stage_layer_ids=[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9, 10], + # [11, 12, 13, 14, 15, 16], + # [17, 18, 19, 20, 21, 22, 23], + # [24, 25, 26, 27, 28, 29, 30], + # [31, 32, 33, 34, 35, 36, 37], + # [38, 39, 40, 41, 42, 43, 44], + # [45, 46, 47, 48, 49]], + # submesh_physical_shapes=[(1, 1)] * 8, + # submesh_logical_shapes=[(1, 1)] * 8, + # submesh_autosharding_option_dicts=[force_dp_dict] * 8)), BenchmarkCase( 1, gpt_specs["1.3B"], 1, "load_solution", LoadSolutionParallelArgs( prefer_reduce_scatter, use_remat, num_auto_layers=50, - forward_stage_layer_ids=[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9, 10], - [11, 12, 13, 14, 15, 16], - [17, 18, 19, 20, 21, 22, 23], - [24, 25, 26, 27, 28, 29, 30], - [31, 32, 33, 34, 35, 36, 37], - [38, 39, 40, 41, 42, 43, 44], - [45, 46, 47, 48, 49]], - submesh_physical_shapes=[(1, 1)] * 8, - submesh_logical_shapes=[(1, 1)] * 8, - submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + forward_stage_layer_ids=[list(range(25)), list(range(25, 50))], + submesh_physical_shapes=[(1, 4)] * 2, + submesh_logical_shapes=[(1, 4)] * 2, + submesh_autosharding_option_dicts=[force_dp_dict] * 2)), ] } @@ -170,16 +196,16 @@ def generate_search_configs(model_config, num_auto_layers, pp_list, op_list): "model_parallel_only", "layer_profile_mode": "individual", - # "use_hlo_cost_model": True, - # "profiling_database_filename": - # "prof_database.pkl", + "use_hlo_cost_model": True, + "profiling_database_filename": + "prof_database.pkl", }))) -generate_search_configs(gpt_specs["1.3B"], 50, [8], [1]) -# generate_search_configs(gpt_specs["1.3B"], 50, [1, 2, 4, 8, 16, 32], -# [1, 2, 4, 8]) -# generate_search_configs(gpt_specs["2.6B"], 66, [1, 2, 4, 8, 16, 32], -# [1, 2, 4, 8]) -# generate_search_configs(gpt_specs["6.7B"], 66, [1, 2, 4, 8, 16, 32], -# [1, 2, 4, 8]) +# generate_search_configs(gpt_specs["1.3B"], 50, [8], [1]) +generate_search_configs(gpt_specs["1.3B"], 50, [1, 2, 4, 8, 16, 32], + [1, 2, 4, 8]) +generate_search_configs(gpt_specs["2.6B"], 66, [1, 2, 4, 8, 16, 32], + [1, 2, 4, 8]) +generate_search_configs(gpt_specs["6.7B"], 66, [1, 2, 4, 8, 16, 32], + [1, 2, 4, 8]) diff --git a/benchmark/alpa/suite_inference_moe.py b/benchmark/alpa/suite_inference_moe.py index 017e0f8e2..602744d9b 100644 --- a/benchmark/alpa/suite_inference_moe.py +++ b/benchmark/alpa/suite_inference_moe.py @@ -75,16 +75,22 @@ def generate_search_configs(model_config, num_auto_layers, pp_list, op_list): "model_parallel_only", "layer_profile_mode": "individual", - # "use_hlo_cost_model": True, - # "profiling_database_filename": - # "prof_database.pkl", + "use_hlo_cost_model": True, + "profiling_database_filename": + "prof_database.pkl", }))) -generate_search_configs(moe_specs["1.3B"], 34, [8], [1]) -# generate_search_configs(moe_specs["1.3B"], 50, [1, 2, 4, 8, 16, 32], -# [1, 2, 4, 8]) -# generate_search_configs(moe_specs["2.4B"], 66, [1, 2, 4, 8, 16, 32], -# [1, 2, 4, 8]) -# generate_search_configs(moe_specs["7.1B"], 66, [1, 2, 4, 8, 16, 32], -# [1, 2, 4, 8]) +# generate_search_configs(moe_specs["1.3B"], 34, [8], [1]) +generate_search_configs(moe_specs["1.3B"], 34, [1, 2, 4, 8, 16, 32], + [1, 2, 4, 8]) +generate_search_configs(moe_specs["2.4B"], 34, [1, 2, 4, 8, 16, 32], + [1, 2, 4, 8]) +generate_search_configs(moe_specs["7.1B"], 34, [1, 2, 4, 8, 16, 32], + [1, 2, 4, 8]) +# generate_search_configs(moe_specs["1.3B"], 16, [1, 2, 4, 8], +# [1]) +# generate_search_configs(moe_specs["2.4B"], 16, [1, 2, 4, 8], +# [1]) +# generate_search_configs(moe_specs["7.1B"], 16, [1, 2, 4, 8], +# [1]) diff --git a/third_party/tensorflow-alpa b/third_party/tensorflow-alpa index 721260d12..cd865615b 160000 --- a/third_party/tensorflow-alpa +++ b/third_party/tensorflow-alpa @@ -1 +1 @@ -Subproject commit 721260d122f096040762b2d226b37e8ab23f74b8 +Subproject commit cd865615b9b518bc507fbdc71dc44c7cc76618ac From c4c9b69dec72f6b4754c27369556ec8ec88cce4d Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 8 Dec 2022 19:54:51 +0000 Subject: [PATCH 37/37] fix tf and benchmarking longer sequences --- benchmark/alpa/suite_manual_gpt.py | 6 +++--- benchmark/alpa/suite_manual_moe.py | 6 +++--- third_party/tensorflow-alpa | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/benchmark/alpa/suite_manual_gpt.py b/benchmark/alpa/suite_manual_gpt.py index 78aab953f..1b49b6515 100644 --- a/benchmark/alpa/suite_manual_gpt.py +++ b/benchmark/alpa/suite_manual_gpt.py @@ -18,9 +18,9 @@ "125M": GPTModelConfig(1024, 768, 12, 12, 51200), "350M": GPTModelConfig(1024, 1024, 24, 16, 51200), "760M": GPTModelConfig(1024, 1536, 24, 16, 51200), - "1.3B": GPTModelConfig(1024, 2048, 24, 32, 51200), - "2.6B": GPTModelConfig(1024, 2560, 32, 32, 51200), - "6.7B": GPTModelConfig(1024, 4096, 32, 32, 51200), + "1.3B": GPTModelConfig(2048, 2048, 24, 32, 51200), + "2.6B": GPTModelConfig(2048, 2560, 32, 32, 51200), + "6.7B": GPTModelConfig(2048, 4096, 32, 32, 51200), "15B": GPTModelConfig(1024, 5120, 48, 40, 51200), "39B": GPTModelConfig(1024, 8192, 48, 64, 51200), "76B": GPTModelConfig(1024, 10240, 60, 80, 51200), diff --git a/benchmark/alpa/suite_manual_moe.py b/benchmark/alpa/suite_manual_moe.py index df08efbc6..d90ed0cfb 100644 --- a/benchmark/alpa/suite_manual_moe.py +++ b/benchmark/alpa/suite_manual_moe.py @@ -18,9 +18,9 @@ # S, H, L, head, V, E, S_ "380M": MoEModelConfig(1024, 768, 8, 16, 32000, 8, 2048), "690M": MoEModelConfig(1024, 768, 8, 16, 32000, 16, 2048), - "1.3B": MoEModelConfig(1024, 768, 16, 16, 32000, 16, 2048), - "2.4B": MoEModelConfig(1024, 1024, 16, 16, 32000, 16, 2048), - "7.1B": MoEModelConfig(1024, 1280, 16, 16, 32000, 32, 2048), + "1.3B": MoEModelConfig(4096, 768, 16, 16, 32000, 16, 2048), + "2.4B": MoEModelConfig(4096, 1024, 16, 16, 32000, 16, 2048), + "7.1B": MoEModelConfig(4096, 1280, 16, 16, 32000, 32, 2048), "10B": MoEModelConfig(1024, 1536, 16, 16, 32000, 32, 2048), "27B": MoEModelConfig(1024, 2048, 16, 16, 32000, 48, 2048), "70B": MoEModelConfig(1024, 2048, 32, 16, 32000, 64, 2048), diff --git a/third_party/tensorflow-alpa b/third_party/tensorflow-alpa index cd865615b..721260d12 160000 --- a/third_party/tensorflow-alpa +++ b/third_party/tensorflow-alpa @@ -1 +1 @@ -Subproject commit cd865615b9b518bc507fbdc71dc44c7cc76618ac +Subproject commit 721260d122f096040762b2d226b37e8ab23f74b8