Skip to content

Commit eb8a885

Browse files
committed
[BugFix,Test,Refactor] Refactor tests
ghstack-source-id: 24c338b Pull-Request: #3232
1 parent 9ca0e40 commit eb8a885

File tree

10 files changed

+357
-23
lines changed

10 files changed

+357
-23
lines changed

test/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,27 @@ def maybe_fork_ParallelEnv(request):
155155
):
156156
return functools.partial(ParallelEnv, mp_start_method="fork")
157157
return ParallelEnv
158+
159+
160+
# LLM testing fixtures
161+
@pytest.fixture
162+
def mock_transformer_model():
163+
"""Fixture that provides a mock transformer model factory."""
164+
from torchrl.testing import MockTransformerModel
165+
166+
def _make_model(
167+
vocab_size: int = 1024, device: torch.device | str | int = "cpu"
168+
) -> MockTransformerModel:
169+
"""Make a mock transformer model."""
170+
device = torch.device(device)
171+
return MockTransformerModel(vocab_size, device)
172+
173+
return _make_model
174+
175+
176+
@pytest.fixture
177+
def mock_tokenizer():
178+
"""Fixture that provides a mock tokenizer."""
179+
from transformers import AutoTokenizer
180+
181+
return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")

test/llm/libs/test_mlgym.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from torchrl.envs.llm import make_mlgym
1717
from torchrl.modules.llm import TransformersWrapper
1818

19+
pytest.importorskip("mlgym")
20+
1921

2022
class TestMLGYM:
2123
def test_mlgym_specs(self):
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
"""Tests for PythonExecutorService with Ray service registry."""
2+
from __future__ import annotations
3+
4+
import pytest
5+
6+
# Skip all tests if Ray is not available
7+
pytest.importorskip("ray")
8+
9+
import ray
10+
from torchrl.envs.llm.transforms import PythonExecutorService, PythonInterpreter
11+
from torchrl.services import get_services
12+
13+
14+
@pytest.fixture
15+
def ray_init():
16+
"""Initialize Ray for tests."""
17+
if not ray.is_initialized():
18+
ray.init()
19+
yield
20+
if ray.is_initialized():
21+
ray.shutdown()
22+
23+
24+
class TestPythonExecutorService:
25+
"""Test suite for PythonExecutorService."""
26+
27+
def test_service_initialization(self, ray_init):
28+
"""Test that the service can be created and registered."""
29+
namespace = "test_executor_init"
30+
services = get_services(backend="ray", namespace=namespace)
31+
32+
try:
33+
services.register(
34+
"python_executor",
35+
PythonExecutorService,
36+
pool_size=2,
37+
timeout=5.0,
38+
num_cpus=2,
39+
max_concurrency=2,
40+
)
41+
42+
# Verify it was registered
43+
assert "python_executor" in services
44+
45+
# Get the service
46+
executor = services["python_executor"]
47+
assert executor is not None
48+
49+
finally:
50+
services.reset()
51+
52+
def test_service_execution(self, ray_init):
53+
"""Test that the service can execute Python code."""
54+
namespace = "test_executor_exec"
55+
services = get_services(backend="ray", namespace=namespace)
56+
57+
try:
58+
services.register(
59+
"python_executor",
60+
PythonExecutorService,
61+
pool_size=2,
62+
timeout=5.0,
63+
num_cpus=2,
64+
max_concurrency=2,
65+
)
66+
67+
executor = services["python_executor"]
68+
69+
# Execute simple code
70+
code = """
71+
x = 10
72+
y = 20
73+
result = x + y
74+
print(f"Result: {result}")
75+
"""
76+
result = ray.get(executor.execute.remote(code), timeout=2)
77+
78+
assert result["success"] is True
79+
assert "Result: 30" in result["stdout"]
80+
assert result["returncode"] == 0
81+
82+
finally:
83+
services.reset()
84+
85+
def test_service_execution_error(self, ray_init):
86+
"""Test that the service handles execution errors."""
87+
namespace = "test_executor_error"
88+
services = get_services(backend="ray", namespace=namespace)
89+
90+
try:
91+
services.register(
92+
"python_executor",
93+
PythonExecutorService,
94+
pool_size=2,
95+
timeout=5.0,
96+
num_cpus=2,
97+
max_concurrency=2,
98+
)
99+
100+
executor = services["python_executor"]
101+
102+
# Execute code with an error
103+
code = "raise ValueError('Test error')"
104+
result = ray.get(executor.execute.remote(code), timeout=2)
105+
106+
assert result["success"] is False
107+
assert "ValueError: Test error" in result["stderr"]
108+
109+
finally:
110+
services.reset()
111+
112+
def test_multiple_executions(self, ray_init):
113+
"""Test multiple concurrent executions."""
114+
namespace = "test_executor_multi"
115+
services = get_services(backend="ray", namespace=namespace)
116+
117+
try:
118+
services.register(
119+
"python_executor",
120+
PythonExecutorService,
121+
pool_size=4,
122+
timeout=5.0,
123+
num_cpus=4,
124+
max_concurrency=4,
125+
)
126+
127+
executor = services["python_executor"]
128+
129+
# Submit multiple executions
130+
futures = []
131+
for i in range(8):
132+
code = f"print('Execution {i}')"
133+
futures.append(executor.execute.remote(code))
134+
135+
# Wait for all to complete
136+
results = ray.get(futures, timeout=5)
137+
138+
# All should succeed
139+
assert len(results) == 8
140+
for i, result in enumerate(results):
141+
assert result["success"] is True
142+
assert f"Execution {i}" in result["stdout"]
143+
144+
finally:
145+
services.reset()
146+
147+
148+
class TestPythonInterpreterWithService:
149+
"""Test suite for PythonInterpreter using the service."""
150+
151+
def test_interpreter_with_service(self, ray_init):
152+
"""Test that PythonInterpreter can use the service."""
153+
namespace = "test_interp_service"
154+
services = get_services(backend="ray", namespace=namespace)
155+
156+
try:
157+
# Register service
158+
services.register(
159+
"python_executor",
160+
PythonExecutorService,
161+
pool_size=2,
162+
timeout=5.0,
163+
num_cpus=2,
164+
max_concurrency=2,
165+
)
166+
167+
# Create interpreter with service
168+
interpreter = PythonInterpreter(
169+
services="ray",
170+
service_name="python_executor",
171+
namespace=namespace,
172+
)
173+
174+
# Verify it's using the service
175+
assert interpreter.python_service is not None
176+
assert interpreter.processes is None
177+
assert interpreter.services == "ray"
178+
179+
finally:
180+
services.reset()
181+
182+
def test_interpreter_without_service(self):
183+
"""Test that PythonInterpreter works without service."""
184+
# Create interpreter without service
185+
interpreter = PythonInterpreter(
186+
services=None,
187+
persistent=True,
188+
)
189+
190+
# Verify it's using local processes
191+
assert interpreter.python_service is None
192+
assert interpreter.processes is not None
193+
assert interpreter.services is None
194+
195+
def test_interpreter_execution_with_service(self, ray_init):
196+
"""Test code execution through interpreter with service."""
197+
namespace = "test_interp_exec"
198+
services = get_services(backend="ray", namespace=namespace)
199+
200+
try:
201+
# Register service
202+
services.register(
203+
"python_executor",
204+
PythonExecutorService,
205+
pool_size=2,
206+
timeout=5.0,
207+
num_cpus=2,
208+
max_concurrency=2,
209+
)
210+
211+
# Create interpreter with service
212+
interpreter = PythonInterpreter(services="ray", namespace=namespace)
213+
214+
# Execute code
215+
code = "print('Hello from service')"
216+
result = interpreter._execute_python_code(code, 0)
217+
218+
assert result["success"] is True
219+
assert "Hello from service" in result["stdout"]
220+
221+
finally:
222+
services.reset()
223+
224+
def test_interpreter_clone_preserves_service(self, ray_init):
225+
"""Test that cloning an interpreter preserves service settings."""
226+
namespace = "test_interp_clone"
227+
services = get_services(backend="ray", namespace=namespace)
228+
229+
try:
230+
# Register service
231+
services.register(
232+
"python_executor",
233+
PythonExecutorService,
234+
pool_size=2,
235+
timeout=5.0,
236+
num_cpus=2,
237+
max_concurrency=2,
238+
)
239+
240+
# Create interpreter with service
241+
interpreter1 = PythonInterpreter(
242+
services="ray",
243+
service_name="python_executor",
244+
namespace=namespace,
245+
)
246+
247+
# Clone it
248+
interpreter2 = interpreter1.clone()
249+
250+
# Verify clone has same settings
251+
assert interpreter2.services == "ray"
252+
assert interpreter2.service_name == "python_executor"
253+
assert interpreter2.python_service is not None
254+
255+
finally:
256+
services.reset()
257+
258+
def test_interpreter_invalid_service_backend(self):
259+
"""Test that invalid service backend raises error."""
260+
with pytest.raises(ValueError, match="Invalid services backend"):
261+
PythonInterpreter(services="invalid")
262+
263+
def test_interpreter_missing_service(self, ray_init):
264+
"""Test that missing service raises error."""
265+
with pytest.raises(RuntimeError, match="Failed to get Ray service"):
266+
PythonInterpreter(services="ray", service_name="nonexistent_service")
267+
268+
269+
if __name__ == "__main__":
270+
pytest.main([__file__, "-v"])

test/test_services.py renamed to test/services/test_services.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from __future__ import annotations
7+
68
import pytest
79

810
pytest.importorskip("ray")
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
"""Test fixtures for service tests that need to be importable by Ray workers."""
7+
8+
from __future__ import annotations
9+
10+
from typing import Any
11+
12+
13+
class SimpleService:
14+
"""A simple service for testing."""
15+
16+
def __init__(self, value: int = 0):
17+
self.value = value
18+
19+
def get_value(self):
20+
return self.value
21+
22+
def set_value(self, value: int):
23+
self.value = value
24+
25+
def getattr(self, val: str, **kwargs) -> Any:
26+
if "default" in kwargs:
27+
default = kwargs["default"]
28+
return getattr(self, val, default)
29+
return getattr(self, val)
30+
31+
32+
class TokenizerService:
33+
"""Mock tokenizer service."""
34+
35+
def __init__(self, vocab_size: int = 1000):
36+
self.vocab_size = vocab_size
37+
38+
def encode(self, text: str):
39+
return [hash(c) % self.vocab_size for c in text]
40+
41+
def decode(self, tokens: list):
42+
return "".join([str(t) for t in tokens])
43+
44+
def getattr(self, val: str, **kwargs) -> Any:
45+
if "default" in kwargs:
46+
default = kwargs["default"]
47+
return getattr(self, val, default)
48+
return getattr(self, val)

torchrl/services/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
>>> tokenizer = services["tokenizer"]
2121
>>> result = tokenizer.encode.remote(text)
2222
"""
23+
from __future__ import annotations
2324

2425
from torchrl.services.base import ServiceBase
2526
from torchrl.services.ray_service import RayService

torchrl/services/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
56

67
from abc import ABC, abstractmethod
78
from typing import Any

torchrl/services/ray_service.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
56

67
from typing import Any
78

0 commit comments

Comments
 (0)