Skip to content

Commit f6347a8

Browse files
committed
Fix OOM.
Signed-off-by: Zheyu Fu <zheyuf@NVIDIA.com>
1 parent b5af379 commit f6347a8

File tree

3 files changed

+5
-0
lines changed

3 files changed

+5
-0
lines changed

tests/unittest/_torch/speculative/test_draft_len_schedule.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def enforce_single_worker():
3333
],
3434
)
3535
@pytest.mark.high_cuda_memory
36+
@pytest.mark.xdist_group("speculative_high_mem")
3637
def test_correctness_across_batch_sizes(drafter_type: str, schedule: dict):
3738
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
3839
memory_required = 30 if drafter_type == "model_drafter" else 20
@@ -146,6 +147,7 @@ def test_correctness_across_batch_sizes(drafter_type: str, schedule: dict):
146147
],
147148
)
148149
@pytest.mark.high_cuda_memory
150+
@pytest.mark.xdist_group("speculative_high_mem")
149151
def test_draft_len_schedule_functionality(drafter_type: str, draft_schedule: dict):
150152
if not torch.cuda.is_available():
151153
pytest.skip("CUDA not available")

tests/unittest/_torch/speculative/test_dynamic_spec_decode.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def enforce_single_worker(monkeypatch):
2424

2525
@pytest.mark.parametrize("disable_overlap_scheduler", [True, False])
2626
@pytest.mark.high_cuda_memory
27+
@pytest.mark.xdist_group("speculative_high_mem")
2728
def test_dynamic_spec_decode(enforce_single_worker,
2829
disable_overlap_scheduler: bool):
2930
# mock_should_use_spec_decode doesn't work with multiple processes,
@@ -124,6 +125,7 @@ def mock_should_use_spec_decode(requests, max_batch_size, max_num_tokens,
124125
# Later: len(requests): 1, max_batch_size: 3, token_cap: 1638 -> num_effective_requests: 1, self.max_concurrency: 2 -> spec decode ON
125126
@pytest.mark.parametrize("disable_overlap_scheduler", [True, False])
126127
@pytest.mark.high_cuda_memory
128+
@pytest.mark.xdist_group("speculative_high_mem")
127129
def test_dynamic_spec_decode_without_force_single_process(
128130
disable_overlap_scheduler: bool):
129131
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9

tests/unittest/_torch/speculative/test_spec_gate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
# This test set the max_concurrency to a large value to prevent spec decode turned off due to number of effective requests > max_concurrency,
2222
# So that we can only focus on the turning off effect from the SpeculationGate.
2323
@pytest.mark.high_cuda_memory
24+
@pytest.mark.xdist_group("speculative_high_mem")
2425
def test_spec_gate_e2e():
2526
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
2627
if total_mem_gb < 35:

0 commit comments

Comments
 (0)