Skip to content

Commit 123ff28

Browse files
benchmark_inference: Allow passing cache option as cli arg (#2695)
Co-authored-by: Riccardo Felluga <11768013+riccardofelluga@users.noreply.github.com>
1 parent 77261d1 commit 123ff28

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

thunder/benchmarks/benchmark_inference.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ class InferenceBenchmarkConfig:
157157
disable_moe_replacement: bool
158158
attn_implementation: str | None
159159
profile: bool
160+
thunder_cache: str | None
160161
enable_thunder_cudagraph: bool
161162

162163

@@ -297,6 +298,8 @@ def _thunder_jit_options(self) -> dict[str, Any]:
297298
res["executors"] = [self._mask_transform.get_executor(), *thunder.get_default_executors()]
298299
if self.config.enable_thunder_cudagraph:
299300
res["transforms"].append(CUDAGraphTransform())
301+
if self.config.thunder_cache is not None:
302+
res["cache"] = self.config.thunder_cache
300303

301304
return res
302305

@@ -685,6 +688,12 @@ def parse_args() -> argparse.Namespace:
685688

686689
parser.add_argument("--save-results", action="store_true", help="Save results to JSON file")
687690
parser.add_argument("--output-dir", type=str, default="./results", help="Directory to save results")
691+
parser.add_argument(
692+
"--thunder-cache",
693+
type=str,
694+
default=None,
695+
help="Cache option: no caching, same input, constant values, symbolic values. See `cache` argument of `thunder.jit` for more details.",
696+
)
688697
parser.add_argument("--enable-thunder-cudagraph", action="store_true", help="Pass CUDAGraphTransform to Thunder")
689698
parser.add_argument("--attn-implementation", type=str, default=None, help="Attention implementation")
690699

@@ -720,6 +729,7 @@ def main():
720729
disable_moe_replacement=args.disable_moe_replacement,
721730
attn_implementation=args.attn_implementation,
722731
profile=args.profile,
732+
thunder_cache=args.thunder_cache,
723733
enable_thunder_cudagraph=args.enable_thunder_cudagraph,
724734
)
725735
benchmark = InferenceBenchmark(config)

0 commit comments

Comments
 (0)