diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 3eb9181f6c1..eb47ee0ed6a 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -504,6 +504,15 @@ def should_stop_processing(self): return self.is_shutdown and len(self.active_requests) == 0 and \ self.executor_request_queue.get_waiting_queue_size() == 0 + def _get_ranked_trace_path(self): + """Return a per-rank torch trace path based on TLLM_TORCH_PROFILE_TRACE.""" + rank = getattr(self.dist, "rank", 0) + trace_path = self.torch_trace_path + if os.path.isdir(trace_path): + return os.path.join(trace_path, f"trace_{rank}.json") + base, ext = os.path.splitext(trace_path) + return f"{base}_{rank}{ext or '.json'}" + @contextmanager def _profiler(self): it = -1 @@ -518,12 +527,12 @@ def _profiler(self): start_event_2 = None end_event_2 = torch.cuda.Event(enable_timing=True) prev_device_step_time = None - - torch_trace_path = os.environ.get(PROFILE_TRACE_ENV_VAR_NAME, None) + + self.torch_trace_path = os.environ.get(PROFILE_TRACE_ENV_VAR_NAME, None) profile_start_stop = os.environ.get(PROFILE_START_STOP_ENV_VAR_NAME, None) - enable_torch_trace = bool(torch_trace_path and profile_start_stop) - if torch_trace_path and profile_start_stop is None: + enable_torch_trace = self.torch_trace_path and profile_start_stop + if self.torch_trace_path is not None and profile_start_stop is None: logger.warning( f"{PROFILE_START_STOP_ENV_VAR_NAME} environment variable " "needs to be set to enable the torch trace. Example to profile " @@ -546,6 +555,7 @@ def profile_step(): assert enabled, "Inconsistent CUDA profiling state" if enable_torch_trace: torch_profiler.stop() + torch_trace_path = self._get_ranked_trace_path() torch_profiler.export_chrome_trace(torch_trace_path) logger.info(f"Profiling stopped at iteration {it}, " f"trace saved to {torch_trace_path}") @@ -612,6 +622,7 @@ def profile_step(): # Stop on early exit / exception if enable_torch_trace: torch_profiler.stop() + torch_trace_path = self._get_ranked_trace_path() torch_profiler.export_chrome_trace(torch_trace_path) logger.info(f"Profiling stopped at iteration {it}, " f"trace saved to {torch_trace_path}")