From 225d7370f915f9a4fd8319a05b98d380da8123ef Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Thu, 27 Jun 2024 14:22:16 -0700 Subject: [PATCH 1/2] refactor the lora load function for clarity and simplicity This should prevent some nasty illegal memory access errors 1. Consolidate individual list comprehensions into a single for loop 2. Distinct code to create the lora weight pointers tensor 3. Add test cases to test segments, indices, and weight pointers 4. Add comments explaining the code in some tricky places --- server/lorax_server/adapters/lora.py | 127 +++++++++++---------------- server/tests/utils/test_lora.py | 122 +++++++++++++++++-------- 2 files changed, 136 insertions(+), 113 deletions(-) diff --git a/server/lorax_server/adapters/lora.py b/server/lorax_server/adapters/lora.py index 07666bd15..78c8a4c89 100644 --- a/server/lorax_server/adapters/lora.py +++ b/server/lorax_server/adapters/lora.py @@ -252,68 +252,36 @@ def load( device = first_weights.weights_a.device segment_indices = meta.segment_indices - lora_a = {idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights} - lora_b = {idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights} - - segment_ranks = [adapter_weights[idx].lora_a_r for idx in segment_indices if idx in adapter_weights] - if not segment_ranks: + lora_a, lora_b, adapter_index_configs = {}, {}, {} + max_rank, rank_indices = 0, defaultdict(list) + for segment_idx, adapter_idx in enumerate(segment_indices): + if adapter_idx in adapter_weights: + adapter_weight = adapter_weights[adapter_idx] + adapter_index_configs[adapter_idx] = adapter_weight.adapter_config + max_rank = max(max_rank, adapter_weight.lora_a_r) + rank_indices[adapter_weight.lora_a_r].append(segment_idx) + lora_a[adapter_idx] = adapter_weight.weights_a + lora_b[adapter_idx] = adapter_weight.weights_b + + if not max_rank: return None - max_rank = max(segment_ranks) - if prefill or max_rank > BGMV_MAX_RANK: - use_sgmv = True - lora_a_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - lora_b_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - else: - use_sgmv = False - lora_a_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_a_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - lora_b_ptr = torch.tensor( - [ - (adapter_weights[idx].weights_b_t.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr()) - for idx in segment_indices - ], - dtype=torch.int64, - device=device, - ) - - adapter_index_configs = { - idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights - } - - rank_indices = defaultdict(list) - for segment_idx, adapter_idx in enumerate(segment_indices): - if adapter_idx not in adapter_weights: - continue - rank_indices[adapter_weights[adapter_idx].lora_a_r].append(segment_idx) + use_sgmv = prefill or max_rank > BGMV_MAX_RANK if prefill_head_indices is not None: j, prefill_head_segment_starts, prefill_head_segment_ends = 1, [0], [0] + # prefill_head_indices is used to slice the tokens in the batch such + # that we only forward the last token for each request through lm_head + # there can be multiple head_index associated with each adapter segment for head_index in prefill_head_indices: - # j cannot go out of bounds as that would mean there are tokens without corresponding adapters + # j cannot go out of bounds as that would mean there are tokens without segments if head_index < meta.adapter_segments[j]: + # head_index is part of the current adapter + # so increment the current segment end prefill_head_segment_ends[-1] += 1 else: + # head_index in not part of the current adapter + # close the previous segment and start a new one prefill_head_segment_starts.append(prefill_head_segment_ends[-1]) prefill_head_segment_ends.append(prefill_head_segment_ends[-1] + 1) j += 1 @@ -325,40 +293,51 @@ def load( segment_starts = None segment_ends = None batch_indices = None + lora_a_ptr_indices = [] + lora_b_ptr_indices = [] if use_sgmv: - lora_a_ptr_indices = lora_a_ptr[indices] - tmp_shrink, tmp_expand = get_tmp_tensors(lora_a_ptr_indices.size(0), rank, device) + for segment_idx in indices: + adapter_weight = adapter_weights[segment_indices[segment_idx]] + lora_a_ptr_indices.append(adapter_weight.weights_a.data_ptr()) + lora_b_ptr_indices.append(adapter_weight.weights_b.data_ptr()) + tmp_shrink, tmp_expand = get_tmp_tensors(len(lora_a_ptr_indices), rank, device) segment_starts = meta.adapter_segments[indices] segment_ends = meta.adapter_segments[[i + 1 for i in indices]] if prefill_head_indices is not None: - for i, segment_index in enumerate(indices): - segment_starts[i] = prefill_head_segment_starts[segment_index] - segment_ends[i] = prefill_head_segment_ends[segment_index] + # since prefill_head_indices is present the segment starts and ends + # need to be adjusted according to the number of head tokens in each + for i, segment_idx in enumerate(indices): + segment_starts[i] = prefill_head_segment_starts[segment_idx] + segment_ends[i] = prefill_head_segment_ends[segment_idx] else: - # `indices` indexes the `segment_indices` which contains segment wise adapter index - # `lora_a_ptr` contains segment wise pointers to lora weights - # lengths of `lora_a_ptr` and `segment_indices` must be same - # `indices` will be used to slice the `lora_a_ptr` tensor - # first, find the mapping between adapter index and its location in the `indices` array - idx_locs = {} - for loc, idx in enumerate(indices): - # use the idx to find the adapter index - if segment_indices[idx] not in idx_locs: - # save the first location of encountering a particular adapter index - idx_locs[segment_indices[idx]] = loc - # second, iterate over the adapter index for each token and find its location in the `indices` array + adapter_idx_to_pointer_idx = {} + # find out which adapters are present in the segments for this rank + # iterate over each segment index and use it to find adapter index and weights + for segment_idx in indices: + adapter_idx = segment_indices[segment_idx] + adapter_weight = adapter_weights[adapter_idx] + # if the adapter hasn't been seen before, then append its weight pointers + # and save the index to the just added pointers for later + if adapter_idx not in adapter_idx_to_pointer_idx: + lora_a_ptr_indices.append(adapter_weight.weights_a_t.data_ptr()) + lora_b_ptr_indices.append(adapter_weight.weights_b_t.data_ptr()) + adapter_idx_to_pointer_idx[adapter_idx] = len(lora_a_ptr_indices) - 1 + # for each token in the batch, see if its adapter is present in the segments for this rank + # if present, then store the index of its weight pointers otherwise store -1 batch_indices = torch.tensor([ - idx_locs[idx] if idx in adapter_weights and adapter_weights[idx].lora_a_r == rank else -1 - for idx in meta.adapter_indices.tolist() + adapter_idx_to_pointer_idx.get(adapter_idx, -1) for adapter_idx in meta.adapter_indices.tolist() ], dtype=torch.int64, device=device) + lora_a_ptr_indices = torch.tensor(lora_a_ptr_indices, dtype=torch.int64, device=device) + lora_b_ptr_indices = torch.tensor(lora_b_ptr_indices, dtype=torch.int64, device=device) + rank_data[rank] = RankSegments( rank=rank, tmp_shrink=tmp_shrink, tmp_expand=tmp_expand, - lora_a_ptr=lora_a_ptr[indices], - lora_b_ptr=lora_b_ptr[indices], + lora_a_ptr=lora_a_ptr_indices, + lora_b_ptr=lora_b_ptr_indices, segment_starts=segment_starts, segment_ends=segment_ends, indices=batch_indices, diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index de6949711..4c2fad227 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -9,6 +9,7 @@ from lorax_server.adapters.types import LORA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, LayerAdapterWeights from lorax_server.utils.sgmv import MIN_RANK_CUSTOM +from lorax_server.utils.segments import find_segments class FakeAdapterWeights(AdapterWeights): @@ -42,16 +43,37 @@ def load( @pytest.mark.parametrize( - "lora_ranks", + "lora_ranks,adapter_indices,expected", [ - [8, 16], - [32, 64], + ( + [8, 8, 16], # ranks of adapters + [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], # adapter indices for each token in the batch + { + 8: ( # rank + [0, 2, 4, 6], # expected segment starts + [2, 4, 6, 8], # expected segment ends + [0, 1, 0, 1], # expected adapter indices + ), + 16: ([8], [10], [2]), + } + ), + ( + [4, 8, 16], + [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], + { + 4: ([0, 4], [2, 6], [0, 0]), + 8: ([2, 6], [4, 8], [1, 1]), + 16: ([8], [10], [2]), + } + ), ], ) -def test_batched_lora_weights(lora_ranks: List[int]): - # batch meta is hardcoded with this assumption below - assert len(lora_ranks) == 2 - +def test_batched_lora_weights( + lora_ranks: List[int], + adapter_indices: List[int], + expected: Dict[int, Tuple[List[int], Tuple[int], Tuple[int]]] +): + num_adapters = len(lora_ranks) batched_weights = LayerAdapterWeights() assert batched_weights.is_empty() @@ -68,59 +90,73 @@ def test_batched_lora_weights(lora_ranks: List[int]): batched_weights.add_adapter(idx, weights) assert not batched_weights.is_empty() - assert len(batched_weights.adapter_weights) == 2 + assert len(batched_weights.adapter_weights) == num_adapters + + segments, segment_indices = find_segments(adapter_indices) meta = AdapterBatchMetadata( - adapter_indices=torch.tensor([0, 0, 1, 1, 0, 0, 1, 1], dtype=torch.int64), - adapter_set={0, 1}, - adapter_segments=torch.tensor([0, 2, 4, 6, 8], dtype=torch.int64), - segment_indices=[0, 1, 0, 1], + adapter_indices=torch.tensor(adapter_indices, dtype=torch.int64), + adapter_set=set(adapter_indices), + adapter_segments=torch.tensor(segments, dtype=torch.int64), + segment_indices=segment_indices, ) with mock.patch("lorax_server.adapters.lora.get_tmp_tensors", return_value=(torch.empty(0), torch.empty(0))): data = batched_weights.get_data(meta, prefill=True, prefill_head_indices=None).get(LORA) - assert len(data.lora_a) == 2 + assert len(data.lora_a) == num_adapters + assert len(data.lora_b) == num_adapters assert data.lora_a.keys() == meta.adapter_set - assert data.lora_a[0].shape == ((1, h, lora_ranks[0]) if lora_ranks[0] < MIN_RANK_CUSTOM else (1, lora_ranks[0], h)) - assert data.lora_a[1].shape == ((1, h, lora_ranks[1]) if lora_ranks[1] < MIN_RANK_CUSTOM else (1, lora_ranks[1], h)) - - assert len(data.lora_b) == 2 assert data.lora_b.keys() == meta.adapter_set - assert data.lora_b[0].shape == (1, lora_ranks[0], h) - assert data.lora_b[1].shape == (1, lora_ranks[1], h) + for i in range(num_adapters): + assert data.lora_a[i].shape == ( + (1, h, lora_ranks[i]) if lora_ranks[i] < MIN_RANK_CUSTOM else (1, lora_ranks[i], h) + ) + assert data.lora_b[i].shape == (1, lora_ranks[i], h) - assert len(data.rank_data) == 2 - assert data.rank_data.keys() == set(lora_ranks) for lora_rank, rd in data.rank_data.items(): assert rd.rank == lora_rank - - # shape in all cases is the number of segments with this rank - assert rd.lora_a_ptr.shape == (2,) - assert rd.lora_b_ptr.shape == (2,) - assert rd.segment_starts.shape == (2,) - assert rd.segment_ends.shape == (2,) - + expected_lora_a_ptr = [] + expected_lora_b_ptr = [] + for adapter_idx in expected[lora_rank][2]: + expected_lora_a_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_a.data_ptr()) + expected_lora_b_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_b.data_ptr()) + expected_lora_a_ptr = torch.tensor(expected_lora_a_ptr, dtype=rd.lora_a_ptr.dtype, device=rd.lora_a_ptr.device) + expected_lora_b_ptr = torch.tensor(expected_lora_b_ptr, dtype=rd.lora_b_ptr.dtype, device=rd.lora_b_ptr.device) + assert all(rd.lora_a_ptr == expected_lora_a_ptr) + assert all(rd.lora_b_ptr == expected_lora_b_ptr) + + expected_segment_starts = torch.tensor( + expected[lora_rank][0], dtype=rd.segment_starts.dtype, device=rd.segment_starts.device + ) + expected_segment_ends = torch.tensor( + expected[lora_rank][1], dtype=rd.segment_ends.dtype, device=rd.segment_ends.device + ) + assert all(rd.segment_ends == expected_segment_ends) + assert all(rd.segment_starts == expected_segment_starts) @pytest.mark.parametrize( "lora_ranks,adapter_indices,expected", [ ( - [8, 8, 16], - [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], + [8, 8, 16], # ranks of adapters + [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], # adapter indices for each token in the batch { - 8: (4, [0, 0, 1, 1, 0, 0, 1, 1, -1, -1]), - 16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]) + 8: ( # rank + [0, 1], # expected adapter indices + [0, 0, 1, 1, 0, 0, 1, 1, -1, -1] # expected indices + ), + 16: ([2], [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]), } ), ( [4, 8, 16], [0, 0, 1, 1, 0, 0, 1, 1, 2, 2], { - 4: (2, [0, 0, -1, -1, 0, 0, -1, -1, -1, -1]), - 8: (2, [-1, -1, 0, 0, -1, -1, 0, 0, -1, -1]), - 16: (1, [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]), + 4: ([0], [0, 0, -1, -1, 0, 0, -1, -1, -1, -1]), + 8: ([1], [-1, -1, 0, 0, -1, -1, 0, 0, -1, -1]), + 16: ([2], [-1, -1, -1, -1, -1, -1, -1, -1, 0, 0]), } ), ], @@ -128,9 +164,8 @@ def test_batched_lora_weights(lora_ranks: List[int]): def test_batched_lora_weights_decode( lora_ranks: List[int], adapter_indices: List[int], - expected: Dict[int, Tuple[int, List[int]]] + expected: Dict[int, Tuple[List[int], List[int]]] ): - from lorax_server.utils.segments import find_segments batched_weights = LayerAdapterWeights() assert batched_weights.is_empty() @@ -156,10 +191,19 @@ def test_batched_lora_weights_decode( data = batched_weights.get_data(meta, prefill=False, prefill_head_indices=None).get(LORA) for lora_rank, rd in data.rank_data.items(): + expected_lora_a_ptr = [] + expected_lora_b_ptr = [] + for adapter_idx in expected[lora_rank][0]: + expected_lora_a_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_a_t.data_ptr()) + expected_lora_b_ptr.append(batched_weights.adapter_weights[adapter_idx].weights_b_t.data_ptr()) + expected_lora_a_ptr = torch.tensor(expected_lora_a_ptr, dtype=rd.lora_a_ptr.dtype, device=rd.lora_a_ptr.device) + expected_lora_b_ptr = torch.tensor(expected_lora_b_ptr, dtype=rd.lora_b_ptr.dtype, device=rd.lora_b_ptr.device) + assert all(rd.lora_a_ptr == expected_lora_a_ptr) + assert all(rd.lora_b_ptr == expected_lora_b_ptr) + expected_indices = torch.tensor(expected[lora_rank][1], dtype=rd.indices.dtype, device=rd.indices.device) - assert rd.lora_a_ptr.shape == (expected[lora_rank][0],) - assert rd.lora_b_ptr.shape == (expected[lora_rank][0],) assert all(rd.indices == expected_indices) + assert rd.segment_starts == None assert rd.segment_ends == None assert rd.tmp_shrink == None From 74307716f94cab03c15dad8e9888dcf437004948 Mon Sep 17 00:00:00 2001 From: Ajinkya Tejankar Date: Thu, 27 Jun 2024 14:29:09 -0700 Subject: [PATCH 2/2] formatting fix to keep ruff happy --- server/tests/utils/test_lora.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/server/tests/utils/test_lora.py b/server/tests/utils/test_lora.py index 4c2fad227..fcdb84800 100644 --- a/server/tests/utils/test_lora.py +++ b/server/tests/utils/test_lora.py @@ -8,8 +8,8 @@ from lorax_server.adapters.lora import LoraWeights from lorax_server.adapters.types import LORA from lorax_server.adapters.weights import AdapterBatchMetadata, AdapterWeights, BatchAdapterWeights, LayerAdapterWeights -from lorax_server.utils.sgmv import MIN_RANK_CUSTOM from lorax_server.utils.segments import find_segments +from lorax_server.utils.sgmv import MIN_RANK_CUSTOM class FakeAdapterWeights(AdapterWeights): @@ -204,10 +204,10 @@ def test_batched_lora_weights_decode( expected_indices = torch.tensor(expected[lora_rank][1], dtype=rd.indices.dtype, device=rd.indices.device) assert all(rd.indices == expected_indices) - assert rd.segment_starts == None - assert rd.segment_ends == None - assert rd.tmp_shrink == None - assert rd.tmp_expand == None + assert rd.segment_starts is None + assert rd.segment_ends is None + assert rd.tmp_shrink is None + assert rd.tmp_expand is None def test_batched_lora_weights_no_segments(): batched_weights = LayerAdapterWeights()