|
97 | 97 | "from torchvision.utils import make_grid\n",
|
98 | 98 | "from tqdm import tqdm\n",
|
99 | 99 | "\n",
|
100 |
| - "from torchcfm.conditional_flow_matching import ConditionalFlowMatcher, ExactOptimalTransportConditionalFlowMatcher, SchrodingerBridgeConditionalFlowMatcher\n", |
| 100 | + "from torchcfm.conditional_flow_matching import (\n", |
| 101 | + " ConditionalFlowMatcher,\n", |
| 102 | + " ExactOptimalTransportConditionalFlowMatcher,\n", |
| 103 | + " SchrodingerBridgeConditionalFlowMatcher,\n", |
| 104 | + ")\n", |
101 | 105 | "from torchcfm.models.unet import UNetModel\n",
|
102 | 106 | "\n",
|
103 | 107 | "savedir = \"models/cond_mnist\"\n",
|
|
154 | 158 | "# Float Conditional CFM\n",
|
155 | 159 | "#################################\n",
|
156 | 160 | "\n",
|
| 161 | + "\n", |
157 | 162 | "class embed_condition(torch.nn.Module):\n",
|
158 |
| - " \"\"\" simple network to embed the condition, other architectures can be used too \"\"\"\n", |
| 163 | + " \"\"\"simple network to embed the condition, other architectures can be used too\"\"\"\n", |
159 | 164 | "\n",
|
160 |
| - " def __init__(self, input_dim=1, target_dim = 128):\n", |
| 165 | + " def __init__(self, input_dim=1, target_dim=128):\n", |
161 | 166 | " super().__init__()\n",
|
162 | 167 | "\n",
|
163 | 168 | " self.model = torch.nn.Sequential(\n",
|
|
171 | 176 | " return self.model(label)\n",
|
172 | 177 | "\n",
|
173 | 178 | "\n",
|
174 |
| - "\n", |
175 | 179 | "sigma = 0.0\n",
|
176 | 180 | "model = UNetModel(\n",
|
177 | 181 | " dim=(1, 28, 28), num_channels=32, num_res_blocks=1, embedding_net=embed_condition\n",
|
|
11920 | 11924 | " for i, data in enumerate(train_loader):\n",
|
11921 | 11925 | " optimizer.zero_grad()\n",
|
11922 | 11926 | " x1 = data[0].to(device)\n",
|
11923 |
| - " y = data[1].float().to(device).reshape((batch_size, 1)) / 2. #just to have a floating point label\n", |
| 11927 | + " y = (\n", |
| 11928 | + " data[1].float().to(device).reshape((batch_size, 1)) / 2.0\n", |
| 11929 | + " ) # just to have a floating point label\n", |
11924 | 11930 | " x0 = torch.randn_like(x1)\n",
|
11925 | 11931 | " t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)\n",
|
11926 | 11932 | " vt = model(t, xt, y)\n",
|
|
11956 | 11962 | ],
|
11957 | 11963 | "source": [
|
11958 | 11964 | "USE_TORCH_DIFFEQ = True\n",
|
11959 |
| - "ntest = 10*10\n", |
11960 |
| - "generated_class_list = torch.arange(10, device=device).repeat(10).reshape((ntest, 1)).float() / 2. #TODO: reshape\n", |
| 11965 | + "ntest = 10 * 10\n", |
| 11966 | + "generated_class_list = (\n", |
| 11967 | + " torch.arange(10, device=device).repeat(10).reshape((ntest, 1)).float() / 2.0\n", |
| 11968 | + ") # TODO: reshape\n", |
11961 | 11969 | "with torch.no_grad():\n",
|
11962 | 11970 | " if USE_TORCH_DIFFEQ:\n",
|
11963 | 11971 | " traj = torchdiffeq.odeint(\n",
|
|
11978 | 11986 | ")\n",
|
11979 | 11987 | "img = ToPILImage()(grid)\n",
|
11980 | 11988 | "plt.imshow(img)\n",
|
11981 |
| - "cond_values = \", \".join([ f\"{float(item):.2f}\" for item in generated_class_list[0:10,0] ])\n", |
| 11989 | + "cond_values = \", \".join([f\"{float(item):.2f}\" for item in generated_class_list[0:10, 0]])\n", |
11982 | 11990 | "plt.title(f\"float conditional cfm\\nlabels: {cond_values}\")\n",
|
11983 | 11991 | "plt.savefig(\"floatconditional-cfm_noninteger.svg\")"
|
11984 | 11992 | ]
|
|
26069 | 26077 | " for i, data in enumerate(train_loader):\n",
|
26070 | 26078 | " optimizer.zero_grad()\n",
|
26071 | 26079 | " x1 = data[0].to(device)\n",
|
26072 |
| - " y = data[1].float().to(device).reshape((batch_size, 1)) / 2. #just to have a floating point label\n", |
| 26080 | + " y = (\n", |
| 26081 | + " data[1].float().to(device).reshape((batch_size, 1)) / 2.0\n", |
| 26082 | + " ) # just to have a floating point label\n", |
26073 | 26083 | " x0 = torch.randn_like(x1)\n",
|
26074 | 26084 | " t, xt, ut, _, y1 = FM.guided_sample_location_and_conditional_flow(x0, x1, y1=y)\n",
|
26075 | 26085 | " vt = model(t, xt, y1)\n",
|
|
82572 | 82582 | " for i, data in tqdm(enumerate(train_loader)):\n",
|
82573 | 82583 | " optimizer.zero_grad()\n",
|
82574 | 82584 | " x1 = data[0].to(device)\n",
|
82575 |
| - " y = data[1].float().to(device).reshape((batch_size, 1)) / 2. #just to have a floating point label\n", |
| 82585 | + " y = (\n", |
| 82586 | + " data[1].float().to(device).reshape((batch_size, 1)) / 2.0\n", |
| 82587 | + " ) # just to have a floating point label\n", |
82576 | 82588 | " x0 = torch.randn_like(x1)\n",
|
82577 | 82589 | " t, xt, ut, _, y1, eps = FM.guided_sample_location_and_conditional_flow(\n",
|
82578 | 82590 | " x0, x1, y1=y, return_noise=True\n",
|
|
0 commit comments