11import collections
2+ import math
23from enum import IntEnum
34
45import cv2
78from core import imagelib
89from facelib import FaceType , LandmarksProcessor
910
11+
1012class SampleProcessor (object ):
1113 class SampleType (IntEnum ):
1214 NONE = 0
@@ -114,8 +116,8 @@ def get_eyes_mask():
114116 if sample_type == SPST .FACE_IMAGE or sample_type == SPST .FACE_MASK :
115117 if not is_face_sample :
116118 raise ValueError ("face_samples should be provided for sample_type FACE_*" )
117-
118- if is_face_sample :
119+
120+ if sample_type == SPST . FACE_IMAGE or sample_type == SPST . FACE_MASK :
119121 face_type = opts .get ('face_type' , None )
120122 face_mask_type = opts .get ('face_mask_type' , SPFMT .NONE )
121123
@@ -125,7 +127,6 @@ def get_eyes_mask():
125127 if face_type > sample .face_type :
126128 raise Exception ('sample %s type %s does not match model requirement %s. Consider extract necessary type of faces.' % (sample .filename , sample .face_type , face_type ) )
127129
128- if sample_type == SPST .FACE_IMAGE or sample_type == SPST .FACE_MASK :
129130
130131 if sample_type == SPST .FACE_MASK :
131132
@@ -156,7 +157,7 @@ def get_eyes_mask():
156157 img = cv2 .resize ( img , (resolution , resolution ), cv2 .INTER_CUBIC )
157158
158159 img = imagelib .warp_by_params (params_per_resolution [resolution ], img , warp , transform , can_flip = True , border_replicate = border_replicate , cv2_inter = cv2 .INTER_LINEAR )
159-
160+
160161 if len (img .shape ) == 2 :
161162 img = img [...,None ]
162163
@@ -175,11 +176,11 @@ def get_eyes_mask():
175176 else :
176177 if w != resolution :
177178 img = cv2 .resize ( img , (resolution , resolution ), cv2 .INTER_CUBIC )
178- img = imagelib .warp_by_params (params_per_resolution [resolution ], img , warp , transform , can_flip = True , border_replicate = border_replicate )
179179
180+ img = imagelib .warp_by_params (params_per_resolution [resolution ], img , warp , transform , can_flip = True , border_replicate = border_replicate )
181+
180182 img = np .clip (img .astype (np .float32 ), 0 , 1 )
181-
182-
183+
183184
184185 # Apply random color transfer
185186 if ct_mode is not None and ct_sample is not None :
@@ -273,17 +274,16 @@ def get_eyes_mask():
273274 l = np .clip (l , 0.0 , 1.0 )
274275 out_sample = l
275276 elif sample_type == SPST .PITCH_YAW_ROLL or sample_type == SPST .PITCH_YAW_ROLL_SIGMOID :
276- pitch_yaw_roll = sample .get_pitch_yaw_roll ()
277-
278- if params ['flip' ]:
277+ pitch ,yaw ,roll = sample .get_pitch_yaw_roll ()
278+ if params_per_resolution [resolution ]['flip' ]:
279279 yaw = - yaw
280280
281281 if sample_type == SPST .PITCH_YAW_ROLL_SIGMOID :
282282 pitch = np .clip ( (pitch / math .pi ) / 2.0 + 0.5 , 0 , 1 )
283283 yaw = np .clip ( (yaw / math .pi ) / 2.0 + 0.5 , 0 , 1 )
284284 roll = np .clip ( (roll / math .pi ) / 2.0 + 0.5 , 0 , 1 )
285285
286- out_sample = (pitch , yaw , roll )
286+ out_sample = (pitch , yaw )
287287 else :
288288 raise ValueError ('expected sample_type' )
289289
0 commit comments