Skip to content

Commit 54939cc

Browse files
author
Tim Joseph
committed
refactor(tensor_distribution): remove redundant method implementations from TensorFisherSnedecor
Remove duplicate method implementations that are already provided by the base TensorDistribution class, including mean, variance, mode, support, has_rsample, entropy, log_prob, cdf, icdf, sample, rsample, enumerate_support, __repr__, and __eq__. Also update import to use local utils module instead of torch.distributions.utils.
1 parent 393605f commit 54939cc

File tree

1 file changed

+1
-49
lines changed

1 file changed

+1
-49
lines changed

src/tensorcontainer/tensor_distribution/fisher_snedecor.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
from typing import Any
44

5-
import torch
65
from torch import Tensor
76
from torch.distributions import FisherSnedecor as TorchFisherSnedecor
8-
from torch.distributions.utils import broadcast_all
7+
from .utils import broadcast_all
98

109
from .base import TensorDistribution
1110

@@ -32,50 +31,3 @@ def _unflatten_distribution(
3231
df2=attributes["_df2"],
3332
validate_args=attributes.get("_validate_args"),
3433
)
35-
36-
@property
37-
def mean(self) -> Tensor:
38-
return self.dist().mean
39-
40-
@property
41-
def variance(self) -> Tensor:
42-
return self.dist().variance
43-
44-
@property
45-
def mode(self) -> Tensor:
46-
return self.dist().mode
47-
48-
@property
49-
def support(self):
50-
return self.dist().support
51-
52-
@property
53-
def has_rsample(self):
54-
return self.dist().has_rsample
55-
56-
def entropy(self) -> Tensor:
57-
return self.dist().entropy()
58-
59-
def log_prob(self, value: Tensor) -> Tensor:
60-
return self.dist().log_prob(value)
61-
62-
def cdf(self, value: Tensor) -> Tensor:
63-
return self.dist().cdf(value)
64-
65-
def icdf(self, value: Tensor) -> Tensor:
66-
return self.dist().icdf(value)
67-
68-
def sample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
69-
return self.dist().sample(sample_shape)
70-
71-
def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
72-
return self.dist().rsample(sample_shape)
73-
74-
def enumerate_support(self, expand: bool = True) -> Tensor:
75-
return self.dist().enumerate_support(expand)
76-
77-
def __repr__(self):
78-
return self.dist().__repr__()
79-
80-
def __eq__(self, other):
81-
return self.dist().__eq__(other)

0 commit comments

Comments
 (0)