Skip to content

Commit 7f209cc

Browse files
committed
Add independent metropolis hasting
Signed-off-by: Guillaume Baudart <guillaume.baudart@inria.fr>
1 parent e7dc4b8 commit 7f209cc

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

src/pdl/pdl_infer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
infer_rejection_parallel,
2121
infer_smc,
2222
infer_smc_parallel,
23+
infer_independent_mh,
2324
)
2425
from .pdl_utils import validate_scope
2526

@@ -28,7 +29,13 @@ class PpdlConfig(TypedDict, total=False):
2829
"""Configuration parameters of the PDL interpreter."""
2930

3031
algo: Literal[
31-
"is", "parallel-is", "smc", "parallel-smc", "rejection", "parallel-rejection"
32+
"is",
33+
"parallel-is",
34+
"smc",
35+
"parallel-smc",
36+
"rejection",
37+
"parallel-rejection",
38+
"imh",
3239
]
3340
num_particles: int
3441
max_workers: int
@@ -58,7 +65,7 @@ def exec_program(
5865
config["event_loop"] = _LOOP
5966

6067
match algo:
61-
case "is" | "rejection" | "parallel-rejection":
68+
case "is" | "rejection" | "parallel-rejection" | "imh":
6269
config["with_resample"] = False
6370
case "smc" | "parallel-smc" | "parallel-is":
6471
config["with_resample"] = True
@@ -87,6 +94,8 @@ def model(replay, score):
8794
dist = infer_rejection(num_particles, model)
8895
case "parallel-rejection":
8996
dist = infer_rejection_parallel(num_particles, model, max_workers=4)
97+
case "imh":
98+
dist = infer_independent_mh(num_particles, model)
9099
case _:
91100
assert False, f"Unexpected algo: {algo}"
92101
return dist
@@ -155,6 +164,7 @@ def main():
155164
"parallel-smc",
156165
"rejection",
157166
"parallel-rejection",
167+
"imh",
158168
],
159169
help="Choose inference algorithm.",
160170
default="smc",

src/pdl/pdl_smc.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,29 @@ def gen():
201201
return Categorical(results)
202202

203203

204+
def infer_independent_mh(
205+
num_samples: int,
206+
model: Callable[[ModelStateT, float], tuple[T, ModelStateT, float]],
207+
) -> Categorical[T]:
208+
samples = []
209+
210+
new_value, _, new_score = model({}, 0.0)
211+
212+
for _ in range(num_samples):
213+
old_score = new_score # store state
214+
old_value = new_value # store current value
215+
new_value, _, new_score = model({}, 0.0) # generate a candidate
216+
alpha = math.exp(min(0, new_score - old_score))
217+
u = random.random() # nosec B311
218+
# [B311:blacklist] Standard pseudo-random generators are not suitable for security/cryptographic purposes.
219+
if not (u < alpha):
220+
new_score = old_score # rollback
221+
new_value = old_value
222+
samples.append((new_value, 0.0))
223+
224+
return Categorical(samples)
225+
226+
204227
# async def _process_particle_async(state, model, num_particles):
205228
# with ImportanceSampling(num_particles) as sampler:
206229
# try:

0 commit comments

Comments
 (0)