@@ -157,6 +157,7 @@ class InferenceBenchmarkConfig:
157157 disable_moe_replacement : bool
158158 attn_implementation : str | None
159159 profile : bool
160+ thunder_cache : str | None
160161 enable_thunder_cudagraph : bool
161162
162163
@@ -297,6 +298,8 @@ def _thunder_jit_options(self) -> dict[str, Any]:
297298 res ["executors" ] = [self ._mask_transform .get_executor (), * thunder .get_default_executors ()]
298299 if self .config .enable_thunder_cudagraph :
299300 res ["transforms" ].append (CUDAGraphTransform ())
301+ if self .config .thunder_cache is not None :
302+ res ["cache" ] = self .config .thunder_cache
300303
301304 return res
302305
@@ -685,6 +688,12 @@ def parse_args() -> argparse.Namespace:
685688
686689 parser .add_argument ("--save-results" , action = "store_true" , help = "Save results to JSON file" )
687690 parser .add_argument ("--output-dir" , type = str , default = "./results" , help = "Directory to save results" )
691+ parser .add_argument (
692+ "--thunder-cache" ,
693+ type = str ,
694+ default = None ,
695+ help = "Cache option: no caching, same input, constant values, symbolic values. See `cache` argument of `thunder.jit` for more details." ,
696+ )
688697 parser .add_argument ("--enable-thunder-cudagraph" , action = "store_true" , help = "Pass CUDAGraphTransform to Thunder" )
689698 parser .add_argument ("--attn-implementation" , type = str , default = None , help = "Attention implementation" )
690699
@@ -720,6 +729,7 @@ def main():
720729 disable_moe_replacement = args .disable_moe_replacement ,
721730 attn_implementation = args .attn_implementation ,
722731 profile = args .profile ,
732+ thunder_cache = args .thunder_cache ,
723733 enable_thunder_cudagraph = args .enable_thunder_cudagraph ,
724734 )
725735 benchmark = InferenceBenchmark (config )
0 commit comments