Skip to content

Commit d61df07

Browse files
author
Tim Joseph
committed
fix(tensor_distribution): replace RuntimeError with ValueError for parameter validation
Replace RuntimeError with more semantically appropriate ValueError for parameter validation errors across tensor distribution classes.
1 parent 54939cc commit d61df07

File tree

8 files changed

+10
-10
lines changed

8 files changed

+10
-10
lines changed

src/tensorcontainer/tensor_distribution/categorical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(
2828
# check here to safely derive shape and device from the data tensor
2929
# before calling the parent constructor
3030
if data is None:
31-
raise RuntimeError("Either 'probs' or 'logits' must be provided.")
31+
raise ValueError("Either 'probs' or 'logits' must be provided.")
3232

3333
# Store the parameters in annotated attributes before calling super().__init__()
3434
# This is required because super().__init__() calls self.dist() which needs these attributes

src/tensorcontainer/tensor_distribution/multinomial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
):
2626
data = probs if probs is not None else logits
2727
if data is None:
28-
raise RuntimeError("Either 'probs' or 'logits' must be provided.")
28+
raise ValueError("Either 'probs' or 'logits' must be provided.")
2929

3030
self._total_count = total_count
3131
self._probs = probs
@@ -83,4 +83,4 @@ def param_shape(self) -> Size:
8383
elif self._logits is not None:
8484
return self._logits.shape
8585
else:
86-
raise RuntimeError("Neither probs nor logits are available.")
86+
raise ValueError("Neither probs nor logits are available.")

src/tensorcontainer/tensor_distribution/one_hot_categorical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(
2626
# check here to safely derive shape and device from the data tensor
2727
# before calling the parent constructor
2828
if data is None:
29-
raise RuntimeError("Either 'probs' or 'logits' must be provided.")
29+
raise ValueError("Either 'probs' or 'logits' must be provided.")
3030

3131
# Store the parameters in annotated attributes before calling super().__init__()
3232
# This is required because super().__init__() calls self.dist() which needs these attributes

src/tensorcontainer/tensor_distribution/relaxed_one_hot_categorical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@ def __init__(
2727
That is why we only allowed scalar temperatures for now.
2828
"""
2929
if temperature.ndim > 0:
30-
raise RuntimeError(
30+
raise ValueError(
3131
"Expected scalar temperature tensor. This is because of a bug in torch: https://github.com/pytorch/pytorch/issues/37162"
3232
)
3333

3434
data = probs if probs is not None else logits
3535
if data is None:
36-
raise RuntimeError("Either 'probs' or 'logits' must be provided.")
36+
raise ValueError("Either 'probs' or 'logits' must be provided.")
3737

3838
# Determine shape and device from data (probs or logits)
3939
shape = data.shape[:-1]

tests/tensor_distribution/test_categorical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class TestTensorCategoricalInitialization:
1515
def test_init_no_params_raises_error(self):
1616
"""A ValueError should be raised when neither probs nor logits are provided."""
1717
with pytest.raises(
18-
RuntimeError, match="Either 'probs' or 'logits' must be provided."
18+
ValueError, match="Either 'probs' or 'logits' must be provided."
1919
):
2020
TensorCategorical()
2121

tests/tensor_distribution/test_one_hot_categorical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class TestTensorOneHotCategoricalInitialization:
1717
def test_init_no_params_raises_error(self):
1818
"""A ValueError should be raised when neither probs nor logits are provided."""
1919
with pytest.raises(
20-
RuntimeError, match="Either 'probs' or 'logits' must be provided."
20+
ValueError, match="Either 'probs' or 'logits' must be provided."
2121
):
2222
TensorOneHotCategorical()
2323

tests/tensor_distribution/test_one_hot_categorical_straight_through.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class TestTensorOneHotCategoricalStraightThroughInitialization:
1717
def test_init_no_params_raises_error(self):
1818
"""A ValueError should be raised when neither probs nor logits are provided."""
1919
with pytest.raises(
20-
RuntimeError, match="Either 'probs' or 'logits' must be provided."
20+
ValueError, match="Either 'probs' or 'logits' must be provided."
2121
):
2222
TensorOneHotCategoricalStraightThrough()
2323

tests/tensor_distribution/test_relaxed_one_hot_categorical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_valid_initialization(
4545
def test_init_no_params_raises_error(self):
4646
"""A RuntimeError should be raised when neither probs nor logits are provided."""
4747
with pytest.raises(
48-
RuntimeError, match="Either 'probs' or 'logits' must be provided."
48+
ValueError, match="Either 'probs' or 'logits' must be provided."
4949
):
5050
TensorRelaxedOneHotCategorical(temperature=torch.tensor(0.5))
5151

0 commit comments

Comments
 (0)