File tree Expand file tree Collapse file tree 8 files changed +10
-10
lines changed
src/tensorcontainer/tensor_distribution
tests/tensor_distribution Expand file tree Collapse file tree 8 files changed +10
-10
lines changed Original file line number Diff line number Diff line change @@ -28,7 +28,7 @@ def __init__(
2828 # check here to safely derive shape and device from the data tensor
2929 # before calling the parent constructor
3030 if data is None :
31- raise RuntimeError ("Either 'probs' or 'logits' must be provided." )
31+ raise ValueError ("Either 'probs' or 'logits' must be provided." )
3232
3333 # Store the parameters in annotated attributes before calling super().__init__()
3434 # This is required because super().__init__() calls self.dist() which needs these attributes
Original file line number Diff line number Diff line change @@ -25,7 +25,7 @@ def __init__(
2525 ):
2626 data = probs if probs is not None else logits
2727 if data is None :
28- raise RuntimeError ("Either 'probs' or 'logits' must be provided." )
28+ raise ValueError ("Either 'probs' or 'logits' must be provided." )
2929
3030 self ._total_count = total_count
3131 self ._probs = probs
@@ -83,4 +83,4 @@ def param_shape(self) -> Size:
8383 elif self ._logits is not None :
8484 return self ._logits .shape
8585 else :
86- raise RuntimeError ("Neither probs nor logits are available." )
86+ raise ValueError ("Neither probs nor logits are available." )
Original file line number Diff line number Diff line change @@ -26,7 +26,7 @@ def __init__(
2626 # check here to safely derive shape and device from the data tensor
2727 # before calling the parent constructor
2828 if data is None :
29- raise RuntimeError ("Either 'probs' or 'logits' must be provided." )
29+ raise ValueError ("Either 'probs' or 'logits' must be provided." )
3030
3131 # Store the parameters in annotated attributes before calling super().__init__()
3232 # This is required because super().__init__() calls self.dist() which needs these attributes
Original file line number Diff line number Diff line change @@ -27,13 +27,13 @@ def __init__(
2727 That is why we only allowed scalar temperatures for now.
2828 """
2929 if temperature .ndim > 0 :
30- raise RuntimeError (
30+ raise ValueError (
3131 "Expected scalar temperature tensor. This is because of a bug in torch: https://github.com/pytorch/pytorch/issues/37162"
3232 )
3333
3434 data = probs if probs is not None else logits
3535 if data is None :
36- raise RuntimeError ("Either 'probs' or 'logits' must be provided." )
36+ raise ValueError ("Either 'probs' or 'logits' must be provided." )
3737
3838 # Determine shape and device from data (probs or logits)
3939 shape = data .shape [:- 1 ]
Original file line number Diff line number Diff line change @@ -15,7 +15,7 @@ class TestTensorCategoricalInitialization:
1515 def test_init_no_params_raises_error (self ):
1616 """A ValueError should be raised when neither probs nor logits are provided."""
1717 with pytest .raises (
18- RuntimeError , match = "Either 'probs' or 'logits' must be provided."
18+ ValueError , match = "Either 'probs' or 'logits' must be provided."
1919 ):
2020 TensorCategorical ()
2121
Original file line number Diff line number Diff line change @@ -17,7 +17,7 @@ class TestTensorOneHotCategoricalInitialization:
1717 def test_init_no_params_raises_error (self ):
1818 """A ValueError should be raised when neither probs nor logits are provided."""
1919 with pytest .raises (
20- RuntimeError , match = "Either 'probs' or 'logits' must be provided."
20+ ValueError , match = "Either 'probs' or 'logits' must be provided."
2121 ):
2222 TensorOneHotCategorical ()
2323
Original file line number Diff line number Diff line change @@ -17,7 +17,7 @@ class TestTensorOneHotCategoricalStraightThroughInitialization:
1717 def test_init_no_params_raises_error (self ):
1818 """A ValueError should be raised when neither probs nor logits are provided."""
1919 with pytest .raises (
20- RuntimeError , match = "Either 'probs' or 'logits' must be provided."
20+ ValueError , match = "Either 'probs' or 'logits' must be provided."
2121 ):
2222 TensorOneHotCategoricalStraightThrough ()
2323
Original file line number Diff line number Diff line change @@ -45,7 +45,7 @@ def test_valid_initialization(
4545 def test_init_no_params_raises_error (self ):
4646 """A RuntimeError should be raised when neither probs nor logits are provided."""
4747 with pytest .raises (
48- RuntimeError , match = "Either 'probs' or 'logits' must be provided."
48+ ValueError , match = "Either 'probs' or 'logits' must be provided."
4949 ):
5050 TensorRelaxedOneHotCategorical (temperature = torch .tensor (0.5 ))
5151
You can’t perform that action at this time.
0 commit comments