@@ -971,36 +971,17 @@ def test_mul(self, dtypes):
971971 def test_truediv (self , dtypes ):
972972 import jax .numpy as jnp
973973
974- try :
975- # JAX v0.8.0 and newer
976- from jax import enable_x64
977- except ImportError :
978- # JAX v0.7.2 and older
979- from jax .experimental import enable_x64
980-
981- # We have to disable x64 for jax since jnp.true_divide doesn't respect
982- # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast
983- # the expected dtype from 64 bit to 32 bit when using jax backend.
984- with enable_x64 (False ):
985- dtype1 , dtype2 = dtypes
986- x1 = backend .Variable (
987- "ones" , shape = (1 ,), dtype = dtype1 , trainable = False
988- )
989- x2 = backend .Variable (
990- "ones" , shape = (1 ,), dtype = dtype2 , trainable = False
991- )
992- x1_jax = jnp .ones ((1 ,), dtype = dtype1 )
993- x2_jax = jnp .ones ((1 ,), dtype = dtype2 )
994- expected_dtype = standardize_dtype (
995- jnp .true_divide (x1_jax , x2_jax ).dtype
996- )
997- if "float64" in (dtype1 , dtype2 ):
998- expected_dtype = "float64"
999- if backend .backend () == "jax" :
1000- expected_dtype = expected_dtype .replace ("64" , "32" )
974+ dtype1 , dtype2 = dtypes
975+ x1 = backend .Variable ("ones" , shape = (1 ,), dtype = dtype1 , trainable = False )
976+ x2 = backend .Variable ("ones" , shape = (1 ,), dtype = dtype2 , trainable = False )
977+ x1_jax = jnp .ones ((1 ,), dtype = dtype1 )
978+ x2_jax = jnp .ones ((1 ,), dtype = dtype2 )
979+ expected_dtype = standardize_dtype (
980+ jnp .true_divide (x1_jax , x2_jax ).dtype
981+ )
1001982
1002- self .assertDType (x1 / x2 , expected_dtype )
1003- self .assertDType (x1 .__rtruediv__ (x2 ), expected_dtype )
983+ self .assertDType (x1 / x2 , expected_dtype )
984+ self .assertDType (x1 .__rtruediv__ (x2 ), expected_dtype )
1004985
1005986 @parameterized .named_parameters (
1006987 named_product (dtypes = itertools .combinations (NON_COMPLEX_DTYPES , 2 ))
0 commit comments