Skip to content

Commit 60e7179

Browse files
author
Vincent Moens
committed
[BE] No warning if user sets the log_prob_key explicitly and only one variable is sampled from the ProbTDMod
ghstack-source-id: ccc966d Pull Request resolved: #1209
1 parent bae04ce commit 60e7179

File tree

3 files changed

+80
-5
lines changed

3 files changed

+80
-5
lines changed

tensordict/nn/distributions/composite.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ def __init__(
130130
dist_params = params.get(name)
131131
kwargs = extra_kwargs.get(name, {})
132132
if dist_params is None:
133-
raise KeyError(f"no param {name} found in params with keys {params.keys(True, True)}")
133+
raise KeyError(
134+
f"no param {name} found in params with keys {params.keys(True, True)}"
135+
)
134136
dist = dist_class(**dist_params, **kwargs)
135137
dists[write_name] = dist
136138
self.dists = dists

tensordict/nn/probabilistic.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,9 @@ class ProbabilisticTensorDictModule(TensorDictModuleBase):
328328
329329
"""
330330

331+
# To be removed in v0.9
332+
_trigger_warning_lpk: bool = False
333+
331334
def __init__(
332335
self,
333336
in_keys: NestedKey | List[NestedKey] | Dict[str, NestedKey],
@@ -396,9 +399,11 @@ def __init__(
396399
"composite_lp_aggregate is set to True but log_prob_keys were passed. "
397400
"When composite_lp_aggregate() returns ``True``, log_prob_key must be used instead."
398401
)
402+
self._trigger_warning_lpk = len(self._out_keys) > 1
399403
if log_prob_key is None:
400404
if composite_lp_aggregate(nowarn=True):
401405
log_prob_key = "sample_log_prob"
406+
self._trigger_warning_lpk = True
402407
elif len(out_keys) == 1:
403408
log_prob_key = _add_suffix(out_keys[0], "_log_prob")
404409
elif len(out_keys) > 1 and not composite_lp_aggregate(nowarn=True):
@@ -451,13 +456,15 @@ def log_prob_key(self):
451456
f"unless there is one and only one element in log_prob_keys (got log_prob_keys={self.log_prob_keys}). "
452457
f"When composite_lp_aggregate() returns ``False``, try to use {type(self).__name__}.log_prob_keys instead."
453458
)
454-
if _composite_lp_aggregate.get_mode() is None:
459+
if _composite_lp_aggregate.get_mode() is None and self._trigger_warning_lpk:
455460
warnings.warn(
456461
f"You are querying the log-probability key of a {type(self).__name__} where the "
457-
f"composite_lp_aggregate has not been set. "
462+
f"composite_lp_aggregate has not been set and the log-prob key has not been chosen. "
458463
f"Currently, it is assumed that composite_lp_aggregate() will return True: the log-probs will be aggregated "
459-
f"in a {self._log_prob_key} entry. From v0.9, this behaviour will be changed and individual log-probs will "
460-
f"be written in `('path', 'to', 'leaf', '<sample_name>_log_prob')`. To prepare for this change, "
464+
f"in a {self._log_prob_key} entry. "
465+
f"From v0.9, this behaviour will be changed and individual log-probs will "
466+
f"be written in `('path', 'to', 'leaf', '<sample_name>_log_prob')`. "
467+
f"To prepare for this change, "
461468
f"call `set_composite_lp_aggregate(mode: bool).set()` at the beginning of your script (or set the "
462469
f"COMPOSITE_LP_AGGREGATE env variable). Use mode=True "
463470
f"to keep the current behaviour, and mode=False to use per-leaf log-probs.",

test/test_nn.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2270,6 +2270,72 @@ def test_index_prob_seq(self):
22702270
assert isinstance(seq[:2], ProbabilisticTensorDictSequential)
22712271
assert isinstance(seq[-2:], ProbabilisticTensorDictSequential)
22722272

2273+
def test_no_warning_single_key(self):
2274+
# Check that there is no warning if the number of out keys is 1 and sample log prob is set
2275+
torch.manual_seed(0)
2276+
with set_composite_lp_aggregate(None):
2277+
mod = ProbabilisticTensorDictModule(
2278+
in_keys=["loc", "scale"],
2279+
distribution_class=torch.distributions.Normal,
2280+
out_keys=[("an", "action")],
2281+
log_prob_key="sample_log_prob",
2282+
return_log_prob=True,
2283+
)
2284+
td = TensorDict(loc=torch.randn(()), scale=torch.rand(()))
2285+
mod(td.copy())
2286+
mod.log_prob(mod(td.copy()))
2287+
mod.log_prob_key
2288+
2289+
# Don't set the key and trigger the warning
2290+
mod = ProbabilisticTensorDictModule(
2291+
in_keys=["loc", "scale"],
2292+
distribution_class=torch.distributions.Normal,
2293+
out_keys=[("an", "action")],
2294+
return_log_prob=True,
2295+
)
2296+
with pytest.warns(
2297+
DeprecationWarning, match="You are querying the log-probability key"
2298+
):
2299+
mod(td.copy())
2300+
mod.log_prob(mod(td.copy()))
2301+
mod.log_prob_key
2302+
2303+
# add another variable, and trigger the warning
2304+
mod = ProbabilisticTensorDictModule(
2305+
in_keys=["params"],
2306+
distribution_class=CompositeDistribution,
2307+
distribution_kwargs={
2308+
"distribution_map": {
2309+
"dirich": torch.distributions.Dirichlet,
2310+
"categ": torch.distributions.Categorical,
2311+
}
2312+
},
2313+
out_keys=[("dirich", "categ")],
2314+
return_log_prob=True,
2315+
)
2316+
with pytest.warns(
2317+
DeprecationWarning, match="You are querying the log-probability key"
2318+
), pytest.warns(
2319+
DeprecationWarning,
2320+
match="Composite log-prob aggregation wasn't defined explicitly",
2321+
):
2322+
td = TensorDict(
2323+
params=TensorDict(
2324+
dirich=TensorDict(
2325+
concentration=torch.rand(
2326+
(
2327+
10,
2328+
11,
2329+
)
2330+
)
2331+
),
2332+
categ=TensorDict(logits=torch.rand((5,))),
2333+
)
2334+
)
2335+
mod(td.copy())
2336+
mod.log_prob(mod(td.copy()))
2337+
mod.log_prob_key
2338+
22732339

22742340
class TestEnsembleModule:
22752341
def test_init(self):

0 commit comments

Comments
 (0)