1818import numpy as np
1919import tensorflow as tf
2020
21- from tensorflow_addons .image import transform_ops
2221from skimage import transform
2322
23+ from tensorflow_addons .image import transform_ops
24+ from tensorflow_addons .utils import test_utils
25+
2426_DTYPES = {
2527 tf .dtypes .uint8 ,
2628 tf .dtypes .int32 ,
@@ -322,11 +324,13 @@ def test_unknown_shape():
322324 np .testing .assert_equal (image .numpy (), fn (image ).numpy ())
323325
324326
325- # TODO: Parameterize on dtypes
326327@pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
327- def test_shear_x ():
328- image = np .random .randint (low = 0 , high = 255 , size = (4 , 4 , 3 ), dtype = np .uint8 )
329- color = tf .constant ([255 , 0 , 255 ], tf .uint8 )
328+ @pytest .mark .parametrize ("dtype" , _DTYPES - {tf .dtypes .float16 })
329+ def test_shear_x (dtype ):
330+ image = np .random .randint (low = 0 , high = 255 , size = (4 , 4 , 3 )).astype (
331+ dtype .as_numpy_dtype
332+ )
333+ color = tf .constant ([255 , 0 , 255 ], tf .int32 )
330334 level = tf .random .uniform (shape = (), minval = 0 , maxval = 1 )
331335
332336 tf_image = tf .constant (image )
@@ -344,11 +348,13 @@ def test_shear_x():
344348 np .testing .assert_equal (sheared_img .numpy (), expected_img )
345349
346350
347- # TODO: Parameterize on dtypes
348351@pytest .mark .usefixtures ("maybe_run_functions_eagerly" )
349- def test_shear_y ():
350- image = np .random .randint (low = 0 , high = 255 , size = (4 , 4 , 3 ), dtype = np .uint8 )
351- color = tf .constant ([255 , 0 , 255 ], tf .dtypes .uint8 )
352+ @pytest .mark .parametrize ("dtype" , _DTYPES - {tf .dtypes .float16 })
353+ def test_shear_y (dtype ):
354+ image = np .random .randint (low = 0 , high = 255 , size = (4 , 4 , 3 )).astype (
355+ dtype .as_numpy_dtype
356+ )
357+ color = tf .constant ([255 , 0 , 255 ], tf .int32 )
352358 level = tf .random .uniform (shape = (), minval = 0 , maxval = 1 )
353359
354360 tf_image = tf .constant (image )
@@ -363,4 +369,4 @@ def test_shear_y():
363369 mask = np .where (expected_img == - 1 )
364370 expected_img [mask [0 ], mask [1 ], :] = color
365371
366- np . testing . assert_equal (sheared_img .numpy (), expected_img )
372+ test_utils . assert_allclose_according_to_type (sheared_img .numpy (), expected_img )
0 commit comments