Skip to content

Commit 1daa98f

Browse files
committed
Update
[ghstack-poisoned]
2 parents 15ff85b + ae87a4e commit 1daa98f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1026
-213
lines changed

.github/workflows/docs.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
build-docs:
2727
strategy:
2828
matrix:
29-
python_version: [ "3.9" ]
29+
python_version: [ "3.12" ]
3030
cuda_arch_version: [ "12.8" ]
3131
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3232
with:
@@ -60,7 +60,7 @@ jobs:
6060
bash ./miniconda.sh -b -f -p "${conda_dir}"
6161
eval "$(${conda_dir}/bin/conda shell.bash hook)"
6262
printf "* Creating a test environment\n"
63-
conda create --prefix "${env_dir}" -y python=3.9
63+
conda create --prefix "${env_dir}" -y python=3.12
6464
printf "* Activating\n"
6565
conda activate "${env_dir}"
6666

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ sphinx_design
1616
torchvision
1717
dm_control
1818
mujoco<3.3.6
19-
gym[classic_control,accept-rom-license,ale-py,atari]
19+
gymnasium[classic_control,atari]
2020
pygame
2121
tqdm
2222
ipython

docs/source/reference/config.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ Training and Optimization Configurations
507507
SparseAdamConfig
508508

509509
Logging Configurations
510-
~~~~~~~~~~~~~~~~~~~~~
510+
~~~~~~~~~~~~~~~~~~~~~~
511511

512512
.. currentmodule:: torchrl.trainers.algorithms.configs.logging
513513

docs/source/reference/envs.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1123,7 +1123,6 @@ to be able to create this other composition:
11231123
ExcludeTransform
11241124
FiniteTensorDictCheck
11251125
FlattenObservation
1126-
FlattenTensorDict
11271126
FrameSkipTransform
11281127
GrayScale
11291128
Hash

docs/source/reference/llms.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ Usage
118118
Adding Custom Templates
119119
^^^^^^^^^^^^^^^^^^^^^^^
120120

121-
You can add custom chat templates for new model families using the :func:`torchrl.data.llm.chat.add_chat_template` function.
121+
You can add custom chat templates for new model families using the :func:`torchrl.data.llm.add_chat_template` function.
122122

123-
.. autofunction:: torchrl.data.llm.chat.add_chat_template
123+
.. autofunction:: torchrl.data.llm.add_chat_template
124124

125125
Usage Examples
126126
^^^^^^^^^^^^^^
@@ -130,7 +130,7 @@ Adding a Llama Template
130130

131131
.. code-block:: python
132132
133-
>>> from torchrl.data.llm.chat import add_chat_template, History
133+
>>> from torchrl.data.llm import add_chat_template, History
134134
>>> from transformers import AutoTokenizer
135135
>>>
136136
>>> # Define the Llama chat template

docs/source/reference/utils.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
.. currentmodule:: torchrl
22

33
torchrl._utils package
4-
====================
4+
======================
55

66
Set of utility methods that are used internally by the library.
77

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,4 @@ first_party_detection = false
149149
[project.entry-points."vllm.general_plugins"]
150150
# Ensure FP32 overrides are registered in all vLLM processes (main, workers, and
151151
# the registry subprocess) before resolving model classes.
152-
fp32_overrides = "torchrl.modules.llm.backends.vllm_plugin:register_fp32_overrides"
152+
fp32_overrides = "torchrl.modules.llm.backends.vllm.vllm_plugin:register_fp32_overrides"

test/_utils_internal.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,7 @@
2121
from tensordict.nn import TensorDictModuleBase
2222
from torch import nn, vmap
2323

24-
from torchrl._utils import (
25-
implement_for,
26-
logger,
27-
logger as torchrl_logger,
28-
RL_WARNINGS,
29-
seed_generator,
30-
)
24+
from torchrl._utils import implement_for, logger, RL_WARNINGS, seed_generator
3125
from torchrl.data.utils import CloudpickleWrapper
3226
from torchrl.envs import MultiThreadedEnv, ObservationNorm
3327
from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
@@ -230,7 +224,7 @@ def f_retry(*args, **kwargs):
230224
return f(*args, **kwargs)
231225
except ExceptionToCheck as e:
232226
msg = "%s, Retrying in %d seconds..." % (str(e), mdelay)
233-
torchrl_logger.info(msg)
227+
logger.info(msg)
234228
time.sleep(mdelay)
235229
mtries -= 1
236230
try:

test/llm/test_objectives.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
from torchrl.envs.llm.transforms.kl import RetrieveLogProb
1717
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
1818
from torchrl.modules.llm.policies.common import ChatHistory, Masks, Text, Tokens
19-
from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage
19+
from torchrl.objectives.llm.grpo import (
20+
CISPO,
21+
CISPOLossOutput,
22+
GRPOLoss,
23+
GRPOLossOutput,
24+
MCAdvantage,
25+
)
2026
from torchrl.objectives.llm.sft import SFTLoss
2127

2228
_has_transformers = importlib.util.find_spec("transformers") is not None
@@ -203,7 +209,6 @@ def test_grpo(self, mock_transformer_model, dapo):
203209
loss_vals = loss_fn(data)
204210

205211
# Assertions: Check output type and structure
206-
from torchrl.objectives.llm.grpo import GRPOLossOutput
207212

208213
assert isinstance(
209214
loss_vals, GRPOLossOutput
@@ -240,6 +245,68 @@ def test_grpo(self, mock_transformer_model, dapo):
240245
0 <= loss_vals.clip_fraction <= 1
241246
), f"clip_fraction out of range: {loss_vals.clip_fraction}"
242247

248+
def test_cispo(self, mock_transformer_model):
249+
"""Test CISPO loss computation with mock models."""
250+
vocab_size = 1024
251+
device = torch.device("cpu")
252+
eps = 0.20
253+
254+
# Create mock model and wrap it
255+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
256+
actor_network = TransformersWrapper(
257+
model,
258+
generate=False,
259+
pad_output=True,
260+
input_mode="history",
261+
)
262+
263+
# Create loss module
264+
265+
loss_fn = CISPO(actor_network, clip_epsilon=eps)
266+
267+
# Create fake data
268+
data = _mock_data_grpo(vocab_size=vocab_size, device=device)
269+
270+
# Compute loss
271+
loss_vals = loss_fn(data)
272+
273+
# Assertions: Check output type and structure
274+
275+
assert isinstance(
276+
loss_vals, CISPOLossOutput
277+
), f"Expected CISPOLossOutput, got {type(loss_vals)}"
278+
279+
# Check that all expected keys are present (same as GRPO)
280+
assert hasattr(loss_vals, "loss_objective"), "Missing loss_objective"
281+
assert hasattr(loss_vals, "clip_fraction"), "Missing clip_fraction"
282+
assert hasattr(loss_vals, "kl_approx"), "Missing kl_approx"
283+
assert hasattr(loss_vals, "ESS"), "Missing ESS"
284+
assert hasattr(loss_vals, "entropy"), "Missing entropy"
285+
assert hasattr(loss_vals, "loss_entropy"), "Missing loss_entropy"
286+
287+
# Check tensor shapes (all losses should be scalars after reduction)
288+
assert (
289+
loss_vals.loss_objective.shape == ()
290+
), f"loss_objective should be scalar, got {loss_vals.loss_objective.shape}"
291+
assert (
292+
loss_vals.clip_fraction.shape == ()
293+
), f"clip_fraction should be scalar, got {loss_vals.clip_fraction.shape}"
294+
assert (
295+
loss_vals.kl_approx.shape == ()
296+
), f"kl_approx should be scalar, got {loss_vals.kl_approx.shape}"
297+
assert (
298+
loss_vals.ESS.shape == ()
299+
), f"ESS should be scalar, got {loss_vals.ESS.shape}"
300+
301+
# Check that losses are finite
302+
assert torch.isfinite(loss_vals.loss_objective), "loss_objective is not finite"
303+
assert torch.isfinite(loss_vals.ESS), "ESS is not finite"
304+
305+
# Check that clip_fraction is in valid range [0, 1]
306+
assert (
307+
0 <= loss_vals.clip_fraction <= 1
308+
), f"clip_fraction out of range: {loss_vals.clip_fraction}"
309+
243310

244311
class TestSFT:
245312
@pytest.fixture(scope="class")

test/llm/test_vllm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class TestAsyncVLLMIntegration:
4040
@pytest.mark.slow
4141
def test_vllm_api_compatibility(self, sampling_params):
4242
"""Test that AsyncVLLM supports the same inputs as vLLM.LLM.generate()."""
43-
from torchrl.modules.llm.backends.vllm_async import AsyncVLLM
43+
from torchrl.modules.llm.backends import AsyncVLLM
4444

4545
# Create AsyncVLLM service
4646
service = AsyncVLLM.from_pretrained(
@@ -113,7 +113,7 @@ def test_vllm_api_compatibility(self, sampling_params):
113113
def test_weight_updates_with_transformer(self, sampling_params):
114114
"""Test weight updates using vLLMUpdater with a real transformer model."""
115115
from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater
116-
from torchrl.modules.llm.backends.vllm_async import AsyncVLLM
116+
from torchrl.modules.llm.backends import AsyncVLLM
117117
from torchrl.modules.llm.policies.transformers_wrapper import (
118118
TransformersWrapper,
119119
)

0 commit comments

Comments
 (0)