diff --git a/alpa/pipeline_parallel/layer_construction.py b/alpa/pipeline_parallel/layer_construction.py index 7eeda03ba..8858d8b78 100644 --- a/alpa/pipeline_parallel/layer_construction.py +++ b/alpa/pipeline_parallel/layer_construction.py @@ -434,22 +434,7 @@ def dp(input_sizes, blocked): assert r == 0, "No solution for layer construction." solution = list(reversed(reversed_sliced_eqns)) - # print("dp solution") - # for i, eqns in enumerate(solution): - # invars = OrderedSet() - # for eqn in eqns: - # invars.update([var for var in eqn.invars if isinstance(var, Var)]) - # invars.intersection_update(jaxpr.jaxpr.invars) - # print(f"mesh: {i}, set_shapes: " - # f"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}") - # - # invars = [] - # for eqn in eqns: - # tmp_set = set([var for var in eqn.invars if isinstance(var, Var)]) - # tmp_set.intersection_update(jaxpr.jaxpr.invars) - # invars.extend(list(tmp_set)) - # print(f"mesh: {i}, list_shapes: " - # f"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}") + # log_solution(solution, jaxpr) solution_info = { "total_cost": value, @@ -457,6 +442,110 @@ def dp(input_sizes, blocked): return solution, solution_info +def cluster_jaxpr_by_cost_optimized(jaxpr: Jaxpr, layer_num: int, costs, + cost_criteria): + """Clusters the jaxpr by cost.""" + layer_num = int(layer_num) + length = len(jaxpr.eqns) + _, input_sizes, compute_costs = costs + assert cost_criteria == "flops" + FLOPS_NORMALIZER = 30 * 1e12 # 30 TFLOPS + NETWORK_NORMALIZER = 2 * 1e9 # 2 GB / s + + @maybe_numba_jit + def init_layer_costs(): + layer_costs = np.full((length, length), np.inf, dtype=np.float32) + for l in range(0, length): + layer_flops = 0 + for r in range(l, length): + layer_flops += compute_costs[r] + layer_costs[l, r] = (layer_flops / FLOPS_NORMALIZER + + input_sizes[l, r + 1] / NETWORK_NORMALIZER) + return layer_costs + + @maybe_numba_jit + def dp(layer_costs): + max_cost = np.full((length + 1, layer_num + 1), + np.inf, + dtype=np.float32) + sum_cost_under_max = np.full((length + 1, layer_num + 1), + np.inf, + dtype=np.float32) + squared_cost_under_max = np.full((length + 1, layer_num + 1), + np.inf, + dtype=np.float32) + max_cost_argmin = np.full((length + 1, layer_num + 1), + -1, + dtype=np.int32) + max_cost[0, 0] = 0 + sum_cost_under_max[0, 0] = 0 + squared_cost_under_max[0, 0] = 0 + for q in range(1, layer_num + 1): + for r in range(1, length + 1): + for k in range(0, r): + new_value = max(max_cost[k, q - 1], layer_costs[k, r - 1]) + new_squared_sum = (squared_cost_under_max[k, q - 1] + + layer_costs[k, r - 1]**2) + if (new_value < max_cost[r, q] or + (new_value <= max_cost[r, q] * (1 + 1e-4) and + new_squared_sum < squared_cost_under_max[r, q])): + max_cost[r, q] = new_value + squared_cost_under_max[r, q] = new_squared_sum + max_cost_argmin[r, q] = k + return max_cost_argmin, max_cost[length, layer_num] + + layer_costs = init_layer_costs() + a_argmin, value = dp(layer_costs) + + reversed_sliced_eqns = [] + + r = length + for q in range(layer_num, 0, -1): + k = a_argmin[r, q] + reversed_sliced_eqns.append(jaxpr.eqns[k:r]) + r = k + assert r == 0, "No solution for layer construction." + solution = list(reversed(reversed_sliced_eqns)) + + # log_solution(solution, jaxpr) + + solution_info = { + "total_cost": value, + } + return solution, solution_info + + +def log_solution(solution, jaxpr): + print("-" * 80) + print(f"Layer construction solution ({len(solution)} layers):") + for i, eqns in enumerate(solution): + print("-" * 40) + print(f"Layer {i}:") + + total_flops = 0 + for j, eqn in enumerate(eqns): + flops = eqn_flops(eqn) + print(f"Eqn {j}: {eqn}, Flops: {flops}") + total_flops += flops + print(f"Total flops: {total_flops}") + invars = OrderedSet() + for eqn in eqns: + invars.update([var for var in eqn.invars if isinstance(var, Var)]) + invars.intersection_update(jaxpr.jaxpr.invars) + + print(f"set_invar_shapes: " + f"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}") + + invars = [] + for eqn in eqns: + tmp_set = {var for var in eqn.invars if isinstance(var, Var)} + tmp_set.intersection_update(jaxpr.jaxpr.invars) + invars.extend(list(tmp_set)) + print(f"list_invar_shapes: " + f"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}") + print("-" * 80) + + def search_layer_num(jaxpr, eps, layer_eps=0, @@ -511,11 +600,8 @@ def wrapped(*args): layer_num = search_layer_num(jaxpr, eps, layer_eps) costs = get_layer_construction_costs(jaxpr, cost_criteria=cost_criteria) - sliced_eqns, _ = cluster_jaxpr_by_cost(jaxpr, - layer_num, - eps, - costs, - cost_criteria=cost_criteria) + sliced_eqns, _ = cluster_jaxpr_by_cost( + jaxpr, layer_num, eps, costs, cost_criteria=cost_criteria) else: sliced_eqns = slice_eqns_by_layer_boundary(jaxpr) diff --git a/benchmark/alpa/benchmark.py b/benchmark/alpa/benchmark.py index 1d9afc5c5..7128ac0d0 100644 --- a/benchmark/alpa/benchmark.py +++ b/benchmark/alpa/benchmark.py @@ -28,6 +28,8 @@ "gpt.grid_search_auto": suite_auto_gpt.grid_search_suite, "gpt.correctness_test_auto": suite_auto_gpt.correctness_test_suite, "gpt_inference.profile": suite_inference_gpt.profile_suite, + "gpt_inference.test": suite_inference_gpt.test_suite, + "gpt_inference.search": suite_inference_gpt.search_suite, "gpt_no_embedding_inference.profile": suite_inference_gpt.profile_suite, "moe.tmp": suite_manual_moe.tmp_suite, "moe.tmp_auto": suite_auto_moe.tmp_suite, @@ -35,6 +37,7 @@ "moe.perf_test_auto": suite_auto_moe.perf_test_suite, "moe.grid_search_auto": suite_auto_moe.grid_search_suite, "moe_inference.profile": suite_inference_moe.profile_suite, + "moe_inference.search": suite_inference_moe.search_suite, "unet.perf_test_auto": suite_unet.perf_test_auto_suite, "unet.grid_search_auto": suite_unet.grid_search_auto_suite, "wresnet.perf_test_2d": suite_wresnet.perf_test_2d_suite, diff --git a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py index 18f305787..5067cf87a 100644 --- a/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py +++ b/benchmark/alpa/benchmark_one_case_gpt_bert_inference.py @@ -1,5 +1,6 @@ """Benchmark one case of inter-op + intra-op parallelism.""" import os +from datetime import datetime import jax import jax.numpy as jnp @@ -10,6 +11,7 @@ from alpa.model.bert_model import BertConfig, FlaxBertLayerCollection from alpa.model.gpt_model import FlaxGPTForLMModule from alpa.util import print_used_time, GB, write_tsv +from alpa.pipeline_parallel.stage_construction import get_last_dp_result from util import compute_gpt_parameter_count, compute_gpt_tflops from benchmark_parallel_utils import ( @@ -179,31 +181,65 @@ def benchmark_gpt_inference_internal(model_type, # Log per-stage execution information if needed if profile_stage_execution_time: model_name = f"bert-{parameter_count/1e9:.1f}b" - # dump chrome trace - executable.dump_stage_execution_trace( - f"./chrome_trace/{model_name},bs={benchmark_case.batch_size},op={benchmark_case.parallel_args.op},pp={benchmark_case.parallel_args.pp}.json" - ) # compute and log per-stage latency/memory statistics exec_info = executable.get_stage_execution_info() timelines = list(zip(*exec_info)) # drop warmup case timelines = timelines[1:] avg_stage_latencies = compute_avg_stage_latencies(timelines) - assert len(avg_stage_latencies) == num_manual_pipeline_stages + if num_manual_pipeline_stages is not None: + assert len(avg_stage_latencies) == num_manual_pipeline_stages parallel_args = benchmark_case.parallel_args - dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp - heads = [ - "ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU", - "MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)", - "StagePeakMem(B)", "StageLatencies(s)" - ] - values = [ - model_name, benchmark_case.batch_size, - benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp, - f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", - f"{tflops:.2f}", f"{per_stage_weight_mem}", f"{per_stage_peak_mem}", - avg_stage_latencies - ] + if benchmark_case.parallel_mode == "uniform": + dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp + # dump chrome trace + executable.dump_stage_execution_trace( + f"./chrome_trace/{model_name}," + f"bs={benchmark_case.batch_size}," + f"op={op},pp={pp}.json") + heads = [ + "ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU", + "MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)", + "StagePeakMem(B)", "StageLatencies(s)" + ] + values = [ + model_name, benchmark_case.batch_size, + benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp, + f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", + f"{tflops:.2f}", f"{per_stage_weight_mem}", + f"{per_stage_peak_mem}", avg_stage_latencies + ] + else: + (compute_cost_file_name, forward_stage_layer_ids, submesh_shapes, + logical_mesh_shapes, + autosharding_option_dicts) = get_last_dp_result() + metadata = { + "compilation_times": compilation_times, + "compute_cost_file_name": compute_cost_file_name, + "forward_stage_layer_ids": forward_stage_layer_ids, + "submesh_shapes": submesh_shapes, + "logical_mesh_shapes": logical_mesh_shapes, + "autosharding_option_dicts": autosharding_option_dicts, + } + + timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + executable.dump_stage_execution_trace( + f"./chrome_trace/{model_name}," + f"bs={benchmark_case.batch_size}," + f"{timestamp}.json") + heads = [ + "ModelName", "BS", "#Microbatch", "ParallelArgs", "MeanTime(s)", + "StdTime(s)", "TFLOPs", "StageWeights(B)", "StagePeakMem(B)", + "StageLatencies(s)", "Metadata", "TimeStamp" + ] + values = [ + model_name, benchmark_case.batch_size, + benchmark_case.num_micro_batches, parallel_args, + f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", + f"{tflops:.2f}", f"{list(per_stage_weight_mem)}", + f"{list(per_stage_peak_mem)}", + list(avg_stage_latencies), metadata, timestamp + ] write_tsv(heads, values, f"benchmark_results.tsv") metadata = { diff --git a/benchmark/alpa/benchmark_one_case_moe_inference.py b/benchmark/alpa/benchmark_one_case_moe_inference.py index f434c1331..7d4f3502c 100644 --- a/benchmark/alpa/benchmark_one_case_moe_inference.py +++ b/benchmark/alpa/benchmark_one_case_moe_inference.py @@ -1,4 +1,6 @@ """Benchmark one case of inter-op + intra-op parallelism.""" +from datetime import datetime + import jax import jax.numpy as jnp import numpy as np @@ -23,7 +25,7 @@ def create_infer_params_aval(rngkey, model, batch): params = jax.eval_shape( lambda p: jax.tree_util.tree_map( lambda x: jnp.asarray(x, dtype=jnp.float16), p), params) - return params + return params def get_infer_step(parallel_method, model): @@ -41,7 +43,7 @@ def infer_step(params, batch, rng_key): loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) loss = (label_mask * loss).sum() / label_mask.sum() return loss - + return parallelize(infer_step, method=parallel_method, donate_argnums=()) @@ -168,31 +170,65 @@ def benchmark_moe_inference_internal(benchmark_case, # Log per-stage execution information if needed if profile_stage_execution_time: model_name = f"moe-{parameter_count/1e9:.1f}b" - # dump chrome trace - executable.dump_stage_execution_trace( - f"./chrome_trace/{model_name},bs={benchmark_case.batch_size},op={benchmark_case.parallel_args.op},pp={benchmark_case.parallel_args.pp}.json" - ) # compute and log per-stage latency/memory statistics exec_info = executable.get_stage_execution_info() timelines = list(zip(*exec_info)) # drop warmup case timelines = timelines[1:] avg_stage_latencies = compute_avg_stage_latencies(timelines) - assert len(avg_stage_latencies) == num_manual_pipeline_stages + if num_manual_pipeline_stages is not None: + assert len(avg_stage_latencies) == num_manual_pipeline_stages parallel_args = benchmark_case.parallel_args - dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp - heads = [ - "ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU", - "MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)", - "StagePeakMem(B)", "StageLatencies(s)" - ] - values = [ - model_name, benchmark_case.batch_size, - benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp, - f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", - f"{tflops:.2f}", f"{per_stage_weight_mem}", f"{per_stage_peak_mem}", - avg_stage_latencies - ] + if benchmark_case.parallel_mode == "uniform": + dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp + # dump chrome trace + executable.dump_stage_execution_trace( + f"./chrome_trace/{model_name}," + f"bs={benchmark_case.batch_size}," + f"op={op},pp={pp}.json") + heads = [ + "ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU", + "MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)", + "StagePeakMem(B)", "StageLatencies(s)" + ] + values = [ + model_name, benchmark_case.batch_size, + benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp, + f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", + f"{tflops:.2f}", f"{per_stage_weight_mem}", + f"{per_stage_peak_mem}", avg_stage_latencies + ] + else: + (compute_cost_file_name, forward_stage_layer_ids, submesh_shapes, + logical_mesh_shapes, + autosharding_option_dicts) = get_last_dp_result() + metadata = { + "compilation_times": compilation_times, + "compute_cost_file_name": compute_cost_file_name, + "forward_stage_layer_ids": forward_stage_layer_ids, + "submesh_shapes": submesh_shapes, + "logical_mesh_shapes": logical_mesh_shapes, + "autosharding_option_dicts": autosharding_option_dicts, + } + + timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + executable.dump_stage_execution_trace( + f"./chrome_trace/{model_name}," + f"bs={benchmark_case.batch_size}," + f"{timestamp}.json") + heads = [ + "ModelName", "BS", "#Microbatch", "ParallelArgs", "MeanTime(s)", + "StdTime(s)", "TFLOPs", "StageWeights(B)", "StagePeakMem(B)", + "StageLatencies(s)", "Metadata", "TimeStamp" + ] + values = [ + model_name, benchmark_case.batch_size, + benchmark_case.num_micro_batches, parallel_args, + f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}", + f"{tflops:.2f}", f"{list(per_stage_weight_mem)}", + f"{list(per_stage_peak_mem)}", + list(avg_stage_latencies), metadata, timestamp + ] write_tsv(heads, values, f"benchmark_results.tsv") metadata = { diff --git a/benchmark/alpa/gen_serving_database.py b/benchmark/alpa/gen_serving_database.py index d9570a89d..0cd05e898 100644 --- a/benchmark/alpa/gen_serving_database.py +++ b/benchmark/alpa/gen_serving_database.py @@ -10,5 +10,5 @@ args = parser.parse_args() database = ProfilingDatabase(args.output, args.new) - database.update_from_csv(args.input) + database.update_from_auto_csv(args.input) database.materialize() diff --git a/benchmark/alpa/run_exp.py b/benchmark/alpa/run_exp.py index cb5632d8f..e1f700c11 100644 --- a/benchmark/alpa/run_exp.py +++ b/benchmark/alpa/run_exp.py @@ -39,16 +39,24 @@ def run_exp(exp_name, cluster_settings, suite_name, benchmark_settings=None): "niter": 10, "profile_stage_execution_time": True }), + "gpt_inference_search": ("gpt_inference.search", { + "niter": 10, + "profile_stage_execution_time": True + }), "moe_inference": ("moe_inference.profile", { "niter": 10, "profile_stage_execution_time": True }), + "moe_inference_search": ("moe_inference.search", { + "niter": 10, + "profile_stage_execution_time": True + }), "gpt_no_embedding_inference": ("gpt_no_embedding_inference.profile", {}), "gpt_inference_streaming": ("gpt_inference.profile", { "profile_driver_time": True }), } -cluster_settings = [(8, 8), (4, 8), (3, 8), (2, 8), (1, 8), (1, 4), (1, 2), +cluster_settings = [(1, 8), (1, 4), (1, 2), (1, 1)] if __name__ == "__main__": diff --git a/benchmark/alpa/suite_inference_gpt.py b/benchmark/alpa/suite_inference_gpt.py index 3d1d345a8..bead07e62 100644 --- a/benchmark/alpa/suite_inference_gpt.py +++ b/benchmark/alpa/suite_inference_gpt.py @@ -1,6 +1,9 @@ """Benchmark suites for gpt with auto parallelization.""" +from alpa import AutoStageOption from suite_manual_gpt import gpt_specs -from benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs) +from benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs, + LoadSolutionParallelArgs, + SearchParallelArgs) prefer_reduce_scatter = True force_batch_dim_mapping = True @@ -60,3 +63,149 @@ def get_config(model_config, [1, 2, 4, 8, 16]) get_config(gpt_specs["15B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], [1, 2, 4, 8, 16]) + +test_suite = { + 8: [ + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "uniform", + # UniformParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # dp=1, + # op=1, + # pp=8, + # force_batch_dim_mapping=force_batch_dim_mapping)), + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "load_solution", + # LoadSolutionParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # num_auto_layers=8, + # forward_stage_layer_ids=[[0], [1], [2], [3], [4], [5], [6], + # [7]], + # submesh_physical_shapes=[(1, 1)] * 8, + # submesh_logical_shapes=[(1, 1)] * 8, + # submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + # # 2D + Profile + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "load_solution", + # LoadSolutionParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # num_auto_layers=50, + # forward_stage_layer_ids=[[0, 1, 2, 3, 4, 5], + # [6, 7, 8, 9, 10, 11], + # [12, 13, 14, 15, 16, 17, 18], + # [19, 20, 21, 22, 23, 24, 25], + # [26, 27, 28, 29, 30, 31], + # [32, 33, 34, 35, 36, 37], + # [38, 39, 40, 41, 42, 43, 44], + # [45, 46, 47, 48, 49]], + # submesh_physical_shapes=[(1, 1)] * 8, + # submesh_logical_shapes=[(1, 1)] * 8, + # submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + # # 1D + Profile + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "load_solution", + # LoadSolutionParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # num_auto_layers=50, + # forward_stage_layer_ids=[[0, 1, 2, 3, 4, 5], + # [6, 7, 8, 9, 10, 11, 12], + # [13, 14, 15, 16, 17, 18], + # [19, 20, 21, 22, 23, 24], + # [25, 26, 27, 28, 29, 30, 31], + # [32, 33, 34, 35, 36, 37], + # [38, 39, 40, 41, 42, 43, 44], + # [45, 46, 47, 48, 49]], + # submesh_physical_shapes=[(1, 1)] * 8, + # submesh_logical_shapes=[(1, 1)] * 8, + # submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + # # 1D + Cost model + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "load_solution", + # LoadSolutionParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # num_auto_layers=50, + # forward_stage_layer_ids=[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9, 10], + # [11, 12, 13, 14, 15, 16], + # [17, 18, 19, 20, 21, 22, 23], + # [24, 25, 26, 27, 28, 29, 30], + # [31, 32, 33, 34, 35, 36, 37], + # [38, 39, 40, 41, 42, 43, 44], + # [45, 46, 47, 48, 49]], + # submesh_physical_shapes=[(1, 1)] * 8, + # submesh_logical_shapes=[(1, 1)] * 8, + # submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + # BenchmarkCase( + # 1, gpt_specs["1.3B"], 1, "load_solution", + # LoadSolutionParallelArgs( + # prefer_reduce_scatter, + # use_remat, + # num_auto_layers=50, + # forward_stage_layer_ids=[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9, 10], + # [11, 12, 13, 14, 15, 16], + # [17, 18, 19, 20, 21, 22, 23], + # [24, 25, 26, 27, 28, 29, 30], + # [31, 32, 33, 34, 35, 36, 37], + # [38, 39, 40, 41, 42, 43, 44], + # [45, 46, 47, 48, 49]], + # submesh_physical_shapes=[(1, 1)] * 8, + # submesh_logical_shapes=[(1, 1)] * 8, + # submesh_autosharding_option_dicts=[force_dp_dict] * 8)), + BenchmarkCase( + 1, gpt_specs["1.3B"], 1, "load_solution", + LoadSolutionParallelArgs( + prefer_reduce_scatter, + use_remat, + num_auto_layers=50, + forward_stage_layer_ids=[list(range(25)), list(range(25, 50))], + submesh_physical_shapes=[(1, 4)] * 2, + submesh_logical_shapes=[(1, 4)] * 2, + submesh_autosharding_option_dicts=[force_dp_dict] * 2)), + ] +} + +search_suite = {} + + +def generate_search_configs(model_config, num_auto_layers, pp_list, op_list): + """Generate search configs.""" + for pp in pp_list: + for op in op_list: + num_gpus = pp * op + if num_gpus not in search_suite: + search_suite[num_gpus] = [] + search_suite[num_gpus].append( + BenchmarkCase( + 1, + model_config, + 1, + "search", + SearchParallelArgs( + prefer_reduce_scatter, + use_remat, + num_auto_layers=num_auto_layers, + auto_stage_option={ + "submesh_physical_shape_space": + "manual", + "manually_specified_submeshes": ((1, op),), + "submesh_logical_shape_space": + "model_parallel_only", + "layer_profile_mode": + "individual", + "use_hlo_cost_model": True, + "profiling_database_filename": + "prof_database.pkl", + }))) + + +# generate_search_configs(gpt_specs["1.3B"], 50, [8], [1]) +generate_search_configs(gpt_specs["1.3B"], 50, [1, 2, 4, 8, 16, 32], + [1, 2, 4, 8]) +generate_search_configs(gpt_specs["2.6B"], 66, [1, 2, 4, 8, 16, 32], + [1, 2, 4, 8]) +generate_search_configs(gpt_specs["6.7B"], 66, [1, 2, 4, 8, 16, 32], + [1, 2, 4, 8]) diff --git a/benchmark/alpa/suite_inference_moe.py b/benchmark/alpa/suite_inference_moe.py index 2c3437ef7..602744d9b 100644 --- a/benchmark/alpa/suite_inference_moe.py +++ b/benchmark/alpa/suite_inference_moe.py @@ -1,6 +1,8 @@ """Benchmark suites for gpt with auto parallelization.""" from suite_manual_moe import moe_specs -from benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs) +from benchmark_parallel_utils import (BenchmarkCase, UniformParallelArgs, + LoadSolutionParallelArgs, + SearchParallelArgs) prefer_reduce_scatter = True force_batch_dim_mapping = True @@ -44,3 +46,51 @@ def get_config(model_config, [1, 2, 4, 8, 16]) get_config(moe_specs["10B"], [1, 2, 4, 8, 16], [1], [1, 2, 4, 8], [1], [1, 2, 4, 8, 16]) + +search_suite = {} + + +def generate_search_configs(model_config, num_auto_layers, pp_list, op_list): + """Generate search configs.""" + for pp in pp_list: + for op in op_list: + num_gpus = pp * op + if num_gpus not in search_suite: + search_suite[num_gpus] = [] + search_suite[num_gpus].append( + BenchmarkCase( + 1, + model_config, + 1, + "search", + SearchParallelArgs( + prefer_reduce_scatter, + use_remat, + num_auto_layers=num_auto_layers, + auto_stage_option={ + "submesh_physical_shape_space": + "manual", + "manually_specified_submeshes": ((1, op),), + "submesh_logical_shape_space": + "model_parallel_only", + "layer_profile_mode": + "individual", + "use_hlo_cost_model": True, + "profiling_database_filename": + "prof_database.pkl", + }))) + + +# generate_search_configs(moe_specs["1.3B"], 34, [8], [1]) +generate_search_configs(moe_specs["1.3B"], 34, [1, 2, 4, 8, 16, 32], + [1, 2, 4, 8]) +generate_search_configs(moe_specs["2.4B"], 34, [1, 2, 4, 8, 16, 32], + [1, 2, 4, 8]) +generate_search_configs(moe_specs["7.1B"], 34, [1, 2, 4, 8, 16, 32], + [1, 2, 4, 8]) +# generate_search_configs(moe_specs["1.3B"], 16, [1, 2, 4, 8], +# [1]) +# generate_search_configs(moe_specs["2.4B"], 16, [1, 2, 4, 8], +# [1]) +# generate_search_configs(moe_specs["7.1B"], 16, [1, 2, 4, 8], +# [1]) diff --git a/benchmark/alpa/suite_manual_gpt.py b/benchmark/alpa/suite_manual_gpt.py index 78aab953f..1b49b6515 100644 --- a/benchmark/alpa/suite_manual_gpt.py +++ b/benchmark/alpa/suite_manual_gpt.py @@ -18,9 +18,9 @@ "125M": GPTModelConfig(1024, 768, 12, 12, 51200), "350M": GPTModelConfig(1024, 1024, 24, 16, 51200), "760M": GPTModelConfig(1024, 1536, 24, 16, 51200), - "1.3B": GPTModelConfig(1024, 2048, 24, 32, 51200), - "2.6B": GPTModelConfig(1024, 2560, 32, 32, 51200), - "6.7B": GPTModelConfig(1024, 4096, 32, 32, 51200), + "1.3B": GPTModelConfig(2048, 2048, 24, 32, 51200), + "2.6B": GPTModelConfig(2048, 2560, 32, 32, 51200), + "6.7B": GPTModelConfig(2048, 4096, 32, 32, 51200), "15B": GPTModelConfig(1024, 5120, 48, 40, 51200), "39B": GPTModelConfig(1024, 8192, 48, 64, 51200), "76B": GPTModelConfig(1024, 10240, 60, 80, 51200), diff --git a/benchmark/alpa/suite_manual_moe.py b/benchmark/alpa/suite_manual_moe.py index df08efbc6..d90ed0cfb 100644 --- a/benchmark/alpa/suite_manual_moe.py +++ b/benchmark/alpa/suite_manual_moe.py @@ -18,9 +18,9 @@ # S, H, L, head, V, E, S_ "380M": MoEModelConfig(1024, 768, 8, 16, 32000, 8, 2048), "690M": MoEModelConfig(1024, 768, 8, 16, 32000, 16, 2048), - "1.3B": MoEModelConfig(1024, 768, 16, 16, 32000, 16, 2048), - "2.4B": MoEModelConfig(1024, 1024, 16, 16, 32000, 16, 2048), - "7.1B": MoEModelConfig(1024, 1280, 16, 16, 32000, 32, 2048), + "1.3B": MoEModelConfig(4096, 768, 16, 16, 32000, 16, 2048), + "2.4B": MoEModelConfig(4096, 1024, 16, 16, 32000, 16, 2048), + "7.1B": MoEModelConfig(4096, 1280, 16, 16, 32000, 32, 2048), "10B": MoEModelConfig(1024, 1536, 16, 16, 32000, 32, 2048), "27B": MoEModelConfig(1024, 2048, 16, 16, 32000, 48, 2048), "70B": MoEModelConfig(1024, 2048, 32, 16, 32000, 64, 2048),