Skip to content

Commit 6ab799e

Browse files
committed
adding support for void label to unet segmentation
1 parent 3dc96e0 commit 6ab799e

File tree

3 files changed

+125
-79
lines changed

3 files changed

+125
-79
lines changed

scenarios/segmentation/01_training_introduction.ipynb

Lines changed: 117 additions & 77 deletions
Large diffs are not rendered by default.

utils_cv/segmentation/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
def ratio_correct(void_id, input, target):
1818
""" Helper function to compute the ratio of correctly classified pixels. """
1919
target = target.squeeze(1)
20-
if void_id:
20+
if void_id != None:
2121
mask = target != void_id
2222
ratio_correct = (
2323
(input.argmax(dim=1)[mask] == target[mask]).float().mean()

utils_cv/segmentation/plot.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def plot_segmentation(
5757
show: bool = True,
5858
figsize: Tuple[int, int] = (16, 4),
5959
cmap: ListedColormap = cm.get_cmap("Set3"),
60+
ignore_background_label = True
6061
) -> None:
6162
""" Plot an image, its predicted mask with associated scores, and optionally the ground truth mask.
6263
@@ -68,10 +69,15 @@ def plot_segmentation(
6869
show: set to true to call matplotlib's show()
6970
figsize: figure size
7071
cmap: mask color map.
72+
ignore_background_label: set to True to ignore the 0 label.
7173
"""
7274
im = load_im(im_or_path)
7375
pred_mask = pil2tensor(pred_mask, np.float32)
74-
max_scores = np.max(np.array(pred_scores[1:]), axis=0)
76+
if ignore_background_label:
77+
start_label = 1
78+
else:
79+
start_label = 0
80+
max_scores = np.max(np.array(pred_scores[start_label:]), axis=0)
7581
max_scores = pil2tensor(max_scores, np.float32)
7682

7783
# Plot groud truth mask if provided

0 commit comments

Comments
 (0)