Skip to content

Commit 5607984

Browse files
committed
Update
[ghstack-poisoned]
2 parents a044032 + dd927ae commit 5607984

File tree

7 files changed

+349
-23
lines changed

7 files changed

+349
-23
lines changed

test/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,27 @@ def maybe_fork_ParallelEnv(request):
155155
):
156156
return functools.partial(ParallelEnv, mp_start_method="fork")
157157
return ParallelEnv
158+
159+
160+
# LLM testing fixtures
161+
@pytest.fixture
162+
def mock_transformer_model():
163+
"""Fixture that provides a mock transformer model factory."""
164+
from torchrl.testing import MockTransformerModel
165+
166+
def _make_model(
167+
vocab_size: int = 1024, device: torch.device | str | int = "cpu"
168+
) -> MockTransformerModel:
169+
"""Make a mock transformer model."""
170+
device = torch.device(device)
171+
return MockTransformerModel(vocab_size, device)
172+
173+
return _make_model
174+
175+
176+
@pytest.fixture
177+
def mock_tokenizer():
178+
"""Fixture that provides a mock tokenizer."""
179+
from transformers import AutoTokenizer
180+
181+
return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")

test/llm/libs/test_mlgym.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from torchrl.envs.llm import make_mlgym
1717
from torchrl.modules.llm import TransformersWrapper
1818

19+
pytest.importorskip("mlgym")
20+
1921

2022
class TestMLGYM:
2123
def test_mlgym_specs(self):
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
"""Tests for PythonExecutorService with Ray service registry."""
2+
3+
import pytest
4+
5+
# Skip all tests if Ray is not available
6+
pytest.importorskip("ray")
7+
8+
import ray
9+
from torchrl.envs.llm.transforms import PythonExecutorService, PythonInterpreter
10+
from torchrl.services import get_services
11+
12+
13+
@pytest.fixture
14+
def ray_init():
15+
"""Initialize Ray for tests."""
16+
if not ray.is_initialized():
17+
ray.init()
18+
yield
19+
if ray.is_initialized():
20+
ray.shutdown()
21+
22+
23+
class TestPythonExecutorService:
24+
"""Test suite for PythonExecutorService."""
25+
26+
def test_service_initialization(self, ray_init):
27+
"""Test that the service can be created and registered."""
28+
namespace = "test_executor_init"
29+
services = get_services(backend="ray", namespace=namespace)
30+
31+
try:
32+
services.register(
33+
"python_executor",
34+
PythonExecutorService,
35+
pool_size=2,
36+
timeout=5.0,
37+
num_cpus=2,
38+
max_concurrency=2,
39+
)
40+
41+
# Verify it was registered
42+
assert "python_executor" in services
43+
44+
# Get the service
45+
executor = services["python_executor"]
46+
assert executor is not None
47+
48+
finally:
49+
services.reset()
50+
51+
def test_service_execution(self, ray_init):
52+
"""Test that the service can execute Python code."""
53+
namespace = "test_executor_exec"
54+
services = get_services(backend="ray", namespace=namespace)
55+
56+
try:
57+
services.register(
58+
"python_executor",
59+
PythonExecutorService,
60+
pool_size=2,
61+
timeout=5.0,
62+
num_cpus=2,
63+
max_concurrency=2,
64+
)
65+
66+
executor = services["python_executor"]
67+
68+
# Execute simple code
69+
code = """
70+
x = 10
71+
y = 20
72+
result = x + y
73+
print(f"Result: {result}")
74+
"""
75+
result = ray.get(executor.execute.remote(code), timeout=2)
76+
77+
assert result["success"] is True
78+
assert "Result: 30" in result["stdout"]
79+
assert result["returncode"] == 0
80+
81+
finally:
82+
services.reset()
83+
84+
def test_service_execution_error(self, ray_init):
85+
"""Test that the service handles execution errors."""
86+
namespace = "test_executor_error"
87+
services = get_services(backend="ray", namespace=namespace)
88+
89+
try:
90+
services.register(
91+
"python_executor",
92+
PythonExecutorService,
93+
pool_size=2,
94+
timeout=5.0,
95+
num_cpus=2,
96+
max_concurrency=2,
97+
)
98+
99+
executor = services["python_executor"]
100+
101+
# Execute code with an error
102+
code = "raise ValueError('Test error')"
103+
result = ray.get(executor.execute.remote(code), timeout=2)
104+
105+
assert result["success"] is False
106+
assert "ValueError: Test error" in result["stderr"]
107+
108+
finally:
109+
services.reset()
110+
111+
def test_multiple_executions(self, ray_init):
112+
"""Test multiple concurrent executions."""
113+
namespace = "test_executor_multi"
114+
services = get_services(backend="ray", namespace=namespace)
115+
116+
try:
117+
services.register(
118+
"python_executor",
119+
PythonExecutorService,
120+
pool_size=4,
121+
timeout=5.0,
122+
num_cpus=4,
123+
max_concurrency=4,
124+
)
125+
126+
executor = services["python_executor"]
127+
128+
# Submit multiple executions
129+
futures = []
130+
for i in range(8):
131+
code = f"print('Execution {i}')"
132+
futures.append(executor.execute.remote(code))
133+
134+
# Wait for all to complete
135+
results = ray.get(futures, timeout=5)
136+
137+
# All should succeed
138+
assert len(results) == 8
139+
for i, result in enumerate(results):
140+
assert result["success"] is True
141+
assert f"Execution {i}" in result["stdout"]
142+
143+
finally:
144+
services.reset()
145+
146+
147+
class TestPythonInterpreterWithService:
148+
"""Test suite for PythonInterpreter using the service."""
149+
150+
def test_interpreter_with_service(self, ray_init):
151+
"""Test that PythonInterpreter can use the service."""
152+
namespace = "test_interp_service"
153+
services = get_services(backend="ray", namespace=namespace)
154+
155+
try:
156+
# Register service
157+
services.register(
158+
"python_executor",
159+
PythonExecutorService,
160+
pool_size=2,
161+
timeout=5.0,
162+
num_cpus=2,
163+
max_concurrency=2,
164+
)
165+
166+
# Create interpreter with service
167+
interpreter = PythonInterpreter(
168+
services="ray",
169+
service_name="python_executor",
170+
namespace=namespace,
171+
)
172+
173+
# Verify it's using the service
174+
assert interpreter.python_service is not None
175+
assert interpreter.processes is None
176+
assert interpreter.services == "ray"
177+
178+
finally:
179+
services.reset()
180+
181+
def test_interpreter_without_service(self):
182+
"""Test that PythonInterpreter works without service."""
183+
# Create interpreter without service
184+
interpreter = PythonInterpreter(
185+
services=None,
186+
persistent=True,
187+
)
188+
189+
# Verify it's using local processes
190+
assert interpreter.python_service is None
191+
assert interpreter.processes is not None
192+
assert interpreter.services is None
193+
194+
def test_interpreter_execution_with_service(self, ray_init):
195+
"""Test code execution through interpreter with service."""
196+
namespace = "test_interp_exec"
197+
services = get_services(backend="ray", namespace=namespace)
198+
199+
try:
200+
# Register service
201+
services.register(
202+
"python_executor",
203+
PythonExecutorService,
204+
pool_size=2,
205+
timeout=5.0,
206+
num_cpus=2,
207+
max_concurrency=2,
208+
)
209+
210+
# Create interpreter with service
211+
interpreter = PythonInterpreter(services="ray", namespace=namespace)
212+
213+
# Execute code
214+
code = "print('Hello from service')"
215+
result = interpreter._execute_python_code(code, 0)
216+
217+
assert result["success"] is True
218+
assert "Hello from service" in result["stdout"]
219+
220+
finally:
221+
services.reset()
222+
223+
def test_interpreter_clone_preserves_service(self, ray_init):
224+
"""Test that cloning an interpreter preserves service settings."""
225+
namespace = "test_interp_clone"
226+
services = get_services(backend="ray", namespace=namespace)
227+
228+
try:
229+
# Register service
230+
services.register(
231+
"python_executor",
232+
PythonExecutorService,
233+
pool_size=2,
234+
timeout=5.0,
235+
num_cpus=2,
236+
max_concurrency=2,
237+
)
238+
239+
# Create interpreter with service
240+
interpreter1 = PythonInterpreter(
241+
services="ray",
242+
service_name="python_executor",
243+
namespace=namespace,
244+
)
245+
246+
# Clone it
247+
interpreter2 = interpreter1.clone()
248+
249+
# Verify clone has same settings
250+
assert interpreter2.services == "ray"
251+
assert interpreter2.service_name == "python_executor"
252+
assert interpreter2.python_service is not None
253+
254+
finally:
255+
services.reset()
256+
257+
def test_interpreter_invalid_service_backend(self):
258+
"""Test that invalid service backend raises error."""
259+
with pytest.raises(ValueError, match="Invalid services backend"):
260+
PythonInterpreter(services="invalid")
261+
262+
def test_interpreter_missing_service(self, ray_init):
263+
"""Test that missing service raises error."""
264+
with pytest.raises(RuntimeError, match="Failed to get Ray service"):
265+
PythonInterpreter(services="ray", service_name="nonexistent_service")
266+
267+
268+
if __name__ == "__main__":
269+
pytest.main([__file__, "-v"])
File renamed without changes.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""Test fixtures for service tests that need to be importable by Ray workers."""
7+
8+
from typing import Any
9+
10+
11+
class SimpleService:
12+
"""A simple service for testing."""
13+
14+
def __init__(self, value: int = 0):
15+
self.value = value
16+
17+
def get_value(self):
18+
return self.value
19+
20+
def set_value(self, value: int):
21+
self.value = value
22+
23+
def getattr(self, val: str, **kwargs) -> Any:
24+
if "default" in kwargs:
25+
default = kwargs["default"]
26+
return getattr(self, val, default)
27+
return getattr(self, val)
28+
29+
30+
class TokenizerService:
31+
"""Mock tokenizer service."""
32+
33+
def __init__(self, vocab_size: int = 1000):
34+
self.vocab_size = vocab_size
35+
36+
def encode(self, text: str):
37+
return [hash(c) % self.vocab_size for c in text]
38+
39+
def decode(self, tokens: list):
40+
return "".join([str(t) for t in tokens])
41+
42+
def getattr(self, val: str, **kwargs) -> Any:
43+
if "default" in kwargs:
44+
default = kwargs["default"]
45+
return getattr(self, val, default)
46+
return getattr(self, val)

torchrl/testing/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
particularly for distributed and Ray-based tests that require importable classes.
1010
"""
1111

12+
from torchrl.testing.llm_mocks import (
13+
MockTransformerConfig,
14+
MockTransformerModel,
15+
MockTransformerOutput,
16+
)
1217
from torchrl.testing.ray_helpers import (
1318
WorkerTransformerDoubleBuffer,
1419
WorkerTransformerNCCL,
@@ -21,4 +26,7 @@
2126
"WorkerTransformerNCCL",
2227
"WorkerVLLMDoubleBuffer",
2328
"WorkerTransformerDoubleBuffer",
29+
"MockTransformerConfig",
30+
"MockTransformerModel",
31+
"MockTransformerOutput",
2432
]

0 commit comments

Comments
 (0)