@@ -94,6 +94,8 @@ class MeanOfAbsoluteDifference(BaseBackgroundMattingMetrics):
9494 def update (self , annotation , prediction ):
9595 pred = self .get_prediction (prediction )
9696 gt = self .get_annotation (annotation )
97+ if pred .shape [- 1 ] == 1 and pred .shape [- 1 ] != gt .shape [- 1 ]:
98+ gt = cv2 .cvtColor (gt , cv2 .COLOR_RGB2GRAY )
9799 value = np .mean (abs (pred - gt )) * 1e3
98100 self .results .append (value )
99101 return value
@@ -105,6 +107,8 @@ class SpatialGradient(BaseBackgroundMattingMetrics):
105107 def update (self , annotation , prediction ):
106108 pred = self .get_prediction (prediction )
107109 gt = self .get_annotation (annotation )
110+ if pred .shape [- 1 ] == 1 and pred .shape [- 1 ] != gt .shape [- 1 ]:
111+ gt = cv2 .cvtColor (gt , cv2 .COLOR_RGB2GRAY )
108112 gt_grad = self .gauss_gradient (gt )
109113 pred_grad = self .gauss_gradient (pred )
110114 value = np .sum ((gt_grad - pred_grad ) ** 2 ) / 1000
@@ -152,6 +156,8 @@ class MeanSquaredErrorWithMask(BaseBackgroundMattingMetrics):
152156 def update (self , annotation , prediction ):
153157 pred = self .get_prediction (prediction )
154158 gt = self .get_annotation (annotation )
159+ if pred .shape [- 1 ] == 1 and pred .shape [- 1 ] != gt .shape [- 1 ]:
160+ gt = cv2 .cvtColor (gt , cv2 .COLOR_RGB2GRAY )
155161 if self .use_mask :
156162 mask = self .prepare_pha (annotation .value ) > 0
157163 pred = pred [mask ]
0 commit comments