Skip to content

Commit 77261d1

Browse files
authored
benchmark_inference: Add CLI option to enable thunder CUDAGraph Transform (#2697)
1 parent 9ada615 commit 77261d1

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

thunder/benchmarks/benchmark_inference.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
)
5050
from thunder.torch.custom_op import _register_custom_op
5151
from thunder.tests.distributed.test_moe import GroupedLinearColwiseParallel, GroupedLinearRowwiseParallel
52+
from thunder.transforms.cudagraph import CUDAGraphTransform
5253

5354
if TYPE_CHECKING:
5455
from typing import Any
@@ -156,6 +157,7 @@ class InferenceBenchmarkConfig:
156157
disable_moe_replacement: bool
157158
attn_implementation: str | None
158159
profile: bool
160+
enable_thunder_cudagraph: bool
159161

160162

161163
@dataclass
@@ -282,16 +284,20 @@ def __init__(self, config: InferenceBenchmarkConfig):
282284
def _thunder_jit_options(self) -> dict[str, Any]:
283285
# `nv_enable_linear=True` might fail with distributed run
284286
# ref: https://github.com/NVIDIA/Fuser/issues/4507
285-
res = {}
287+
res = {"transforms": []}
286288
if self.config.enable_nv_linear:
287-
res = {"nv_enable_linear": True, "nv_enable_matmul": True}
289+
res["nv_enable_linear"] = True
290+
res["nv_enable_matmul"] = True
288291
if self.config.mode == "thunderjit":
289292
from thunder.recipes.hf_transformers import SDPAMaskTransform
290293

291294
if not hasattr(self, "_mask_transform"):
292295
self._mask_transform = SDPAMaskTransform()
293-
res["transforms"] = [self._mask_transform]
296+
res["transforms"].append(self._mask_transform)
294297
res["executors"] = [self._mask_transform.get_executor(), *thunder.get_default_executors()]
298+
if self.config.enable_thunder_cudagraph:
299+
res["transforms"].append(CUDAGraphTransform())
300+
295301
return res
296302

297303
def _compile_model(self, model):
@@ -679,6 +685,7 @@ def parse_args() -> argparse.Namespace:
679685

680686
parser.add_argument("--save-results", action="store_true", help="Save results to JSON file")
681687
parser.add_argument("--output-dir", type=str, default="./results", help="Directory to save results")
688+
parser.add_argument("--enable-thunder-cudagraph", action="store_true", help="Pass CUDAGraphTransform to Thunder")
682689
parser.add_argument("--attn-implementation", type=str, default=None, help="Attention implementation")
683690

684691
args = parser.parse_args()
@@ -713,6 +720,7 @@ def main():
713720
disable_moe_replacement=args.disable_moe_replacement,
714721
attn_implementation=args.attn_implementation,
715722
profile=args.profile,
723+
enable_thunder_cudagraph=args.enable_thunder_cudagraph,
716724
)
717725
benchmark = InferenceBenchmark(config)
718726

0 commit comments

Comments
 (0)