Skip to content

Commit 55fd89e

Browse files
authored
Merge pull request #17 from mctigger/tensor-symlog
feat(tensor_distribution): add TensorSymLog distribution
2 parents 01babde + 05c9d4b commit 55fd89e

File tree

4 files changed

+378
-1
lines changed

4 files changed

+378
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "tensorcontainer"
7-
version = "0.7.0"
7+
version = "0.7.1"
88
description = "TensorDict-like functionality for PyTorch with PyTree compatibility and torch.compile support"
99
authors = [{name="Tim Joseph", email="tim@mctigger.com"}]
1010
license = {text = "MIT"}

src/tensorcontainer/tensor_distribution/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from .relaxed_one_hot_categorical import TensorRelaxedOneHotCategorical
3636
from .soft_bernoulli import TensorSoftBernoulli
3737
from .student_t import TensorStudentT
38+
from .symlog import TensorSymLog
3839
from .tanh_normal import TensorTanhNormal
3940
from .transformed_distribution import TransformedDistribution
4041
from .truncated_normal import TensorTruncatedNormal
@@ -81,6 +82,7 @@
8182
"TensorRelaxedOneHotCategorical",
8283
"TensorSoftBernoulli",
8384
"TensorStudentT",
85+
"TensorSymLog",
8486
"TensorTanhNormal",
8587
"TransformedDistribution",
8688
"TensorTruncatedNormal",
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from __future__ import annotations
2+
3+
from typing import Any
4+
5+
from torch import Tensor
6+
from torch.distributions import Distribution
7+
from torch.types import Number
8+
9+
from .base import TensorDistribution
10+
from .utils import broadcast_all
11+
12+
13+
class TensorSymLog(TensorDistribution):
14+
"""Tensor-aware SymLog distribution.
15+
16+
Creates a SymLog distribution parameterized by `loc` (mean) and `scale` (standard deviation).
17+
This distribution transforms a Normal distribution with a symexp transform, which is useful
18+
for modeling data with a wide dynamic range where the data can be both positive and negative.
19+
20+
Args:
21+
loc: Mean of the base Normal distribution.
22+
scale: Standard deviation of the base Normal distribution. Must be positive.
23+
validate_args: Whether to validate the arguments. Defaults to None.
24+
25+
Note:
26+
The SymLog distribution is useful for modeling data with a wide dynamic range,
27+
where the data can be both positive and negative, and can have values close to zero.
28+
The symlog transform compresses large values and expands small values, making the
29+
distribution more stable for optimization.
30+
"""
31+
32+
# Annotated tensor parameters
33+
_loc: Tensor
34+
_scale: Tensor
35+
36+
def __init__(
37+
self,
38+
loc: Number | Tensor,
39+
scale: Number | Tensor,
40+
validate_args: bool | None = None,
41+
) -> None:
42+
self._loc, self._scale = broadcast_all(loc, scale)
43+
44+
shape = self._loc.shape
45+
device = self._loc.device
46+
47+
super().__init__(shape, device, validate_args)
48+
49+
@classmethod
50+
def _unflatten_distribution(
51+
cls,
52+
attributes: dict[str, Any],
53+
) -> TensorSymLog:
54+
return cls(
55+
loc=attributes["_loc"],
56+
scale=attributes["_scale"],
57+
validate_args=attributes.get("_validate_args"),
58+
)
59+
60+
def dist(self) -> Distribution:
61+
"""Return the underlying SymLogDistribution instance."""
62+
from tensorcontainer.distributions.symlog import SymLogDistribution
63+
64+
return SymLogDistribution(
65+
loc=self._loc,
66+
scale=self._scale,
67+
validate_args=self._validate_args,
68+
)
69+
70+
@property
71+
def loc(self) -> Tensor:
72+
"""Returns the location parameter of the distribution."""
73+
return self._loc
74+
75+
@property
76+
def scale(self) -> Tensor:
77+
"""Returns the scale parameter of the distribution."""
78+
return self._scale

0 commit comments

Comments
 (0)