1+ from typing import Tuple
2+ import os
3+
4+ import numpy as np
5+ from scipy import optimize
6+ import torch
7+ from torch .autograd .functional import jacobian , hessian
8+
9+ from examples .plotting import pairplot
10+
11+ import deep_tensor as dt
12+
13+
14+ def read_credit_data (fname : str ) -> Tuple [torch .Tensor , torch .Tensor ]:
15+ """Reads in the German credit dataset, then shifts and scales the
16+ predictors such that each has a mean of zero and standard deviation
17+ of 1, and scales the response variable such that it takes values in
18+ {0, 1}.
19+ """
20+
21+ with open (fname , "r" ) as f :
22+ data = [[float (l ) for l in line .strip ().split ()]
23+ for line in f .readlines ()]
24+
25+ data = torch .tensor (data )
26+ xs , ys = data [:, :- 1 ], data [:, - 1 ]
27+
28+ mean_xs = torch .mean (xs , dim = 0 )
29+ std_xs = torch .std (xs , dim = 0 )
30+
31+ xs = (xs - mean_xs ) / std_xs
32+ ys -= 1.0
33+
34+ return xs , ys
35+
36+ fname = os .path .join ("examples" , "credit" , "german.data-numeric" )
37+ xs , ys = read_credit_data (fname )
38+
39+ n_beta = 1 + xs .shape [1 ]
40+
41+ mean_pri = torch .zeros ((n_beta ,))
42+ sd_pri = 10.0
43+ cov_pri = sd_pri ** 2 * torch .eye (n_beta )
44+
45+ def negloglik (bs : torch .Tensor ) -> torch .Tensor :
46+
47+ bs = torch .atleast_2d (bs )
48+
49+ neglogodds = bs [:, :1 ] + torch .sum (bs [:, 1 :, None ] * xs .T [None , ...], dim = 1 )
50+ probs = 1.0 / (1.0 + torch .exp (- neglogodds ))
51+
52+ neglogliks_0 = - torch .log (1.0 - probs )[:, ys < 0.5 ].sum (dim = 1 )
53+ neglogliks_1 = - torch .log (probs )[:, ys > 0.5 ].sum (dim = 1 )
54+ neglogliks = neglogliks_0 + neglogliks_1 - 500 # numerical stability
55+ return neglogliks
56+
57+ def neglogpri (bs : torch .Tensor ) -> torch .Tensor :
58+
59+ bs = torch .atleast_2d (bs )
60+
61+ neglogpris = 0.5 * (bs / sd_pri ).square ().sum (dim = 1 )
62+ return neglogpris
63+
64+ def neglogpost (bs : torch .Tensor ) -> torch .Tensor :
65+ return negloglik (bs ) + neglogpri (bs )
66+
67+ def compute_laplace_approx () -> Tuple [torch .Tensor , torch .Tensor ]:
68+ """Computes a Laplace approximation to the posterior."""
69+
70+ def jac (_bs : np .ndarray ) -> torch .Tensor :
71+ bs = torch .from_numpy (_bs )
72+ return jacobian (lambda x : neglogpost (x [None , :]), bs ).flatten ()
73+
74+ res = optimize .minimize (
75+ fun = lambda bs : neglogpost (torch .from_numpy (bs )),
76+ x0 = torch .zeros ((n_beta ,)),
77+ jac = jac
78+ )
79+
80+ if not res .success :
81+ msg = "MAP optimisation failed to converge."
82+ raise Exception (msg )
83+
84+ bs_map = torch .from_numpy (res .x )
85+ H = hessian (lambda x : neglogpost (x [None , :]), bs_map )
86+ H_inv = torch .linalg .inv (H )
87+ return bs_map , H_inv
88+
89+ bs_map , cov_map = compute_laplace_approx ()
90+
91+ domain = dt .BoundedDomain (torch .tensor ([- 6.0 , 6.0 ]))
92+ reference = dt .GaussianReference (domain = domain )
93+ preconditioner = dt .GaussianPreconditioner (bs_map , cov_map , reference )
94+
95+ bases = dt .Lagrange1 (num_elems = 20 )
96+
97+ dirt = dt .DIRT (
98+ negloglik ,
99+ neglogpri ,
100+ preconditioner ,
101+ bases ,
102+ tt_options = dt .TTOptions (verbose = 2 , init_rank = 10 , max_rank = 12 )
103+ )
104+
105+ n_steps = 10_000
106+
107+ norm = torch .distributions .MultivariateNormal (bs_map .flatten (), cov_map )
108+ samples = norm .sample ((n_steps ,))
109+ potentials_norm = - norm .log_prob (samples )
110+ potentials_true = negloglik (samples ) + neglogpri (samples )
111+
112+ res = dt .run_independence_sampler (samples , potentials_norm , potentials_true )
113+ print (res .acceptance_rate )
114+ print (res .iacts .max ())
115+ print (res .ess .min ())
116+
117+ rs = dirt .reference .random (d = dirt .dim , n = n_steps )
118+ samples , potentials_dirt = dirt .eval_irt (rs )
119+ potentials_true = negloglik (samples ) + neglogpri (samples )
120+
121+ res = dt .run_independence_sampler (samples , potentials_dirt , potentials_true )
122+ print (res .acceptance_rate )
123+ print (res .iacts .max ())
124+ print (res .ess .min ())
125+
126+ rs = dirt .reference .random (d = dirt .dim , n = 1000 )
127+ samples = preconditioner .Q (rs , "first" )
128+ pairplot (res .xs [::5 , :6 ])
0 commit comments