Skip to content
This repository was archived by the owner on Oct 19, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
44a4828
add test cases on inference layer construction
zhuohan123 Dec 2, 2022
2376277
Merge branch 'main' into inference-layer-construction
zhuohan123 Dec 2, 2022
e5ddd93
fix bugs
zhuohan123 Dec 3, 2022
1893ab1
fix bugs
zhuohan123 Dec 3, 2022
e3e90b8
fix
zhuohan123 Dec 3, 2022
f20fd9e
fix
zhuohan123 Dec 3, 2022
4f6f1a4
test search
zhuohan123 Dec 3, 2022
a7f5e2d
fix
zhuohan123 Dec 3, 2022
413bb34
add 1d inference stage construction
zhuohan123 Dec 3, 2022
7fa3cfc
fix
zhuohan123 Dec 3, 2022
bf6c932
fix
zhuohan123 Dec 3, 2022
bf8cf0c
add bert tests
zhuohan123 Dec 3, 2022
6c23c7c
fix
zhuohan123 Dec 4, 2022
9615712
Merge branch 'inference-1d-stage-construction' into inference-layer-c…
zhuohan123 Dec 4, 2022
2da2103
Merge branch 'main' into inference-layer-construction
zhuohan123 Dec 4, 2022
8d41a5f
test 1d stage construction
zhuohan123 Dec 4, 2022
e7b67db
add layer construction solution logging
zhuohan123 Dec 4, 2022
27af75b
add some searched results
zhuohan123 Dec 4, 2022
8efe119
test new layer construction
zhuohan123 Dec 4, 2022
7bc7583
fix indexing error
zhuohan123 Dec 4, 2022
a6cc06e
add squared cost
zhuohan123 Dec 5, 2022
006c781
fix bug
zhuohan123 Dec 5, 2022
20c3ddb
do not use sum, only use squared sum
zhuohan123 Dec 5, 2022
04177f8
add search suite
zhuohan123 Dec 5, 2022
eab6cbd
add metadata
zhuohan123 Dec 5, 2022
759d04d
print python list
zhuohan123 Dec 5, 2022
de4a365
fix
zhuohan123 Dec 5, 2022
6a589a6
fix head ip address
zhuohan123 Dec 6, 2022
00c8114
fix available memory bug
zhuohan123 Dec 6, 2022
acf744a
Merge branch 'fix-head-ip' into inference-layer-construction
zhuohan123 Dec 6, 2022
c1211cc
fix more available memory
zhuohan123 Dec 6, 2022
239e611
Merge branch 'fix-head-ip' into inference-layer-construction
zhuohan123 Dec 6, 2022
4e4ca90
use old layer construction and generate moe search config
zhuohan123 Dec 6, 2022
ce7e4e7
add moe search profile
zhuohan123 Dec 6, 2022
ff2ddbf
fix
zhuohan123 Dec 6, 2022
e9bbb58
fix
zhuohan123 Dec 6, 2022
b8f9fd3
fix
zhuohan123 Dec 6, 2022
672e17b
fix
zhuohan123 Dec 6, 2022
ddcabeb
name change
zhuohan123 Dec 6, 2022
bd81a66
Merge branch 'main' into inference-layer-construction
zhuohan123 Dec 6, 2022
944d7c5
fix tf version
zhuohan123 Dec 6, 2022
0c89467
fix
zhuohan123 Dec 8, 2022
c4c9b69
fix tf and benchmarking longer sequences
zhuohan123 Dec 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 107 additions & 21 deletions alpa/pipeline_parallel/layer_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,29 +434,118 @@ def dp(input_sizes, blocked):
assert r == 0, "No solution for layer construction."
solution = list(reversed(reversed_sliced_eqns))

# print("dp solution")
# for i, eqns in enumerate(solution):
# invars = OrderedSet()
# for eqn in eqns:
# invars.update([var for var in eqn.invars if isinstance(var, Var)])
# invars.intersection_update(jaxpr.jaxpr.invars)
# print(f"mesh: {i}, set_shapes: "
# f"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}")
#
# invars = []
# for eqn in eqns:
# tmp_set = set([var for var in eqn.invars if isinstance(var, Var)])
# tmp_set.intersection_update(jaxpr.jaxpr.invars)
# invars.extend(list(tmp_set))
# print(f"mesh: {i}, list_shapes: "
# f"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}")
# log_solution(solution, jaxpr)

solution_info = {
"total_cost": value,
}
return solution, solution_info


def cluster_jaxpr_by_cost_optimized(jaxpr: Jaxpr, layer_num: int, costs,
cost_criteria):
"""Clusters the jaxpr by cost."""
layer_num = int(layer_num)
length = len(jaxpr.eqns)
_, input_sizes, compute_costs = costs
assert cost_criteria == "flops"
FLOPS_NORMALIZER = 30 * 1e12 # 30 TFLOPS
NETWORK_NORMALIZER = 2 * 1e9 # 2 GB / s

@maybe_numba_jit
def init_layer_costs():
layer_costs = np.full((length, length), np.inf, dtype=np.float32)
for l in range(0, length):
layer_flops = 0
for r in range(l, length):
layer_flops += compute_costs[r]
layer_costs[l, r] = (layer_flops / FLOPS_NORMALIZER +
input_sizes[l, r + 1] / NETWORK_NORMALIZER)
return layer_costs

@maybe_numba_jit
def dp(layer_costs):
max_cost = np.full((length + 1, layer_num + 1),
np.inf,
dtype=np.float32)
sum_cost_under_max = np.full((length + 1, layer_num + 1),
np.inf,
dtype=np.float32)
squared_cost_under_max = np.full((length + 1, layer_num + 1),
np.inf,
dtype=np.float32)
max_cost_argmin = np.full((length + 1, layer_num + 1),
-1,
dtype=np.int32)
max_cost[0, 0] = 0
sum_cost_under_max[0, 0] = 0
squared_cost_under_max[0, 0] = 0
for q in range(1, layer_num + 1):
for r in range(1, length + 1):
for k in range(0, r):
new_value = max(max_cost[k, q - 1], layer_costs[k, r - 1])
new_squared_sum = (squared_cost_under_max[k, q - 1] +
layer_costs[k, r - 1]**2)
if (new_value < max_cost[r, q] or
(new_value <= max_cost[r, q] * (1 + 1e-4) and
new_squared_sum < squared_cost_under_max[r, q])):
max_cost[r, q] = new_value
squared_cost_under_max[r, q] = new_squared_sum
max_cost_argmin[r, q] = k
return max_cost_argmin, max_cost[length, layer_num]

layer_costs = init_layer_costs()
a_argmin, value = dp(layer_costs)

reversed_sliced_eqns = []

r = length
for q in range(layer_num, 0, -1):
k = a_argmin[r, q]
reversed_sliced_eqns.append(jaxpr.eqns[k:r])
r = k
assert r == 0, "No solution for layer construction."
solution = list(reversed(reversed_sliced_eqns))

# log_solution(solution, jaxpr)

solution_info = {
"total_cost": value,
}
return solution, solution_info


def log_solution(solution, jaxpr):
print("-" * 80)
print(f"Layer construction solution ({len(solution)} layers):")
for i, eqns in enumerate(solution):
print("-" * 40)
print(f"Layer {i}:")

total_flops = 0
for j, eqn in enumerate(eqns):
flops = eqn_flops(eqn)
print(f"Eqn {j}: {eqn}, Flops: {flops}")
total_flops += flops
print(f"Total flops: {total_flops}")
invars = OrderedSet()
for eqn in eqns:
invars.update([var for var in eqn.invars if isinstance(var, Var)])
invars.intersection_update(jaxpr.jaxpr.invars)

print(f"set_invar_shapes: "
f"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}")

invars = []
for eqn in eqns:
tmp_set = {var for var in eqn.invars if isinstance(var, Var)}
tmp_set.intersection_update(jaxpr.jaxpr.invars)
invars.extend(list(tmp_set))
print(f"list_invar_shapes: "
f"{[x.aval.shape for x in invars if len(x.aval.shape) > 1]}")
print("-" * 80)


def search_layer_num(jaxpr,
eps,
layer_eps=0,
Expand Down Expand Up @@ -511,11 +600,8 @@ def wrapped(*args):
layer_num = search_layer_num(jaxpr, eps, layer_eps)
costs = get_layer_construction_costs(jaxpr,
cost_criteria=cost_criteria)
sliced_eqns, _ = cluster_jaxpr_by_cost(jaxpr,
layer_num,
eps,
costs,
cost_criteria=cost_criteria)
sliced_eqns, _ = cluster_jaxpr_by_cost(
jaxpr, layer_num, eps, costs, cost_criteria=cost_criteria)
else:
sliced_eqns = slice_eqns_by_layer_boundary(jaxpr)

Expand Down
3 changes: 3 additions & 0 deletions benchmark/alpa/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
"gpt.grid_search_auto": suite_auto_gpt.grid_search_suite,
"gpt.correctness_test_auto": suite_auto_gpt.correctness_test_suite,
"gpt_inference.profile": suite_inference_gpt.profile_suite,
"gpt_inference.test": suite_inference_gpt.test_suite,
"gpt_inference.search": suite_inference_gpt.search_suite,
"gpt_no_embedding_inference.profile": suite_inference_gpt.profile_suite,
"moe.tmp": suite_manual_moe.tmp_suite,
"moe.tmp_auto": suite_auto_moe.tmp_suite,
"moe.perf_test_fast_2d": suite_manual_moe.perf_test_fast_2d_suite,
"moe.perf_test_auto": suite_auto_moe.perf_test_suite,
"moe.grid_search_auto": suite_auto_moe.grid_search_suite,
"moe_inference.profile": suite_inference_moe.profile_suite,
"moe_inference.search": suite_inference_moe.search_suite,
"unet.perf_test_auto": suite_unet.perf_test_auto_suite,
"unet.grid_search_auto": suite_unet.grid_search_auto_suite,
"wresnet.perf_test_2d": suite_wresnet.perf_test_2d_suite,
Expand Down
72 changes: 54 additions & 18 deletions benchmark/alpa/benchmark_one_case_gpt_bert_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Benchmark one case of inter-op + intra-op parallelism."""
import os
from datetime import datetime

import jax
import jax.numpy as jnp
Expand All @@ -10,6 +11,7 @@
from alpa.model.bert_model import BertConfig, FlaxBertLayerCollection
from alpa.model.gpt_model import FlaxGPTForLMModule
from alpa.util import print_used_time, GB, write_tsv
from alpa.pipeline_parallel.stage_construction import get_last_dp_result

from util import compute_gpt_parameter_count, compute_gpt_tflops
from benchmark_parallel_utils import (
Expand Down Expand Up @@ -179,31 +181,65 @@ def benchmark_gpt_inference_internal(model_type,
# Log per-stage execution information if needed
if profile_stage_execution_time:
model_name = f"bert-{parameter_count/1e9:.1f}b"
# dump chrome trace
executable.dump_stage_execution_trace(
f"./chrome_trace/{model_name},bs={benchmark_case.batch_size},op={benchmark_case.parallel_args.op},pp={benchmark_case.parallel_args.pp}.json"
)
# compute and log per-stage latency/memory statistics
exec_info = executable.get_stage_execution_info()
timelines = list(zip(*exec_info))
# drop warmup case
timelines = timelines[1:]
avg_stage_latencies = compute_avg_stage_latencies(timelines)
assert len(avg_stage_latencies) == num_manual_pipeline_stages
if num_manual_pipeline_stages is not None:
assert len(avg_stage_latencies) == num_manual_pipeline_stages
parallel_args = benchmark_case.parallel_args
dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp
heads = [
"ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU",
"MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)",
"StagePeakMem(B)", "StageLatencies(s)"
]
values = [
model_name, benchmark_case.batch_size,
benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp,
f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}",
f"{tflops:.2f}", f"{per_stage_weight_mem}", f"{per_stage_peak_mem}",
avg_stage_latencies
]
if benchmark_case.parallel_mode == "uniform":
dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp
# dump chrome trace
executable.dump_stage_execution_trace(
f"./chrome_trace/{model_name},"
f"bs={benchmark_case.batch_size},"
f"op={op},pp={pp}.json")
heads = [
"ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU",
"MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)",
"StagePeakMem(B)", "StageLatencies(s)"
]
values = [
model_name, benchmark_case.batch_size,
benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp,
f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}",
f"{tflops:.2f}", f"{per_stage_weight_mem}",
f"{per_stage_peak_mem}", avg_stage_latencies
]
else:
(compute_cost_file_name, forward_stage_layer_ids, submesh_shapes,
logical_mesh_shapes,
autosharding_option_dicts) = get_last_dp_result()
metadata = {
"compilation_times": compilation_times,
"compute_cost_file_name": compute_cost_file_name,
"forward_stage_layer_ids": forward_stage_layer_ids,
"submesh_shapes": submesh_shapes,
"logical_mesh_shapes": logical_mesh_shapes,
"autosharding_option_dicts": autosharding_option_dicts,
}

timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
executable.dump_stage_execution_trace(
f"./chrome_trace/{model_name},"
f"bs={benchmark_case.batch_size},"
f"{timestamp}.json")
heads = [
"ModelName", "BS", "#Microbatch", "ParallelArgs", "MeanTime(s)",
"StdTime(s)", "TFLOPs", "StageWeights(B)", "StagePeakMem(B)",
"StageLatencies(s)", "Metadata", "TimeStamp"
]
values = [
model_name, benchmark_case.batch_size,
benchmark_case.num_micro_batches, parallel_args,
f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}",
f"{tflops:.2f}", f"{list(per_stage_weight_mem)}",
f"{list(per_stage_peak_mem)}",
list(avg_stage_latencies), metadata, timestamp
]
write_tsv(heads, values, f"benchmark_results.tsv")

metadata = {
Expand Down
76 changes: 56 additions & 20 deletions benchmark/alpa/benchmark_one_case_moe_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Benchmark one case of inter-op + intra-op parallelism."""
from datetime import datetime

import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -23,7 +25,7 @@ def create_infer_params_aval(rngkey, model, batch):
params = jax.eval_shape(
lambda p: jax.tree_util.tree_map(
lambda x: jnp.asarray(x, dtype=jnp.float16), p), params)
return params
return params

def get_infer_step(parallel_method, model):

Expand All @@ -41,7 +43,7 @@ def infer_step(params, batch, rng_key):
loss = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1)
loss = (label_mask * loss).sum() / label_mask.sum()
return loss

return parallelize(infer_step, method=parallel_method, donate_argnums=())


Expand Down Expand Up @@ -168,31 +170,65 @@ def benchmark_moe_inference_internal(benchmark_case,
# Log per-stage execution information if needed
if profile_stage_execution_time:
model_name = f"moe-{parameter_count/1e9:.1f}b"
# dump chrome trace
executable.dump_stage_execution_trace(
f"./chrome_trace/{model_name},bs={benchmark_case.batch_size},op={benchmark_case.parallel_args.op},pp={benchmark_case.parallel_args.pp}.json"
)
# compute and log per-stage latency/memory statistics
exec_info = executable.get_stage_execution_info()
timelines = list(zip(*exec_info))
# drop warmup case
timelines = timelines[1:]
avg_stage_latencies = compute_avg_stage_latencies(timelines)
assert len(avg_stage_latencies) == num_manual_pipeline_stages
if num_manual_pipeline_stages is not None:
assert len(avg_stage_latencies) == num_manual_pipeline_stages
parallel_args = benchmark_case.parallel_args
dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp
heads = [
"ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU",
"MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)",
"StagePeakMem(B)", "StageLatencies(s)"
]
values = [
model_name, benchmark_case.batch_size,
benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp,
f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}",
f"{tflops:.2f}", f"{per_stage_weight_mem}", f"{per_stage_peak_mem}",
avg_stage_latencies
]
if benchmark_case.parallel_mode == "uniform":
dp, op, pp = parallel_args.dp, parallel_args.op, parallel_args.pp
# dump chrome trace
executable.dump_stage_execution_trace(
f"./chrome_trace/{model_name},"
f"bs={benchmark_case.batch_size},"
f"op={op},pp={pp}.json")
heads = [
"ModelName", "BS", "#Microbatch", "DP", "OP", "PP", "#GPU",
"MeanTime(s)", "StdTime(s)", "TFLOPs", "StageWeights(B)",
"StagePeakMem(B)", "StageLatencies(s)"
]
values = [
model_name, benchmark_case.batch_size,
benchmark_case.num_micro_batches, dp, op, pp, dp * op * pp,
f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}",
f"{tflops:.2f}", f"{per_stage_weight_mem}",
f"{per_stage_peak_mem}", avg_stage_latencies
]
else:
(compute_cost_file_name, forward_stage_layer_ids, submesh_shapes,
logical_mesh_shapes,
autosharding_option_dicts) = get_last_dp_result()
metadata = {
"compilation_times": compilation_times,
"compute_cost_file_name": compute_cost_file_name,
"forward_stage_layer_ids": forward_stage_layer_ids,
"submesh_shapes": submesh_shapes,
"logical_mesh_shapes": logical_mesh_shapes,
"autosharding_option_dicts": autosharding_option_dicts,
}

timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
executable.dump_stage_execution_trace(
f"./chrome_trace/{model_name},"
f"bs={benchmark_case.batch_size},"
f"{timestamp}.json")
heads = [
"ModelName", "BS", "#Microbatch", "ParallelArgs", "MeanTime(s)",
"StdTime(s)", "TFLOPs", "StageWeights(B)", "StagePeakMem(B)",
"StageLatencies(s)", "Metadata", "TimeStamp"
]
values = [
model_name, benchmark_case.batch_size,
benchmark_case.num_micro_batches, parallel_args,
f"{np.mean(latencies):.3f}", f"{np.std(latencies):.3f}",
f"{tflops:.2f}", f"{list(per_stage_weight_mem)}",
f"{list(per_stage_peak_mem)}",
list(avg_stage_latencies), metadata, timestamp
]
write_tsv(heads, values, f"benchmark_results.tsv")

metadata = {
Expand Down
2 changes: 1 addition & 1 deletion benchmark/alpa/gen_serving_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@
args = parser.parse_args()

database = ProfilingDatabase(args.output, args.new)
database.update_from_csv(args.input)
database.update_from_auto_csv(args.input)
database.materialize()
Loading