We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 01af96b commit 9ca88f1Copy full SHA for 9ca88f1
numpyro/distributions/continuous.py
@@ -2706,8 +2706,11 @@ def sample(
2706
2707
@validate_sample
2708
def log_prob(self, value: ArrayLike) -> ArrayLike:
2709
+ log_p = -jnp.log(self.high - self.low)
2710
+ is_in_support = (value >= self.low) & (value < self.high)
2711
shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape)
- return -jnp.broadcast_to(jnp.log(self.high - self.low), shape)
2712
+ log_p = jnp.broadcast_to(log_p, shape)
2713
+ return jnp.where(is_in_support, log_p, -jnp.inf)
2714
2715
def cdf(self, value: ArrayLike) -> ArrayLike:
2716
cdf = (value - self.low) / (self.high - self.low)
0 commit comments