@@ -165,15 +165,20 @@ def compute_output_shape(self, x_shape, attention_mask_shape):
165165 self .assertEqual (output .shape , (2 , 3 ))
166166
167167 # Test compute_output_spec as well
168- output_spec = layer .compute_output_spec (x , attention_mask = attention_mask )
168+ output_spec = layer .compute_output_spec (
169+ x , attention_mask = attention_mask
170+ )
169171 self .assertEqual (output_spec .shape , (2 , 3 ))
170172
171173 def test_mask_parameter_exclusions (self ):
172174 # Test that only 'mask' parameter is excluded from shapes_dict,
173175 # not all parameters ending with '_mask'. Issue #21154.
174176
175177 class LayerWithMultipleMasks (layers .Layer ):
176- def call (self , x , mask = None , attention_mask = None , padding_mask = None ):
178+ def call (
179+ self , x , mask = None , attention_mask = None ,
180+ padding_mask = None
181+ ):
177182 result = x
178183 if mask is not None :
179184 result = result * mask
@@ -198,12 +203,16 @@ def compute_output_shape(
198203 padding_mask = backend .KerasTensor ((2 , 3 ))
199204
200205 # This should work without errors
201- output = layer (x , mask = mask , attention_mask = attention_mask , padding_mask = padding_mask )
206+ output = layer (
207+ x , mask = mask , attention_mask = attention_mask ,
208+ padding_mask = padding_mask
209+ )
202210 self .assertEqual (output .shape , (2 , 3 ))
203211
204212 # Test compute_output_spec as well
205213 output_spec = layer .compute_output_spec (
206- x , mask = mask , attention_mask = attention_mask , padding_mask = padding_mask
214+ x , mask = mask , attention_mask = attention_mask ,
215+ padding_mask = padding_mask
207216 )
208217 self .assertEqual (output_spec .shape , (2 , 3 ))
209218
0 commit comments