|
49 | 49 | ) |
50 | 50 | from thunder.torch.custom_op import _register_custom_op |
51 | 51 | from thunder.tests.distributed.test_moe import GroupedLinearColwiseParallel, GroupedLinearRowwiseParallel |
| 52 | +from thunder.transforms.cudagraph import CUDAGraphTransform |
52 | 53 |
|
53 | 54 | if TYPE_CHECKING: |
54 | 55 | from typing import Any |
@@ -156,6 +157,7 @@ class InferenceBenchmarkConfig: |
156 | 157 | disable_moe_replacement: bool |
157 | 158 | attn_implementation: str | None |
158 | 159 | profile: bool |
| 160 | + enable_thunder_cudagraph: bool |
159 | 161 |
|
160 | 162 |
|
161 | 163 | @dataclass |
@@ -282,16 +284,20 @@ def __init__(self, config: InferenceBenchmarkConfig): |
282 | 284 | def _thunder_jit_options(self) -> dict[str, Any]: |
283 | 285 | # `nv_enable_linear=True` might fail with distributed run |
284 | 286 | # ref: https://github.com/NVIDIA/Fuser/issues/4507 |
285 | | - res = {} |
| 287 | + res = {"transforms": []} |
286 | 288 | if self.config.enable_nv_linear: |
287 | | - res = {"nv_enable_linear": True, "nv_enable_matmul": True} |
| 289 | + res["nv_enable_linear"] = True |
| 290 | + res["nv_enable_matmul"] = True |
288 | 291 | if self.config.mode == "thunderjit": |
289 | 292 | from thunder.recipes.hf_transformers import SDPAMaskTransform |
290 | 293 |
|
291 | 294 | if not hasattr(self, "_mask_transform"): |
292 | 295 | self._mask_transform = SDPAMaskTransform() |
293 | | - res["transforms"] = [self._mask_transform] |
| 296 | + res["transforms"].append(self._mask_transform) |
294 | 297 | res["executors"] = [self._mask_transform.get_executor(), *thunder.get_default_executors()] |
| 298 | + if self.config.enable_thunder_cudagraph: |
| 299 | + res["transforms"].append(CUDAGraphTransform()) |
| 300 | + |
295 | 301 | return res |
296 | 302 |
|
297 | 303 | def _compile_model(self, model): |
@@ -679,6 +685,7 @@ def parse_args() -> argparse.Namespace: |
679 | 685 |
|
680 | 686 | parser.add_argument("--save-results", action="store_true", help="Save results to JSON file") |
681 | 687 | parser.add_argument("--output-dir", type=str, default="./results", help="Directory to save results") |
| 688 | + parser.add_argument("--enable-thunder-cudagraph", action="store_true", help="Pass CUDAGraphTransform to Thunder") |
682 | 689 | parser.add_argument("--attn-implementation", type=str, default=None, help="Attention implementation") |
683 | 690 |
|
684 | 691 | args = parser.parse_args() |
@@ -713,6 +720,7 @@ def main(): |
713 | 720 | disable_moe_replacement=args.disable_moe_replacement, |
714 | 721 | attn_implementation=args.attn_implementation, |
715 | 722 | profile=args.profile, |
| 723 | + enable_thunder_cudagraph=args.enable_thunder_cudagraph, |
716 | 724 | ) |
717 | 725 | benchmark = InferenceBenchmark(config) |
718 | 726 |
|
|
0 commit comments