|
91 | 91 | "# Heatmap helper using GenerateHeatmap\n", |
92 | 92 | "# NOTE: GenerateHeatmap expects points in (y, x) == (row, col) order matching array indexing.\n", |
93 | 93 | "# This wrapper accepts user-friendly (x, y) and internally reorders to (y, x).\n", |
| 94 | + "# It now supports batched inputs.\n", |
94 | 95 | "\n", |
95 | 96 | "sigma = 3.0\n", |
96 | 97 | "\n", |
|
99 | 100 | " s = float(sigma_override) if sigma_override is not None else float(sigma)\n", |
100 | 101 | " tr = GenerateHeatmap(sigma=s, spatial_shape=(H, W))\n", |
101 | 102 | " # Reorder (x,y) -> (y,x) for the transform\n", |
102 | | - " pts_yx = np.array([[float(y), float(x)]], dtype=np.float32)\n", |
103 | | - " return tr(pts_yx) # (N,H,W) where pts interpreted as (row, col)" |
| 103 | + " # Support batched and non-batched inputs\n", |
| 104 | + " pts = np.array(list(zip(y, x)), dtype=np.float32)\n", |
| 105 | + " if pts.ndim == 2:\n", |
| 106 | + " pts = pts[np.newaxis, ...] # Add batch dimension: (N, 2) -> (1, N, 2)\n", |
| 107 | + " pts_yx = pts[..., [1, 0]]\n", |
| 108 | + " return tr(pts_yx) # (B, N, H, W)\n" |
104 | 109 | ] |
105 | 110 | }, |
106 | 111 | { |
|
125 | 130 | " spatial_shape=None if use_ref else (H, W),\n", |
126 | 131 | " sigma=s,\n", |
127 | 132 | " )\n", |
128 | | - " pts_yx = np.array([[float(y), float(x)]], dtype=np.float32)\n", |
| 133 | + " # Support batched and non-batched inputs\n", |
| 134 | + " pts = np.array(list(zip(y, x)), dtype=np.float32)\n", |
| 135 | + " if pts.ndim == 2:\n", |
| 136 | + " pts = pts[np.newaxis, ...] # Add batch dimension: (N, 2) -> (1, N, 2)\n", |
| 137 | + " pts_yx = pts[..., [1, 0]]\n", |
129 | 138 | " data = {\"points\": pts_yx, \"ref\": ref_img}\n", |
130 | 139 | " out = tr(data)\n", |
131 | | - " return out[\"heatmap\"]" |
| 140 | + " return out[\"heatmap\"]\n" |
132 | 141 | ] |
133 | 142 | }, |
134 | 143 | { |
|
179 | 188 | "num_points = 3 # number of random landmarks\n", |
180 | 189 | "sigma_demo = 3.0 # Gaussian sigma\n", |
181 | 190 | "combine_mode = \"max\" # or 'sum'\n", |
| 191 | + "batched_input = True # Set to True to test batched input\n", |
182 | 192 | "\n", |
183 | 193 | "# Sample random (x,y) points within image bounds (user-friendly)\n", |
184 | 194 | "points_xy = np.array(\n", |
|
189 | 199 | "\n", |
190 | 200 | "# Convert to (y,x) for the transform\n", |
191 | 201 | "yx_points = points_xy[:, [1, 0]].copy()\n", |
| 202 | + "if batched_input:\n", |
| 203 | + " yx_points = yx_points[np.newaxis, ...] # Add a batch dimension\n", |
192 | 204 | "\n", |
193 | 205 | "array_tr = GenerateHeatmap(sigma=sigma_demo, spatial_shape=(H, W))\n", |
194 | 206 | "heatmaps = array_tr(yx_points) # now correct orientation\n", |
195 | 207 | "\n", |
| 208 | + "if batched_input:\n", |
| 209 | + " heatmaps = heatmaps.squeeze(0) # Remove batch dim for plotting\n", |
| 210 | + "\n", |
196 | 211 | "if combine_mode == \"max\":\n", |
197 | 212 | " combined = heatmaps.max(axis=0)\n", |
198 | 213 | "elif combine_mode == \"sum\":\n", |
|
230 | 245 | " ax.set_title(f\"Point {i}: (x={points_xy[i,0]:.1f}, y={points_xy[i,1]:.1f})\")\n", |
231 | 246 | " ax.set_axis_off()\n", |
232 | 247 | "plt.tight_layout()\n", |
233 | | - "plt.show()" |
| 248 | + "plt.show()\n" |
234 | 249 | ] |
235 | 250 | } |
236 | 251 | ], |
|
0 commit comments