Skip to content

Commit 24b2d20

Browse files
committed
Update examples to work on GPU (where applicable)
1 parent 00d71ad commit 24b2d20

File tree

14 files changed

+323
-547
lines changed

14 files changed

+323
-547
lines changed

docs/__init__.py

Whitespace-only changes.

docs/examples/heat/example_inference.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,10 @@
7979
},
8080
{
8181
"cell_type": "code",
82-
"execution_count": 2,
82+
"execution_count": null,
8383
"metadata": {},
8484
"outputs": [],
8585
"source": [
86-
"torch.set_default_device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
8786
"torch.manual_seed(1)\n",
8887
"set_plot_style()"
8988
]

docs/examples/logistic/__init__.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,18 @@
1313
Data = namedtuple("Data", ["X", "y"])
1414

1515

16-
def read_credit_data(fname: Path) -> Data:
16+
def load_credit_data(device: torch.device = torch.device("cpu")) -> Data:
1717
"""Reads in the German credit dataset, then shifts and scales the
1818
predictors such that each has a mean of zero and standard deviation
1919
of 1, and scales the response variable such that it takes values in
2020
{0, 1}.
2121
"""
2222

23-
with open(fname, "r") as f:
23+
with open(DATA_PATH, "r") as f:
2424
data = [[float(l) for l in line.strip().split()]
2525
for line in f.readlines()]
2626

27-
data = torch.tensor(data)
27+
data = torch.tensor(data, device=device)
2828
X, y = data[:, :-1], data[:, -1]
2929

3030
mean_X = torch.mean(X, dim=0)
@@ -33,7 +33,4 @@ def read_credit_data(fname: Path) -> Data:
3333
X = (X - mean_X) / std_X
3434
y -= 1.0
3535

36-
return Data(X, y)
37-
38-
39-
credit_data = read_credit_data(DATA_PATH)
36+
return Data(X, y)

docs/examples/logistic/example_inference.ipynb

Lines changed: 60 additions & 73 deletions
Large diffs are not rendered by default.

docs/examples/plotting.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,14 @@ def pairplot(
4242
if labels is None:
4343
labels = [f"$x_{i+1}$" for i in range(dim)]
4444

45+
xs = xs.to("cpu")
4546
if ys is not None:
47+
ys = ys.to("cpu")
4648
xys = torch.vstack((xs, ys))
4749
else:
4850
xys = xs.clone()
51+
if truth is not None:
52+
truth.to("cpu")
4953

5054
if bounds is None:
5155
samples_min = xys.min(dim=0).values

docs/examples/poisson/example_amortised.ipynb

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,11 @@
7777
},
7878
{
7979
"cell_type": "code",
80-
"execution_count": 2,
80+
"execution_count": null,
8181
"metadata": {},
8282
"outputs": [],
8383
"source": [
8484
"torch.manual_seed(0)\n",
85-
"torch.set_default_device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
8685
"set_plot_style()"
8786
]
8887
},
@@ -95,7 +94,7 @@
9594
},
9695
{
9796
"cell_type": "code",
98-
"execution_count": 3,
97+
"execution_count": null,
9998
"metadata": {},
10099
"outputs": [],
101100
"source": [
@@ -111,7 +110,7 @@
111110
},
112111
{
113112
"cell_type": "code",
114-
"execution_count": 4,
113+
"execution_count": null,
115114
"metadata": {},
116115
"outputs": [
117116
{

docs/examples/shock/__init__.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,23 @@
1+
from collections import namedtuple
2+
from pathlib import Path
3+
14
import torch
25

3-
from .preconditioner import GammaNormalMapping
4-
5-
6-
# Define failure distances (km)
7-
failure_dists = torch.tensor([
8-
6700, 6950, 7820, 8790, 9120,
9-
9660, 9820, 11310, 11690, 11850,
10-
11880, 12140, 12200, 12870, 13150,
11-
13330, 13470, 14040, 14300, 17520,
12-
17540, 17890, 18420, 18960, 18980,
13-
19410, 20100, 20100, 20150, 20320,
14-
20900, 22700, 23490, 26510, 27410,
15-
27490, 27890, 28100
16-
])
17-
18-
# Define whether or not each observation is right-censored
19-
censored = torch.tensor([
20-
False, True, True, True, False,
21-
True, True, True, True, True,
22-
True, True, False, True, False,
23-
True, True, True, False, False,
24-
True, True, True, True, True,
25-
True, False, True, True, True,
26-
False, False, True, False, True,
27-
False, True, True
28-
])
6+
from .preconditioner import GammaNormalMapping
7+
8+
9+
DATA_PATH = Path(__file__).resolve().parent.joinpath("data")
10+
FAILURE_DISTS_PATH = DATA_PATH.joinpath("failure_dists.pt")
11+
CENSORED_PATH = DATA_PATH.joinpath("censored.pt")
12+
13+
14+
Data = namedtuple("Data", ["failure_dists", "censored"])
15+
16+
17+
def load_shock_data(device: torch.device = torch.device("cpu")) -> Data:
18+
"""Reads in the data (failure distances (km) and censorship
19+
information) used in @Dolgov2020.
20+
"""
21+
failure_dists = torch.load(FAILURE_DISTS_PATH).to(device=device)
22+
censored = torch.load(CENSORED_PATH).to(device=device)
23+
return Data(failure_dists, censored)

docs/examples/shock/example_inference.ipynb

Lines changed: 80 additions & 185 deletions
Large diffs are not rendered by default.

docs/examples/shock/preconditioner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import deep_tensor as dt
55

6-
from examples.shock.prior import GammaDist, GaussianDist
6+
from .prior import GammaDist, GaussianDist
77

88
EPS = torch.finfo(torch.get_default_dtype()).eps
99

@@ -14,8 +14,8 @@ def __init__(
1414
self,
1515
reference: dt.Reference,
1616
bounds: Tensor,
17-
alpha: float,
18-
gamma: float,
17+
alpha: Tensor,
18+
gamma: Tensor,
1919
ms: Tensor,
2020
sds: Tensor,
2121
dim: int

docs/examples/shock/prior.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ def newton(
8787

8888

8989
def converged(fs: Tensor, dls: Tensor) -> bool:
90-
error_fs = fs.abs()
91-
error_dls = dls.abs()
92-
return error_fs.max() < 1e-8
90+
return fs.abs().max().item() < 1e-8
9391

9492

9593
def gaussian_potential(xs: Tensor, mus: Tensor, sds: Tensor) -> Tensor:
@@ -116,10 +114,10 @@ class GammaDist():
116114
117115
"""
118116

119-
def __init__(self, alpha: float, lamb: float, bounds: Tensor):
117+
def __init__(self, alpha: Tensor, lamb: Tensor, bounds: Tensor):
120118

121-
self.alpha = torch.tensor(alpha)
122-
self.lamb = torch.tensor(lamb)
119+
self.alpha = alpha
120+
self.lamb = lamb
123121
self.bounds = bounds
124122
self.Gamma = Gamma(self.alpha, 1.0 / self.lamb)
125123

@@ -142,8 +140,8 @@ def func(xs: Tensor) -> Tuple[Tensor, Tensor]:
142140
zs -= zs_cdf
143141
return zs, dzdxs
144142

145-
l0s = torch.full_like(zs_cdf, self.bounds[0])
146-
l1s = torch.full_like(zs_cdf, self.bounds[1])
143+
l0s = torch.full_like(zs_cdf, self.bounds[0].item())
144+
l1s = torch.full_like(zs_cdf, self.bounds[1].item())
147145
return newton(func, l0s, l1s)
148146

149147

0 commit comments

Comments
 (0)