Skip to content

Commit b6a22f4

Browse files
committed
precommit hooks
1 parent 85cc180 commit b6a22f4

File tree

3 files changed

+51
-32
lines changed

3 files changed

+51
-32
lines changed

examples/images/conditional_mnist_noninteger.ipynb

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@
9797
"from torchvision.utils import make_grid\n",
9898
"from tqdm import tqdm\n",
9999
"\n",
100-
"from torchcfm.conditional_flow_matching import ConditionalFlowMatcher, ExactOptimalTransportConditionalFlowMatcher, SchrodingerBridgeConditionalFlowMatcher\n",
100+
"from torchcfm.conditional_flow_matching import (\n",
101+
" ConditionalFlowMatcher,\n",
102+
" ExactOptimalTransportConditionalFlowMatcher,\n",
103+
" SchrodingerBridgeConditionalFlowMatcher,\n",
104+
")\n",
101105
"from torchcfm.models.unet import UNetModel\n",
102106
"\n",
103107
"savedir = \"models/cond_mnist\"\n",
@@ -154,10 +158,11 @@
154158
"# Float Conditional CFM\n",
155159
"#################################\n",
156160
"\n",
161+
"\n",
157162
"class embed_condition(torch.nn.Module):\n",
158-
" \"\"\" simple network to embed the condition, other architectures can be used too \"\"\"\n",
163+
" \"\"\"simple network to embed the condition, other architectures can be used too\"\"\"\n",
159164
"\n",
160-
" def __init__(self, input_dim=1, target_dim = 128):\n",
165+
" def __init__(self, input_dim=1, target_dim=128):\n",
161166
" super().__init__()\n",
162167
"\n",
163168
" self.model = torch.nn.Sequential(\n",
@@ -171,7 +176,6 @@
171176
" return self.model(label)\n",
172177
"\n",
173178
"\n",
174-
"\n",
175179
"sigma = 0.0\n",
176180
"model = UNetModel(\n",
177181
" dim=(1, 28, 28), num_channels=32, num_res_blocks=1, embedding_net=embed_condition\n",
@@ -11920,7 +11924,9 @@
1192011924
" for i, data in enumerate(train_loader):\n",
1192111925
" optimizer.zero_grad()\n",
1192211926
" x1 = data[0].to(device)\n",
11923-
" y = data[1].float().to(device).reshape((batch_size, 1)) / 2. #just to have a floating point label\n",
11927+
" y = (\n",
11928+
" data[1].float().to(device).reshape((batch_size, 1)) / 2.0\n",
11929+
" ) # just to have a floating point label\n",
1192411930
" x0 = torch.randn_like(x1)\n",
1192511931
" t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)\n",
1192611932
" vt = model(t, xt, y)\n",
@@ -11956,8 +11962,10 @@
1195611962
],
1195711963
"source": [
1195811964
"USE_TORCH_DIFFEQ = True\n",
11959-
"ntest = 10*10\n",
11960-
"generated_class_list = torch.arange(10, device=device).repeat(10).reshape((ntest, 1)).float() / 2. #TODO: reshape\n",
11965+
"ntest = 10 * 10\n",
11966+
"generated_class_list = (\n",
11967+
" torch.arange(10, device=device).repeat(10).reshape((ntest, 1)).float() / 2.0\n",
11968+
") # TODO: reshape\n",
1196111969
"with torch.no_grad():\n",
1196211970
" if USE_TORCH_DIFFEQ:\n",
1196311971
" traj = torchdiffeq.odeint(\n",
@@ -11978,7 +11986,7 @@
1197811986
")\n",
1197911987
"img = ToPILImage()(grid)\n",
1198011988
"plt.imshow(img)\n",
11981-
"cond_values = \", \".join([ f\"{float(item):.2f}\" for item in generated_class_list[0:10,0] ])\n",
11989+
"cond_values = \", \".join([f\"{float(item):.2f}\" for item in generated_class_list[0:10, 0]])\n",
1198211990
"plt.title(f\"float conditional cfm\\nlabels: {cond_values}\")\n",
1198311991
"plt.savefig(\"floatconditional-cfm_noninteger.svg\")"
1198411992
]
@@ -26069,7 +26077,9 @@
2606926077
" for i, data in enumerate(train_loader):\n",
2607026078
" optimizer.zero_grad()\n",
2607126079
" x1 = data[0].to(device)\n",
26072-
" y = data[1].float().to(device).reshape((batch_size, 1)) / 2. #just to have a floating point label\n",
26080+
" y = (\n",
26081+
" data[1].float().to(device).reshape((batch_size, 1)) / 2.0\n",
26082+
" ) # just to have a floating point label\n",
2607326083
" x0 = torch.randn_like(x1)\n",
2607426084
" t, xt, ut, _, y1 = FM.guided_sample_location_and_conditional_flow(x0, x1, y1=y)\n",
2607526085
" vt = model(t, xt, y1)\n",
@@ -82572,7 +82582,9 @@
8257282582
" for i, data in tqdm(enumerate(train_loader)):\n",
8257382583
" optimizer.zero_grad()\n",
8257482584
" x1 = data[0].to(device)\n",
82575-
" y = data[1].float().to(device).reshape((batch_size, 1)) / 2. #just to have a floating point label\n",
82585+
" y = (\n",
82586+
" data[1].float().to(device).reshape((batch_size, 1)) / 2.0\n",
82587+
" ) # just to have a floating point label\n",
8257682588
" x0 = torch.randn_like(x1)\n",
8257782589
" t, xt, ut, _, y1, eps = FM.guided_sample_location_and_conditional_flow(\n",
8257882590
" x0, x1, y1=y, return_noise=True\n",

tests/test_models.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,31 @@ def test_initialize_unet():
1515

1616
batch = torch.zeros((8, 1, 28, 28), dtype=torch.float32)
1717
label = torch.ones((8,), dtype=torch.long)
18-
timesteps = torch.linspace(0,1,steps=8)
18+
timesteps = torch.linspace(0, 1, steps=8)
1919

20-
_ = model(t= timesteps, x=batch, y=label)
20+
_ = model(t=timesteps, x=batch, y=label)
2121

2222

2323
def test_initialize_mlp():
24-
2524
model1 = MLP(dim=2, time_varying=True, w=64)
2625
batch = torch.ones((8, 3), dtype=torch.float32)
2726
output1 = model1(x=batch)
2827

29-
assert output1.shape == (8,2)
28+
assert output1.shape == (8, 2)
3029

3130
model2 = MLP(dim=2, w=64)
3231
batch = torch.ones((8, 2), dtype=torch.float32)
3332
output2 = model2(x=batch)
3433

35-
assert output2.shape == (8,2)
34+
assert output2.shape == (8, 2)
3635

3736

3837
class mock_embedding(torch.nn.Module):
39-
4038
def __init__(self, outdim=128):
4139
super().__init__()
4240
self.outdim = outdim
4341

4442
def forward(self, inputs):
45-
4643
batchsize = inputs.size(0)
4744
if len(inputs.shape) == 1:
4845
inputs = inputs.reshape((batchsize, 1))
@@ -51,22 +48,21 @@ def forward(self, inputs):
5148

5249

5350
def test_conditional_model_without_integer_labels():
54-
5551
model_channels = 32
5652
model = UNetModel(
5753
dim=(1, 28, 28),
5854
num_channels=model_channels,
5955
num_res_blocks=1,
6056
class_cond=False,
61-
embedding_net = mock_embedding
57+
embedding_net=mock_embedding,
6258
)
6359

6460
x1 = torch.ones((8, 1, 28, 28), dtype=torch.float32)
6561
x0 = torch.randn_like(x1)
66-
FM = ConditionalFlowMatcher(sigma=0.)
62+
FM = ConditionalFlowMatcher(sigma=0.0)
6763
t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
6864

69-
label = 42.1*torch.ones((8,)).float()
65+
label = 42.1 * torch.ones((8,)).float()
7066

7167
vt = model(t=t, x=xt, y=label)
7268
>>>>>>> 1dbfbd0 (more tests and floating point conditions)

torchcfm/models/unet/unet.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,15 @@
99
import torch.nn.functional as F
1010

1111
from .fp16_util import convert_module_to_f16, convert_module_to_f32
12-
from .nn import (avg_pool_nd, checkpoint, conv_nd, linear, normalization,
13-
timestep_embedding, zero_module)
12+
from .nn import (
13+
avg_pool_nd,
14+
checkpoint,
15+
conv_nd,
16+
linear,
17+
normalization,
18+
timestep_embedding,
19+
zero_module,
20+
)
1421

1522

1623
class AttentionPool2d(nn.Module):
@@ -410,7 +417,7 @@ def __init__(
410417
use_scale_shift_norm=False,
411418
resblock_updown=False,
412419
use_new_attention_order=False,
413-
embedding_net = nn.Identity
420+
embedding_net=nn.Identity,
414421
):
415422
super().__init__()
416423

@@ -444,8 +451,9 @@ def __init__(
444451
if self.num_classes is not None:
445452
embedding_net = nn.Embedding if embedding_net == nn.Identity else embedding_net
446453
self.label_emb = embedding_net(num_classes, time_embed_dim)
447-
assert not isinstance(self.label_emb, nn.Identity), f"for class-conditional networks, provide an embedding please!"
448-
454+
assert not isinstance(
455+
self.label_emb, nn.Identity
456+
), f"for class-conditional networks, provide an embedding please!"
449457

450458
ch = input_ch = int(channel_mult[0] * model_channels)
451459
self.input_blocks = nn.ModuleList(
@@ -604,9 +612,9 @@ def forward(self, t, x, y=None):
604612
:return: an [N x C x ...] Tensor of outputs.
605613
"""
606614
timesteps = t
607-
#assert (y is not None) == (
615+
# assert (y is not None) == (
608616
# self.num_classes is not None
609-
#), "must specify y if and only if the model is class-conditional"
617+
# ), "must specify y if and only if the model is class-conditional"
610618
while timesteps.dim() > 1:
611619
print(timesteps.shape)
612620
timesteps = timesteps[:, 0]
@@ -622,8 +630,10 @@ def forward(self, t, x, y=None):
622630
# (batchsize, 4*self.model_channels)
623631
emb = self.time_embed(embedded_timesteps)
624632

625-
if (y is not None):
626-
assert y.shape[0] == x.shape[0], f"batch dimension of y ({y.shape[0]}) does not match x ({x.shape[0]})"
633+
if y is not None:
634+
assert (
635+
y.shape[0] == x.shape[0]
636+
), f"batch dimension of y ({y.shape[0]}) does not match x ({x.shape[0]})"
627637
labels = self.label_emb(y)
628638
emb = emb + labels
629639

@@ -861,6 +871,7 @@ def forward(self, x, timesteps):
861871

862872
NUM_CLASSES = 1000
863873

874+
864875
# this overwrites UNetModel in __init__.py
865876
class UNetModelWrapper(UNetModel):
866877
def __init__(
@@ -882,7 +893,7 @@ def __init__(
882893
resblock_updown=False,
883894
use_fp16=False,
884895
use_new_attention_order=False,
885-
embedding_net = nn.Identity
896+
embedding_net=nn.Identity,
886897
):
887898
"""Dim (tuple): (C, H, W)"""
888899
image_size = dim[-1]
@@ -926,7 +937,7 @@ def __init__(
926937
use_scale_shift_norm=use_scale_shift_norm,
927938
resblock_updown=resblock_updown,
928939
use_new_attention_order=use_new_attention_order,
929-
embedding_net = embedding_net
940+
embedding_net=embedding_net,
930941
)
931942

932943
def forward(self, t, x, y=None, *args, **kwargs):

0 commit comments

Comments
 (0)