diff --git a/src/pdl/pdl_infer.py b/src/pdl/pdl_infer.py index c9f61c5df..8c6479770 100644 --- a/src/pdl/pdl_infer.py +++ b/src/pdl/pdl_infer.py @@ -20,6 +20,7 @@ infer_rejection_parallel, infer_smc, infer_smc_parallel, + infer_independent_mh, ) from .pdl_utils import validate_scope @@ -28,7 +29,13 @@ class PpdlConfig(TypedDict, total=False): """Configuration parameters of the PDL interpreter.""" algo: Literal[ - "is", "parallel-is", "smc", "parallel-smc", "rejection", "parallel-rejection" + "is", + "parallel-is", + "smc", + "parallel-smc", + "rejection", + "parallel-rejection", + "imh", ] num_particles: int max_workers: int @@ -58,7 +65,7 @@ def exec_program( config["event_loop"] = _LOOP match algo: - case "is" | "rejection" | "parallel-rejection": + case "is" | "rejection" | "parallel-rejection" | "imh": config["with_resample"] = False case "smc" | "parallel-smc" | "parallel-is": config["with_resample"] = True @@ -87,6 +94,8 @@ def model(replay, score): dist = infer_rejection(num_particles, model) case "parallel-rejection": dist = infer_rejection_parallel(num_particles, model, max_workers=4) + case "imh": + dist = infer_independent_mh(num_particles, model) case _: assert False, f"Unexpected algo: {algo}" return dist @@ -155,6 +164,7 @@ def main(): "parallel-smc", "rejection", "parallel-rejection", + "imh", ], help="Choose inference algorithm.", default="smc", diff --git a/src/pdl/pdl_smc.py b/src/pdl/pdl_smc.py index dba35903f..9a33569aa 100644 --- a/src/pdl/pdl_smc.py +++ b/src/pdl/pdl_smc.py @@ -201,6 +201,29 @@ def gen(): return Categorical(results) +def infer_independent_mh( + num_samples: int, + model: Callable[[ModelStateT, float], tuple[T, ModelStateT, float]], +) -> Categorical[T]: + samples = [] + + new_value, _, new_score = model({}, 0.0) + + for _ in range(num_samples): + old_score = new_score # store state + old_value = new_value # store current value + new_value, _, new_score = model({}, 0.0) # generate a candidate + alpha = math.exp(min(0, new_score - old_score)) + u = random.random() # nosec B311 + # [B311:blacklist] Standard pseudo-random generators are not suitable for security/cryptographic purposes. + if not (u < alpha): + new_score = old_score # rollback + new_value = old_value + samples.append((new_value, 0.0)) + + return Categorical(samples) + + # async def _process_particle_async(state, model, num_particles): # with ImportanceSampling(num_particles) as sampler: # try: