Skip to content

Commit da66d2f

Browse files
committed
New bench function
1 parent 1329890 commit da66d2f

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed

tests/link/numba/test_compile.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import numpy as np
2+
3+
import pytensor.tensor as pt
4+
from pytensor import function
5+
from pytensor.graph import rewrite_graph
6+
from pytensor.graph.traversal import explicit_graph_inputs
7+
8+
9+
def test_radon_model_logp_dlogp():
10+
def halfnormal(name, *, sigma=1.0, model_logp):
11+
log_value = pt.scalar(f"{name}_log")
12+
value = pt.exp(log_value)
13+
14+
logp = (
15+
-0.5 * ((value / sigma) ** 2) + pt.log(pt.sqrt(2.0 / np.pi)) - pt.log(sigma)
16+
)
17+
logp = pt.switch(value >= 0, logp, -np.inf)
18+
model_logp.append(logp + value)
19+
return value
20+
21+
def normal(name, *, mu=0.0, sigma=1.0, model_logp, observed=None):
22+
value = pt.scalar(name) if observed is None else pt.as_tensor(observed)
23+
24+
logp = (
25+
-0.5 * (((value - mu) / sigma) ** 2)
26+
- pt.log(pt.sqrt(2.0 * np.pi))
27+
- pt.log(sigma)
28+
)
29+
model_logp.append(logp)
30+
return value
31+
32+
def zerosumnormal(name, *, sigma=1.0, size, model_logp):
33+
raw_value = pt.vector(f"{name}_zerosum", shape=(size - 1,))
34+
n = raw_value.shape[0] + 1
35+
sum_vals = raw_value.sum(0, keepdims=True)
36+
norm = sum_vals / (pt.sqrt(n) + n)
37+
fill_value = norm - sum_vals / pt.sqrt(n)
38+
value = pt.concatenate([raw_value, fill_value]) - norm
39+
40+
shape = value.shape
41+
_full_size = pt.prod(shape)
42+
_degrees_of_freedom = pt.prod(shape[-1:].inc(-1))
43+
logp = pt.sum(
44+
-0.5 * ((value / sigma) ** 2)
45+
- (pt.log(pt.sqrt(2.0 * np.pi)) + pt.log(sigma))
46+
* (_degrees_of_freedom / _full_size)
47+
)
48+
model_logp.append(logp)
49+
return value
50+
51+
rng = np.random.default_rng(1)
52+
n_counties = 85
53+
county_idx = rng.integers(n_counties, size=919)
54+
county_idx.sort()
55+
floor = rng.binomial(n=1, p=0.5, size=919).astype(np.float64)
56+
log_radon = rng.normal(size=919)
57+
58+
# joined_inputs = pt.vector("joined_inputs")
59+
60+
model_logp = []
61+
intercept = normal("intercept", sigma=10, model_logp=model_logp)
62+
63+
# County effects
64+
county_raw = zerosumnormal("county_raw", size=n_counties, model_logp=model_logp)
65+
county_sd = halfnormal("county_sd", model_logp=model_logp)
66+
county_effect = county_raw * county_sd
67+
68+
# Global floor effect
69+
floor_effect = normal("floor_effect", sigma=2, model_logp=model_logp)
70+
71+
county_floor_raw = zerosumnormal(
72+
"county_floor_raw", size=n_counties, model_logp=model_logp
73+
)
74+
county_floor_sd = halfnormal("county_floor_sd", model_logp=model_logp)
75+
county_floor_effect = county_floor_raw * county_floor_sd
76+
77+
mu = (
78+
intercept
79+
+ county_effect[county_idx]
80+
+ floor_effect * floor
81+
+ county_floor_effect[county_idx] * floor
82+
)
83+
84+
sigma = halfnormal("sigma", model_logp=model_logp)
85+
_ = normal(
86+
"log_radon",
87+
mu=mu,
88+
sigma=sigma,
89+
observed=log_radon,
90+
model_logp=model_logp,
91+
)
92+
93+
model_logp = pt.sum([logp.sum() for logp in model_logp])
94+
model_logp = rewrite_graph(
95+
model_logp, include=("canonicalize", "stabilize"), clone=False
96+
)
97+
params = list(explicit_graph_inputs(model_logp))
98+
model_dlogp = pt.concatenate([term.ravel() for term in pt.grad(model_logp, params)])
99+
100+
# TODO: Replace inputs by raveled vector
101+
102+
function(params, [model_logp, model_dlogp]).dprint()

0 commit comments

Comments
 (0)