Skip to content

Commit 5def626

Browse files
committed
precommit hooks
1 parent 830b1dd commit 5def626

File tree

3 files changed

+51
-32
lines changed

3 files changed

+51
-32
lines changed

examples/images/conditional_mnist_noninteger.ipynb

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@
9797
"from torchvision.utils import make_grid\n",
9898
"from tqdm import tqdm\n",
9999
"\n",
100-
"from torchcfm.conditional_flow_matching import ConditionalFlowMatcher, ExactOptimalTransportConditionalFlowMatcher, SchrodingerBridgeConditionalFlowMatcher\n",
100+
"from torchcfm.conditional_flow_matching import (\n",
101+
" ConditionalFlowMatcher,\n",
102+
" ExactOptimalTransportConditionalFlowMatcher,\n",
103+
" SchrodingerBridgeConditionalFlowMatcher,\n",
104+
")\n",
101105
"from torchcfm.models.unet import UNetModel\n",
102106
"\n",
103107
"savedir = \"models/cond_mnist\"\n",
@@ -154,10 +158,11 @@
154158
"# Float Conditional CFM\n",
155159
"#################################\n",
156160
"\n",
161+
"\n",
157162
"class embed_condition(torch.nn.Module):\n",
158-
" \"\"\" simple network to embed the condition, other architectures can be used too \"\"\"\n",
163+
" \"\"\"simple network to embed the condition, other architectures can be used too\"\"\"\n",
159164
"\n",
160-
" def __init__(self, input_dim=1, target_dim = 128):\n",
165+
" def __init__(self, input_dim=1, target_dim=128):\n",
161166
" super().__init__()\n",
162167
"\n",
163168
" self.model = torch.nn.Sequential(\n",
@@ -171,7 +176,6 @@
171176
" return self.model(label)\n",
172177
"\n",
173178
"\n",
174-
"\n",
175179
"sigma = 0.0\n",
176180
"model = UNetModel(\n",
177181
" dim=(1, 28, 28), num_channels=32, num_res_blocks=1, embedding_net=embed_condition\n",
@@ -11920,7 +11924,9 @@
1192011924
" for i, data in enumerate(train_loader):\n",
1192111925
" optimizer.zero_grad()\n",
1192211926
" x1 = data[0].to(device)\n",
11923-
" y = data[1].float().to(device).reshape((batch_size, 1)) / 2. #just to have a floating point label\n",
11927+
" y = (\n",
11928+
" data[1].float().to(device).reshape((batch_size, 1)) / 2.0\n",
11929+
" ) # just to have a floating point label\n",
1192411930
" x0 = torch.randn_like(x1)\n",
1192511931
" t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)\n",
1192611932
" vt = model(t, xt, y)\n",
@@ -11956,8 +11962,10 @@
1195611962
],
1195711963
"source": [
1195811964
"USE_TORCH_DIFFEQ = True\n",
11959-
"ntest = 10*10\n",
11960-
"generated_class_list = torch.arange(10, device=device).repeat(10).reshape((ntest, 1)).float() / 2. #TODO: reshape\n",
11965+
"ntest = 10 * 10\n",
11966+
"generated_class_list = (\n",
11967+
" torch.arange(10, device=device).repeat(10).reshape((ntest, 1)).float() / 2.0\n",
11968+
") # TODO: reshape\n",
1196111969
"with torch.no_grad():\n",
1196211970
" if USE_TORCH_DIFFEQ:\n",
1196311971
" traj = torchdiffeq.odeint(\n",
@@ -11978,7 +11986,7 @@
1197811986
")\n",
1197911987
"img = ToPILImage()(grid)\n",
1198011988
"plt.imshow(img)\n",
11981-
"cond_values = \", \".join([ f\"{float(item):.2f}\" for item in generated_class_list[0:10,0] ])\n",
11989+
"cond_values = \", \".join([f\"{float(item):.2f}\" for item in generated_class_list[0:10, 0]])\n",
1198211990
"plt.title(f\"float conditional cfm\\nlabels: {cond_values}\")\n",
1198311991
"plt.savefig(\"floatconditional-cfm_noninteger.svg\")"
1198411992
]
@@ -26069,7 +26077,9 @@
2606926077
" for i, data in enumerate(train_loader):\n",
2607026078
" optimizer.zero_grad()\n",
2607126079
" x1 = data[0].to(device)\n",
26072-
" y = data[1].float().to(device).reshape((batch_size, 1)) / 2. #just to have a floating point label\n",
26080+
" y = (\n",
26081+
" data[1].float().to(device).reshape((batch_size, 1)) / 2.0\n",
26082+
" ) # just to have a floating point label\n",
2607326083
" x0 = torch.randn_like(x1)\n",
2607426084
" t, xt, ut, _, y1 = FM.guided_sample_location_and_conditional_flow(x0, x1, y1=y)\n",
2607526085
" vt = model(t, xt, y1)\n",
@@ -82572,7 +82582,9 @@
8257282582
" for i, data in tqdm(enumerate(train_loader)):\n",
8257382583
" optimizer.zero_grad()\n",
8257482584
" x1 = data[0].to(device)\n",
82575-
" y = data[1].float().to(device).reshape((batch_size, 1)) / 2. #just to have a floating point label\n",
82585+
" y = (\n",
82586+
" data[1].float().to(device).reshape((batch_size, 1)) / 2.0\n",
82587+
" ) # just to have a floating point label\n",
8257682588
" x0 = torch.randn_like(x1)\n",
8257782589
" t, xt, ut, _, y1, eps = FM.guided_sample_location_and_conditional_flow(\n",
8257882590
" x0, x1, y1=y, return_noise=True\n",

tests/test_models.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,34 +17,31 @@ def test_initialize_unet():
1717

1818
batch = torch.zeros((8, 1, 28, 28), dtype=torch.float32)
1919
label = torch.ones((8,), dtype=torch.long)
20-
timesteps = torch.linspace(0,1,steps=8)
20+
timesteps = torch.linspace(0, 1, steps=8)
2121

22-
_ = model(t= timesteps, x=batch, y=label)
22+
_ = model(t=timesteps, x=batch, y=label)
2323

2424

2525
def test_initialize_mlp():
26-
2726
model1 = MLP(dim=2, time_varying=True, w=64)
2827
batch = torch.ones((8, 3), dtype=torch.float32)
2928
output1 = model1(x=batch)
3029

31-
assert output1.shape == (8,2)
30+
assert output1.shape == (8, 2)
3231

3332
model2 = MLP(dim=2, w=64)
3433
batch = torch.ones((8, 2), dtype=torch.float32)
3534
output2 = model2(x=batch)
3635

37-
assert output2.shape == (8,2)
36+
assert output2.shape == (8, 2)
3837

3938

4039
class mock_embedding(torch.nn.Module):
41-
4240
def __init__(self, outdim=128):
4341
super().__init__()
4442
self.outdim = outdim
4543

4644
def forward(self, inputs):
47-
4845
batchsize = inputs.size(0)
4946
if len(inputs.shape) == 1:
5047
inputs = inputs.reshape((batchsize, 1))
@@ -53,21 +50,20 @@ def forward(self, inputs):
5350

5451

5552
def test_conditional_model_without_integer_labels():
56-
5753
model_channels = 32
5854
model = UNetModel(
5955
dim=(1, 28, 28),
6056
num_channels=model_channels,
6157
num_res_blocks=1,
6258
class_cond=False,
63-
embedding_net = mock_embedding
59+
embedding_net=mock_embedding,
6460
)
6561

6662
x1 = torch.ones((8, 1, 28, 28), dtype=torch.float32)
6763
x0 = torch.randn_like(x1)
68-
FM = ConditionalFlowMatcher(sigma=0.)
64+
FM = ConditionalFlowMatcher(sigma=0.0)
6965
t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
7066

71-
label = 42.1*torch.ones((8,)).float()
67+
label = 42.1 * torch.ones((8,)).float()
7268

7369
vt = model(t=t, x=xt, y=label)

torchcfm/models/unet/unet.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,15 @@
88
import torch.nn.functional as F
99

1010
from .fp16_util import convert_module_to_f16, convert_module_to_f32
11-
from .nn import (avg_pool_nd, checkpoint, conv_nd, linear, normalization,
12-
timestep_embedding, zero_module)
11+
from .nn import (
12+
avg_pool_nd,
13+
checkpoint,
14+
conv_nd,
15+
linear,
16+
normalization,
17+
timestep_embedding,
18+
zero_module,
19+
)
1320

1421

1522
class AttentionPool2d(nn.Module):
@@ -411,7 +418,7 @@ def __init__(
411418
use_scale_shift_norm=False,
412419
resblock_updown=False,
413420
use_new_attention_order=False,
414-
embedding_net = nn.Identity
421+
embedding_net=nn.Identity,
415422
):
416423
super().__init__()
417424

@@ -445,8 +452,9 @@ def __init__(
445452
if self.num_classes is not None:
446453
embedding_net = nn.Embedding if embedding_net == nn.Identity else embedding_net
447454
self.label_emb = embedding_net(num_classes, time_embed_dim)
448-
assert not isinstance(self.label_emb, nn.Identity), f"for class-conditional networks, provide an embedding please!"
449-
455+
assert not isinstance(
456+
self.label_emb, nn.Identity
457+
), f"for class-conditional networks, provide an embedding please!"
450458

451459
ch = input_ch = int(channel_mult[0] * model_channels)
452460
self.input_blocks = nn.ModuleList(
@@ -605,9 +613,9 @@ def forward(self, t, x, y=None):
605613
:return: an [N x C x ...] Tensor of outputs.
606614
"""
607615
timesteps = t
608-
#assert (y is not None) == (
616+
# assert (y is not None) == (
609617
# self.num_classes is not None
610-
#), "must specify y if and only if the model is class-conditional"
618+
# ), "must specify y if and only if the model is class-conditional"
611619
while timesteps.dim() > 1:
612620
print(timesteps.shape)
613621
timesteps = timesteps[:, 0]
@@ -623,8 +631,10 @@ def forward(self, t, x, y=None):
623631
# (batchsize, 4*self.model_channels)
624632
emb = self.time_embed(embedded_timesteps)
625633

626-
if (y is not None):
627-
assert y.shape[0] == x.shape[0], f"batch dimension of y ({y.shape[0]}) does not match x ({x.shape[0]})"
634+
if y is not None:
635+
assert (
636+
y.shape[0] == x.shape[0]
637+
), f"batch dimension of y ({y.shape[0]}) does not match x ({x.shape[0]})"
628638
labels = self.label_emb(y)
629639
emb = emb + labels
630640

@@ -862,6 +872,7 @@ def forward(self, x, timesteps):
862872

863873
NUM_CLASSES = 1000
864874

875+
865876
# this overwrites UNetModel in __init__.py
866877
class UNetModelWrapper(UNetModel):
867878
def __init__(
@@ -883,7 +894,7 @@ def __init__(
883894
resblock_updown=False,
884895
use_fp16=False,
885896
use_new_attention_order=False,
886-
embedding_net = nn.Identity
897+
embedding_net=nn.Identity,
887898
):
888899
"""Dim (tuple): (C, H, W)"""
889900
image_size = dim[-1]
@@ -927,7 +938,7 @@ def __init__(
927938
use_scale_shift_norm=use_scale_shift_norm,
928939
resblock_updown=resblock_updown,
929940
use_new_attention_order=use_new_attention_order,
930-
embedding_net = embedding_net
941+
embedding_net=embedding_net,
931942
)
932943

933944
def forward(self, t, x, y=None, *args, **kwargs):

0 commit comments

Comments
 (0)