@@ -940,9 +940,10 @@ def flatnonzero(a):
940940 nonzero_values : Return the non-zero elements of the input array
941941
942942 """
943- if a .ndim == 0 :
943+ _a = as_tensor_variable (a )
944+ if _a .ndim == 0 :
944945 raise ValueError ("Nonzero only supports non-scalar arrays." )
945- return nonzero (a .flatten (), return_matrix = False )[0 ]
946+ return nonzero (_a .flatten (), return_matrix = False )[0 ]
946947
947948
948949def nonzero_values (a ):
@@ -1324,9 +1325,10 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
13241325 tensor
13251326 tensor the shape of x with ones on main diagonal and zeroes elsewhere of type of dtype.
13261327 """
1328+ _x = as_tensor_variable (x )
13271329 if dtype is None :
1328- dtype = x .dtype
1329- return eye (x .shape [0 ], x .shape [1 ], k = 0 , dtype = dtype )
1330+ dtype = _x .dtype
1331+ return eye (_x .shape [0 ], _x .shape [1 ], k = 0 , dtype = dtype )
13301332
13311333
13321334def infer_broadcastable (shape ):
@@ -2773,8 +2775,9 @@ def tile(x, reps, ndim=None):
27732775 """
27742776 from aesara .tensor .math import ge
27752777
2776- if ndim is not None and ndim < x .ndim :
2777- raise ValueError ("ndim should be equal or larger than x.ndim" )
2778+ _x = as_tensor_variable (x )
2779+ if ndim is not None and ndim < _x .ndim :
2780+ raise ValueError ("ndim should be equal or larger than _x.ndim" )
27782781
27792782 # If reps is a scalar, integer or vector, we convert it to a list.
27802783 if not isinstance (reps , (list , tuple )):
@@ -2799,8 +2802,8 @@ def tile(x, reps, ndim=None):
27992802 # assert that reps.shape[0] does not exceed ndim
28002803 offset = assert_op (offset , ge (offset , 0 ))
28012804
2802- # if reps.ndim is less than x .ndim, we pad the reps with
2803- # "1" so that reps will have the same ndim as x .
2805+ # if reps.ndim is less than _x .ndim, we pad the reps with
2806+ # "1" so that reps will have the same ndim as _x .
28042807 reps_ = [switch (i < offset , 1 , reps [i - offset ]) for i in range (ndim )]
28052808 reps = reps_
28062809
@@ -2817,17 +2820,17 @@ def tile(x, reps, ndim=None):
28172820 ):
28182821 raise ValueError ("elements of reps must be scalars of integer dtype" )
28192822
2820- # If reps.ndim is less than x .ndim, we pad the reps with
2821- # "1" so that reps will have the same ndim as x
2823+ # If reps.ndim is less than _x .ndim, we pad the reps with
2824+ # "1" so that reps will have the same ndim as _x
28222825 reps = list (reps )
28232826 if ndim is None :
2824- ndim = builtins .max (len (reps ), x .ndim )
2827+ ndim = builtins .max (len (reps ), _x .ndim )
28252828 if len (reps ) < ndim :
28262829 reps = [1 ] * (ndim - len (reps )) + reps
28272830
2828- _shape = [1 ] * (ndim - x .ndim ) + [x .shape [i ] for i in range (x .ndim )]
2831+ _shape = [1 ] * (ndim - _x .ndim ) + [_x .shape [i ] for i in range (_x .ndim )]
28292832 alloc_shape = reps + _shape
2830- y = alloc (x , * alloc_shape )
2833+ y = alloc (_x , * alloc_shape )
28312834 shuffle_ind = np .arange (ndim * 2 ).reshape (2 , ndim )
28322835 shuffle_ind = shuffle_ind .transpose ().flatten ()
28332836 y = y .dimshuffle (* shuffle_ind )
@@ -3288,8 +3291,9 @@ def inverse_permutation(perm):
32883291 Each row of input should contain a permutation of the first integers.
32893292
32903293 """
3294+ _perm = as_tensor_variable (perm )
32913295 return permute_row_elements (
3292- arange (perm .shape [- 1 ], dtype = perm .dtype ), perm , inverse = True
3296+ arange (_perm .shape [- 1 ], dtype = _perm .dtype ), _perm , inverse = True
32933297 )
32943298
32953299
@@ -3575,12 +3579,14 @@ def diag(v, k=0):
35753579
35763580 """
35773581
3578- if v .ndim == 1 :
3579- return AllocDiag (k )(v )
3580- elif v .ndim >= 2 :
3581- return diagonal (v , offset = k )
3582+ _v = as_tensor_variable (v )
3583+
3584+ if _v .ndim == 1 :
3585+ return AllocDiag (k )(_v )
3586+ elif _v .ndim >= 2 :
3587+ return diagonal (_v , offset = k )
35823588 else :
3583- raise ValueError ("Input must has v.ndim >= 1 ." )
3589+ raise ValueError ("Number of dimensions of `v` must be greater than one ." )
35843590
35853591
35863592def stacklists (arg ):
0 commit comments