diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index f718c7fbf..78c8a4c89 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -252,68 +252,36 @@ def load( device = first_weights.weights_a.device segment_indices = meta.segment_indices - lora_a = {idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights} - lora_b = {idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights} - - segment_ranks = [adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights] - if not segment_ranks: + lora_a, lora_b, adapter_index_configs = {}, {}, {} + max_rank, rank_indices = 0, defaultdict(list) + for segment_idx, adapter_idx in enumerate(segment_indices): + if adapter_idx in adapter_weights: + adapter_weight = adapter_weights[adapter_idx] + adapter_index_configs[adapter_idx] = adapter_weight.adapter_config + max_rank = max(max_rank, adapter_weight.lora_a_r) + rank_indices[adapter_weight.lora_a_r].append(segment_idx) + lora_a[adapter_idx] = adapter_weight.weights_a + lora_b[adapter_idx] = adapter_weight.weights_b + + if not max_rank: return None - max_rank = max(segment_ranks) - if prefill or max_rank > BGMV_MAX_RANK: - use_sgmv = True - lora_a_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - lora_b_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - else: - use_sgmv = False - lora_a_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_a_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - lora_b_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_b_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - - adapter_index_configs = { - idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights - } - - rank_indices = defaultdict(list) - for segment_idx, adapter_idx in enumerate(segment_indices): - if adapter_idx not in adapter_weights: - continue - rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) + use_sgmv = prefill or max_rank > BGMV_MAX_RANK if prefill_head_indices is not None: j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] + # prefill_head_indices is used to slice the tokens in the batch such + # that we only forward the last token for each request through lm_head + # there can be multiple head_index associated with each adapter segment for head_index in prefill_head_indices: - # j cannot go out of bounds as that would mean there are tokens without corresponding adapters + # j cannot go out of bounds as that would mean there are tokens without segments if head_index < meta.adapter_segments[j]: + # head_index is part of the current adapter + # so increment the current segment end prefill_head_segment_ends[-1] += 1 else: + # head_index in not part of the current adapter + # close the previous segment and start a new one prefill_head_segment_starts.append(prefill_head_segment_ends[-1]) prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1) j += 1 @@ -325,44 +293,51 @@ def load( segment_starts = None segment_ends = None batch_indices = None + lora_a_ptr_indices = [] + lora_b_ptr_indices = [] if use_sgmv: - lora_a_ptr_indices = lora_a_ptr[indices] - tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device) + for segment_idx in indices: + adapter_weight = adapter_weights[segment_indices[segment_idx]] + lora_a_ptr_indices.append(adapter_weight.weights_a.data_ptr()) + lora_b_ptr_indices.append(adapter_weight.weights_b.data_ptr()) + tmp_shrink, tmp_expand = get_tmp_tensors(len(lora_a_ptr_indices), rank, device) segment_starts = meta.adapter_segments[indices] segment_ends = meta.adapter_segments[[i + 1 for i in indices]] if prefill_head_indices is not None: - for i, segment_index in enumerate(indices): - segment_starts[i] = prefill_head_segment_starts[segment_index] - segment_ends[i] = prefill_head_segment_ends[segment_index] + # since prefill_head_indices is present the segment starts and ends + # need to be adjusted according to the number of head tokens in each + for i, segment_idx in enumerate(indices): + segment_starts[i] = prefill_head_segment_starts[segment_idx] + segment_ends[i] = prefill_head_segment_ends[segment_idx] else: - # `indices` indexes the `segment_indices` which contains segment wise adapter index - # `lora_a_ptr` contains segment wise pointers to lora weights - # lengths of `lora_a_ptr` and `segment_indices` must be same - # `indices` will be used to slice the `lora_a_ptr` tensor - # first, find the mapping between adapter index and its location in the `indices` array - idx_locs = {} - for loc, idx in enumerate(indices): - # use the idx to find the adapter index - if segment_indices[idx] not in idx_locs: - # save the first location of encountering a particular adapter index - idx_locs[segment_indices[idx]] = loc - # second, iterate over the adapter index for each token and find its location in the `indices` array - batch_indices = torch.tensor( - [ - idx_locs[idx] if idx in adapter_weights and adapter_weights[idx].lora_a_r == rank else -1 - for idx in meta.adapter_indices.tolist() - ], - dtype=torch.int64, - device=device, - ) + adapter_idx_to_pointer_idx = {} + # find out which adapters are present in the segments for this rank + # iterate over each segment index and use it to find adapter index and weights + for segment_idx in indices: + adapter_idx = segment_indices[segment_idx] + adapter_weight = adapter_weights[adapter_idx] + # if the adapter hasn't been seen before, then append its weight pointers + # and save the index to the just added pointers for later + if adapter_idx not in adapter_idx_to_pointer_idx: + lora_a_ptr_indices.append(adapter_weight.weights_a_t.data_ptr()) + lora_b_ptr_indices.append(adapter_weight.weights_b_t.data_ptr()) + adapter_idx_to_pointer_idx[adapter_idx] = len(lora_a_ptr_indices) - 1 + # for each token in the batch, see if its adapter is present in the segments for this rank + # if present, then store the index of its weight pointers otherwise store -1 + batch_indices = torch.tensor([ + adapter_idx_to_pointer_idx.get(adapter_idx, -1) for adapter_idx in meta.adapter_indices.tolist() + ], dtype=torch.int64, device=device) + + lora_a_ptr_indices = torch.tensor(lora_a_ptr_indices, dtype=torch.int64, device=device) + lora_b_ptr_indices = torch.tensor(lora_b_ptr_indices, dtype=torch.int64, device=device) rank_data[rank] = RankSegments( rank=rank, tmp_shrink=tmp_shrink, tmp_expand=tmp_expand, - lora_a_ptr=lora_a_ptr[indices], - lora_b_ptr=lora_b_ptr[indices], + lora_a_ptr=lora_a_ptr_indices, + lora_b_ptr=lora_b_ptr_indices, segment_starts=segment_starts, segment_ends=segment_ends, indices=batch_indices, diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index bbc3546d6..fcdb84800 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -8,6 +8,7 @@ from lorax_server.adapters.lora import LoraWeights from lorax_server.adapters.types import LORA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, LayerAdapterWeights +from lorax_server.utils.segments import find_segments from lorax_server.utils.sgmv import MIN_RANK_CUSTOM @@ -42,16 +43,37 @@ def load( @pytest.mark.parametrize( - "lora_ranks", + "lora_ranks,adapter_indices,expected", [ - [8, 16], - [32, 64], + ( + [8, 8, 16], # ranks of adapters + [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], # adapter indices for each token in the batch + { + 8: ( # rank + [0, 2, 4, 6], # expected segment starts + [2, 4, 6, 8], # expected segment ends + [0, 1, 0, 1], # expected adapter indices + ), + 16: ([8], [10], [2]), + } + ), + ( + [4, 8, 16], + [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], + { + 4: ([0, 4], [2, 6], [0, 0]), + 8: ([2, 6], [4, 8], [1, 1]), + 16: ([8], [10], [2]), + } + ), ], ) -def test_batched_lora_weights(lora_ranks: List[int]): - # batch meta is hardcoded with this assumption below - assert len(lora_ranks) == 2 - +def test_batched_lora_weights( + lora_ranks: List[int], + adapter_indices: List[int], + expected: Dict[int, Tuple[List[int], Tuple[int], Tuple[int]]] +): + num_adapters = len(lora_ranks) batched_weights = LayerAdapterWeights() assert batched_weights.is_empty() @@ -68,59 +90,73 @@ def test_batched_lora_weights(lora_ranks: List[int]): batched_weights.add_adapter(idx, weights) assert not batched_weights.is_empty() - assert len(batched_weights.adapter_weights) == 2 + assert len(batched_weights.adapter_weights) == num_adapters + + segments, segment_indices = find_segments(adapter_indices) meta = AdapterBatchMetadata( - adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64), - adapter_set={0, 1}, - adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64), - segment_indices=[0, 1, 0, 1], + adapter_indices=torch.tensor(adapter_indices, dtype=torch.int64), + adapter_set=set(adapter_indices), + adapter_segments=torch.tensor(segments, dtype=torch.int64), + segment_indices=segment_indices, ) with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None).get(LORA) - assert len(data.lora_a) == 2 + assert len(data.lora_a) == num_adapters + assert len(data.lora_b) == num_adapters assert data.lora_a.keys() == meta.adapter_set - assert data.lora_a[0].shape == ((1, h, lora_ranks[0]) if lora_ranks[0] < MIN_RANK_CUSTOM else (1, lora_ranks[0], h)) - assert data.lora_a[1].shape == ((1, h, lora_ranks[1]) if lora_ranks[1] < MIN_RANK_CUSTOM else (1, lora_ranks[1], h)) - - assert len(data.lora_b) == 2 assert data.lora_b.keys() == meta.adapter_set - assert data.lora_b[0].shape == (1, lora_ranks[0], h) - assert data.lora_b[1].shape == (1, lora_ranks[1], h) + for i in range(num_adapters): + assert data.lora_a[i].shape == ( + (1, h, lora_ranks[i]) if lora_ranks[i] < MIN_RANK_CUSTOM else (1, lora_ranks[i], h) + ) + assert data.lora_b[i].shape == (1, lora_ranks[i], h) - assert len(data.rank_data) == 2 - assert data.rank_data.keys() == set(lora_ranks) for lora_rank, rd in data.rank_data.items(): assert rd.rank == lora_rank - - # shape in all cases is the number of segments with this rank - assert rd.lora_a_ptr.shape == (2,) - assert rd.lora_b_ptr.shape == (2,) - assert rd.segment_starts.shape == (2,) - assert rd.segment_ends.shape == (2,) - + expected_lora_a_ptr = [] + expected_lora_b_ptr = [] + for adapter_idx in expected[lora_rank][2]: + expected_lora_a_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_a.data_ptr()) + expected_lora_b_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_b.data_ptr()) + expected_lora_a_ptr = torch.tensor(expected_lora_a_ptr, dtype=rd.lora_a_ptr.dtype, device=rd.lora_a_ptr.device) + expected_lora_b_ptr = torch.tensor(expected_lora_b_ptr, dtype=rd.lora_b_ptr.dtype, device=rd.lora_b_ptr.device) + assert all(rd.lora_a_ptr == expected_lora_a_ptr) + assert all(rd.lora_b_ptr == expected_lora_b_ptr) + + expected_segment_starts = torch.tensor( + expected[lora_rank][0], dtype=rd.segment_starts.dtype, device=rd.segment_starts.device + ) + expected_segment_ends = torch.tensor( + expected[lora_rank][1], dtype=rd.segment_ends.dtype, device=rd.segment_ends.device + ) + assert all(rd.segment_ends == expected_segment_ends) + assert all(rd.segment_starts == expected_segment_starts) @pytest.mark.parametrize( "lora_ranks,adapter_indices,expected", [ ( - [8, 8, 16], - [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], + [8, 8, 16], # ranks of adapters + [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], # adapter indices for each token in the batch { - 8: (4, [0, 0, 1, 1, 0, 0, 1, 1, -1, -1]), - 16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]) + 8: ( # rank + [0, 1], # expected adapter indices + [0, 0, 1, 1, 0, 0, 1, 1, -1, -1] # expected indices + ), + 16: ([2], [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]), } ), ( [4, 8, 16], [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], { - 4: (2, [0, 0, -1, -1, 0, 0, -1, -1, -1, -1]), - 8: (2, [-1, -1, 0, 0, -1, -1, 0, 0, -1, -1]), - 16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]), + 4: ([0], [0, 0, -1, -1, 0, 0, -1, -1, -1, -1]), + 8: ([1], [-1, -1, 0, 0, -1, -1, 0, 0, -1, -1]), + 16: ([2], [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]), } ), ], @@ -128,9 +164,8 @@ def test_batched_lora_weights(lora_ranks: List[int]): def test_batched_lora_weights_decode( lora_ranks: List[int], adapter_indices: List[int], - expected: Dict[int, Tuple[int, List[int]]] + expected: Dict[int, Tuple[List[int], List[int]]] ): - from lorax_server.utils.segments import find_segments batched_weights = LayerAdapterWeights() assert batched_weights.is_empty() @@ -156,10 +191,19 @@ def test_batched_lora_weights_decode( data = batched_weights.get_data(meta, prefill=False, prefill_head_indices=None).get(LORA) for lora_rank, rd in data.rank_data.items(): + expected_lora_a_ptr = [] + expected_lora_b_ptr = [] + for adapter_idx in expected[lora_rank][0]: + expected_lora_a_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_a_t.data_ptr()) + expected_lora_b_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_b_t.data_ptr()) + expected_lora_a_ptr = torch.tensor(expected_lora_a_ptr, dtype=rd.lora_a_ptr.dtype, device=rd.lora_a_ptr.device) + expected_lora_b_ptr = torch.tensor(expected_lora_b_ptr, dtype=rd.lora_b_ptr.dtype, device=rd.lora_b_ptr.device) + assert all(rd.lora_a_ptr == expected_lora_a_ptr) + assert all(rd.lora_b_ptr == expected_lora_b_ptr) + expected_indices = torch.tensor(expected[lora_rank][1], dtype=rd.indices.dtype, device=rd.indices.device) - assert rd.lora_a_ptr.shape == (expected[lora_rank][0],) - assert rd.lora_b_ptr.shape == (expected[lora_rank][0],) assert all(rd.indices == expected_indices) + assert rd.segment_starts is None assert rd.segment_ends is None assert rd.tmp_shrink is None