Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ class PytorchEngineConfig:
session_len (int): Max session length. Default None.
max_batch_size (int): Max batch size. If it is not specified,
the engine will automatically set it according to the device
attn_tp_size (int): tp size for attention, only works for dp>1
mlp_tp_size (int): tp size for mlp, only works for dp>1
moe_tp_size (int): tp size for moe, only works for dp>1
cache_max_entry_count (float): the percentage of gpu memory occupied
by the k/v cache. For lmdeploy versions greater than `v0.2.1`,
it defaults to 0.8, signifying the percentage of FREE GPU memory
Expand Down Expand Up @@ -350,6 +353,9 @@ class PytorchEngineConfig:
ep: int = 1
session_len: int = None
max_batch_size: int = None
attn_tp_size: int = None
mlp_tp_size: int = None
moe_tp_size: int = None
cache_max_entry_count: float = 0.8
prefill_interval: int = 16
block_size: int = 64
Expand Down
7 changes: 6 additions & 1 deletion lmdeploy/pytorch/backends/awq_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ def update_weights(self,
return qweight, scales, qzeros, bias

@abstractmethod
def forward(self, x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, all_reduce: bool = False):
def forward(self,
x,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
raise NotImplementedError

Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/pytorch/backends/blockedf8_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Optional

import torch
import torch.distributed as dist


class LinearBlockedF8Impl(ABC):
Expand All @@ -19,6 +20,7 @@ def forward(self,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[dist.ProcessGroup] = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/backends/cuda/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
self.flash_attention_fwd = flash_attention_fwd

# for alibi attention
world_size, rank = get_tp_world_rank()
world_size, rank = get_tp_world_rank('attn')
self.alibi_head_offset = self.num_heads * rank
self.alibi_num_heads = self.num_heads * world_size
self.block_sparse_size = block_sparse_size
Expand Down
5 changes: 3 additions & 2 deletions lmdeploy/pytorch/backends/cuda/awq_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ def forward(self,
scales: torch.Tensor,
qzeros: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False):
all_reduce: bool = False,
group: Optional[torch.distributed.ProcessGroup] = None):
"""forward."""
out_features = scales.size(1)
out = wq_gemm_forward(x, qweight, qzeros, scales, self.w_bit, self.group_size, bias, out_features)
if all_reduce:
dist.all_reduce(out)
dist.all_reduce(out, group=group)
return out


Expand Down
20 changes: 6 additions & 14 deletions lmdeploy/pytorch/backends/cuda/blockedf8_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,6 @@
logger = get_logger('lmdeploy')


def _reduce_scatter_input(out: torch.Tensor, rank: int, tp_sizes: List[int]):
"""Reduce scatter."""
outs = out.split(tp_sizes, -2)
out = outs[rank]
outs = list(outs)
dist.reduce_scatter(out, outs)
return out


class TritonLinearBlockedF8Impl(LinearBlockedF8Impl):
"""Triton linear blocked f8 implementation."""

Expand All @@ -37,6 +28,7 @@ def forward(self,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[dist.ProcessGroup] = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
Expand All @@ -52,7 +44,7 @@ def forward(self,

if all_reduce:
if scatter_size is not None:
out = _reduce_scatter_input(out, rank, scatter_size)
out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
else:
dist.all_reduce(out)
return out
Expand Down Expand Up @@ -117,6 +109,7 @@ def forward(self,
scale: torch.Tensor,
bias: Optional[torch.Tensor] = None,
all_reduce: bool = False,
group: Optional[dist.ProcessGroup] = None,
rank: int = 0,
scatter_size: List[int] = None):
"""forward."""
Expand All @@ -128,12 +121,11 @@ def forward(self,
out = out[:x.size(0)]
if bias is not None:
out += bias
out = out.unflatten(0, x_shape[:-1])

if all_reduce:
if scatter_size is not None:
out = _reduce_scatter_input(out, rank, scatter_size)
out = dist.reduce_scatter_by_tp_sizes(out, rank, scatter_size, group=group)
else:
dist.all_reduce(out)

out = out.unflatten(0, x_shape[:-1])
dist.all_reduce(out, group=group)
return out
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/backends/cuda/graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def update_inputs(self, inputs):
meta = self.get_meta()
padding_batch_size = meta.padding_batch_size
tp_size = self._get_capture_tokens(padding_batch_size)
dp_meta.tp_sizes = [tp_size] * len(dp_meta.tp_sizes)
dp_meta.sync_tp_size(tp_size)
return inputs

def get_capture_batch_sizes(self) -> List[int]:
Expand Down
Loading
Loading