Skip to content
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
51afb9f
changed laplace approx to return MvNormal
Michal-Novomestsky Jul 2, 2025
c326525
added seperate line for evaluating Q-hess
Michal-Novomestsky Jul 2, 2025
61d4d89
WIP: minor refactor
Michal-Novomestsky Jul 4, 2025
1960cb9
started writing fit_INLA routine
Michal-Novomestsky Jul 6, 2025
6a1d523
changed minimizer tol to 1e-8
Michal-Novomestsky Jul 6, 2025
674d813
WIP: MarginalLaplaceRV
Michal-Novomestsky Jul 16, 2025
3b5d49c
WIP: Minimize inside logp
Michal-Novomestsky Jul 19, 2025
22d2ef1
tidied up MarginalLaplaceRV
Michal-Novomestsky Aug 9, 2025
c49de10
refactor: variable name change
Michal-Novomestsky Aug 9, 2025
54e394d
jesse minimize testing
Michal-Novomestsky Aug 10, 2025
f02e652
end-to-end implementation
Michal-Novomestsky Aug 11, 2025
9fb860e
refactor: changed boolean logic
Michal-Novomestsky Aug 11, 2025
68b87ee
refactor: changed distributions in test case
Michal-Novomestsky Aug 12, 2025
de2d1fc
removed jesse's debug notebook
Michal-Novomestsky Aug 12, 2025
787a39e
added WIP warning to pmx.fit
Michal-Novomestsky Aug 12, 2025
71f8642
refactor: added TODO
Michal-Novomestsky Aug 12, 2025
18747d5
refactor: re-ran notebook
Michal-Novomestsky Aug 12, 2025
c6010f3
refactor: temporarily changed gitignore
Michal-Novomestsky Aug 12, 2025
a473e87
refactor: rolled gitignore back to default
Michal-Novomestsky Aug 12, 2025
31072ef
refactor: reworded list comprehension in log_likelihood
Michal-Novomestsky Aug 12, 2025
6630675
refactor: uncommented import
Michal-Novomestsky Aug 12, 2025
b275e10
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 13, 2025
a92ba8f
removed legacy code
Michal-Novomestsky Aug 13, 2025
f077250
refactor: restored missing assert
Michal-Novomestsky Aug 13, 2025
57a7935
refactor: changed test_inla.py location
Michal-Novomestsky Aug 13, 2025
8b94a99
refactor: moved _precision_mv_normal_logp into pmx
Michal-Novomestsky Aug 16, 2025
bfa4e12
Merge branch 'main' into implement-pmx.fit-option-for-INLA-+-marginal…
Michal-Novomestsky Aug 16, 2025
dd54a37
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2025
34dfdfa
set d automatically
Michal-Novomestsky Aug 17, 2025
8cb19d9
refactor: removed inccorect laplace.py and moved inla into seperate f…
Michal-Novomestsky Aug 17, 2025
0ee1ec9
bugfix: laplace import/file location
Michal-Novomestsky Aug 17, 2025
b82c6a4
refactor: folder name change
Michal-Novomestsky Aug 17, 2025
c5f2bd8
bugfix: removed erroneous test case
Michal-Novomestsky Aug 17, 2025
9d7342d
bugfix: typo in INLA
Michal-Novomestsky Aug 17, 2025
12b109f
refactor: added more __init__s
Michal-Novomestsky Aug 17, 2025
92f6a0f
removed temp_kwargs, made Q amenable to RVs, removed dependency on Mv…
Michal-Novomestsky Aug 26, 2025
296ca39
removed checking for MvNormal
Michal-Novomestsky Aug 26, 2025
dccd9a6
error message reworded
Michal-Novomestsky Aug 26, 2025
d0aaae5
added comments explaining logp bottleneck
Michal-Novomestsky Aug 26, 2025
af61cf7
removed None default for minimizer_kwargs
Michal-Novomestsky Aug 26, 2025
0779b6e
added docstring for _precision_mv_normal_logp
Michal-Novomestsky Aug 26, 2025
d7b198a
added more documentation
Michal-Novomestsky Aug 26, 2025
2198465
added example 1 to example notebook
Michal-Novomestsky Aug 27, 2025
0c4fcd5
refactor: default return_latent_posteriors to false
Michal-Novomestsky Aug 27, 2025
d031008
Merge branch 'pymc-devs:main' into implement-pmx.fit-option-for-INLA-…
Michal-Novomestsky Aug 27, 2025
e7ccfe2
refactor: moved sample step inside if-block
Michal-Novomestsky Aug 27, 2025
ece57b1
added docstring
Michal-Novomestsky Aug 27, 2025
a675b37
added latex to docstring
Michal-Novomestsky Aug 27, 2025
3636e98
refactored unittest
Michal-Novomestsky Aug 27, 2025
47b8dae
refactor: moved laplace approx into seperate function + more docstrings
Michal-Novomestsky Aug 27, 2025
fb39764
refactor: TensorLike typehint
Michal-Novomestsky Aug 27, 2025
065c6b2
refactor: labelling of p(x|y,params)
Michal-Novomestsky Aug 27, 2025
59b623d
refactor: text in example notebook
Michal-Novomestsky Aug 27, 2025
6cea8ba
removed old INLA notebook
Michal-Novomestsky Aug 27, 2025
609156e
refactor: local import
Michal-Novomestsky Aug 27, 2025
bc3f1c3
latex-friendly formatting
Michal-Novomestsky Aug 28, 2025
e032c25
getting Q as RV
Michal-Novomestsky Aug 28, 2025
e367958
updated inla docstring
Michal-Novomestsky Aug 28, 2025
8ed64fd
added warning (INLA experimental)
Michal-Novomestsky Aug 28, 2025
7ca496b
added AR1 testcase
Michal-Novomestsky Aug 28, 2025
154cc2c
added normals to notebook
Michal-Novomestsky Aug 28, 2025
fda71d6
refactor: changed test case atol to 0.2
Michal-Novomestsky Aug 28, 2025
04db7c3
refactor: add warning to d calculation
Michal-Novomestsky Aug 28, 2025
934d740
refactor: warning message
Michal-Novomestsky Aug 28, 2025
b3a3351
set vectorized jac flag to true
Michal-Novomestsky Sep 21, 2025
19bc44d
Merge branch 'main' into implement-pmx.fit-option-for-INLA-+-marginal…
Michal-Novomestsky Sep 21, 2025
176ca6b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
782 changes: 782 additions & 0 deletions notebooks/INLA_testing.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pymc_extras/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

from pymc_extras.inference.fit import fit
from pymc_extras.inference.inla import fit_INLA
from pymc_extras.inference.laplace_approx.find_map import find_MAP
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder

__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP", "fit_INLA"]
14 changes: 12 additions & 2 deletions pymc_extras/inference/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,17 @@ def fit(method: str, **kwargs) -> az.InferenceData:

return fit_pathfinder(**kwargs)

if method == "laplace":
from pymc_extras.inference import fit_laplace
elif method == "laplace":
from pymc_extras.inference.laplace import fit_laplace

return fit_laplace(**kwargs)

elif method == "INLA":
from pymc_extras.inference.inla import fit_INLA

return fit_INLA(**kwargs)

else:
raise ValueError(
f"method '{method}' not supported. Use one of 'pathfinder', 'laplace' or 'INLA'."
)
54 changes: 54 additions & 0 deletions pymc_extras/inference/inla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import warnings

import arviz as az
import pymc as pm

from pymc.distributions.multivariate import MvNormal
from pytensor.tensor import TensorVariable
from pytensor.tensor.linalg import inv as matrix_inverse

from pymc_extras.model.marginal.marginal_model import marginalize


def fit_INLA(
x: TensorVariable,
temp_kwargs=None, # TODO REMOVE. DEBUGGING TOOL
model: pm.Model | None = None,
minimizer_kwargs: dict | None = None,
return_latent_posteriors: bool = True,
**sampler_kwargs,
) -> az.InferenceData:
warnings.warn("Currently only valid for a nested normal model. WIP.", UserWarning)
# TODO ADD CHECK FOR NESTED NORMAL

model = pm.modelcontext(model)

# Check if latent field is Gaussian
if not isinstance(x.owner.op, MvNormal):
raise ValueError(
f"Latent field {x} is not instance of MvNormal. Has distribution {x.owner.op}."
)

_, _, _, tau = x.owner.inputs

# Latent field should use precison rather than covariance
if not (tau.owner and tau.owner.op == matrix_inverse):
raise ValueError(
f"Latent field {x} is not in precision matrix form. Use MvNormal(tau=Q) instead."
)

Q = tau.owner.inputs[0]

# Marginalize out the latent field
minimizer_kwargs = {"method": "L-BFGS-B", "optimizer_kwargs": {"tol": 1e-8}}
marginalize_kwargs = {"Q": Q, "temp_kwargs": temp_kwargs, "minimizer_kwargs": minimizer_kwargs}
marginal_model = marginalize(model, x, use_laplace=True, **marginalize_kwargs)

# Sample over the hyperparameters
idata = pm.sample(model=marginal_model, **sampler_kwargs)

if not return_latent_posteriors:
return idata

# TODO Unmarginalize stuff
raise NotImplementedError("Latent posteriors not supported yet, WIP.")
Loading
Loading