Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,15 +673,17 @@ def get_generator(self, device: torch.device) -> torch.Generator:
assert self._generator.device == device
return self._generator

def get_spec_tree_manager(self, resource_manager: ResourceManager) -> Optional[SpecTreeManager]:
def get_spec_tree_manager(
self, resource_manager: Optional[ResourceManager]
) -> Optional[SpecTreeManager]:
if resource_manager is None:
return None
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER
)
if spec_resource_manager is None or not hasattr(spec_resource_manager, "spec_tree_manager"):
return None
return spec_resource_manager.spec_tree_manager
return spec_resource_manager.spec_tree_manager # type: ignore

@staticmethod
def _meet_max_token_stop_criteria(request: LlmRequest, max_seq_len: int):
Expand Down
23 changes: 10 additions & 13 deletions tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,14 @@ def _make_tensor(data: list, dtype: torch.dtype, device: torch.device) -> torch.
def _prepare_logits_with_temperature(
logits: torch.Tensor,
group_logit_indices: Optional[torch.Tensor],
temperature: Optional[torch.Tensor],
temperature: torch.Tensor,
) -> torch.Tensor:
if temperature is not None:
temperature = temperature.unsqueeze(-1)
if group_logit_indices is not None:
logits = torch.index_select(logits, 0, group_logit_indices) # ensures copy
logits /= temperature
else:
logits = logits / temperature # not inplace
elif group_logit_indices is not None:
logits = logits[group_logit_indices]
temperature = temperature.unsqueeze(-1)
if group_logit_indices is not None:
logits = torch.index_select(logits, 0, group_logit_indices) # ensures copy
logits /= temperature
else:
logits = logits / temperature # not inplace
return logits

@staticmethod
Expand All @@ -112,12 +109,12 @@ def _prepare_probs_with_temperature(
) -> torch.Tensor:
if group_logit_indices is not None:
logits = logits[group_logit_indices]
logits = flashinfer.sampling.softmax(
probs = flashinfer.sampling.softmax(
logits,
temperature,
enable_pdl=ENABLE_PDL,
)
return logits
return probs

@classmethod
def _sample_from_probs(
Expand Down Expand Up @@ -151,7 +148,7 @@ def _sample_with_probs(
group_logit_indices: Optional[torch.Tensor],
top_k: Optional[torch.Tensor],
top_p: Optional[torch.Tensor],
temperature: Optional[torch.Tensor],
temperature: torch.Tensor,
generator: Optional[torch.Generator],
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if top_k is not None:
Expand Down
Loading