Skip to content

Commit 9ca88f1

Browse files
committed
fix support
1 parent 01af96b commit 9ca88f1

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

numpyro/distributions/continuous.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2706,8 +2706,11 @@ def sample(
27062706

27072707
@validate_sample
27082708
def log_prob(self, value: ArrayLike) -> ArrayLike:
2709+
log_p = -jnp.log(self.high - self.low)
2710+
is_in_support = (value >= self.low) & (value < self.high)
27092711
shape = lax.broadcast_shapes(jnp.shape(value), self.batch_shape)
2710-
return -jnp.broadcast_to(jnp.log(self.high - self.low), shape)
2712+
log_p = jnp.broadcast_to(log_p, shape)
2713+
return jnp.where(is_in_support, log_p, -jnp.inf)
27112714

27122715
def cdf(self, value: ArrayLike) -> ArrayLike:
27132716
cdf = (value - self.low) / (self.high - self.low)

0 commit comments

Comments
 (0)