Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit e40c827

Browse files
authored
Add tensor conversion to flatnonzero, nonzero_values, tile, inverse_permutation, and diag
1 parent c2ed818 commit e40c827

File tree

2 files changed

+63
-21
lines changed

2 files changed

+63
-21
lines changed

aesara/tensor/basic.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -940,9 +940,10 @@ def flatnonzero(a):
940940
nonzero_values : Return the non-zero elements of the input array
941941
942942
"""
943-
if a.ndim == 0:
943+
_a = as_tensor_variable(a)
944+
if _a.ndim == 0:
944945
raise ValueError("Nonzero only supports non-scalar arrays.")
945-
return nonzero(a.flatten(), return_matrix=False)[0]
946+
return nonzero(_a.flatten(), return_matrix=False)[0]
946947

947948

948949
def nonzero_values(a):
@@ -1324,9 +1325,10 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
13241325
tensor
13251326
tensor the shape of x with ones on main diagonal and zeroes elsewhere of type of dtype.
13261327
"""
1328+
_x = as_tensor_variable(x)
13271329
if dtype is None:
1328-
dtype = x.dtype
1329-
return eye(x.shape[0], x.shape[1], k=0, dtype=dtype)
1330+
dtype = _x.dtype
1331+
return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype)
13301332

13311333

13321334
def infer_broadcastable(shape):
@@ -2773,8 +2775,9 @@ def tile(x, reps, ndim=None):
27732775
"""
27742776
from aesara.tensor.math import ge
27752777

2776-
if ndim is not None and ndim < x.ndim:
2777-
raise ValueError("ndim should be equal or larger than x.ndim")
2778+
_x = as_tensor_variable(x)
2779+
if ndim is not None and ndim < _x.ndim:
2780+
raise ValueError("ndim should be equal or larger than _x.ndim")
27782781

27792782
# If reps is a scalar, integer or vector, we convert it to a list.
27802783
if not isinstance(reps, (list, tuple)):
@@ -2799,8 +2802,8 @@ def tile(x, reps, ndim=None):
27992802
# assert that reps.shape[0] does not exceed ndim
28002803
offset = assert_op(offset, ge(offset, 0))
28012804

2802-
# if reps.ndim is less than x.ndim, we pad the reps with
2803-
# "1" so that reps will have the same ndim as x.
2805+
# if reps.ndim is less than _x.ndim, we pad the reps with
2806+
# "1" so that reps will have the same ndim as _x.
28042807
reps_ = [switch(i < offset, 1, reps[i - offset]) for i in range(ndim)]
28052808
reps = reps_
28062809

@@ -2817,17 +2820,17 @@ def tile(x, reps, ndim=None):
28172820
):
28182821
raise ValueError("elements of reps must be scalars of integer dtype")
28192822

2820-
# If reps.ndim is less than x.ndim, we pad the reps with
2821-
# "1" so that reps will have the same ndim as x
2823+
# If reps.ndim is less than _x.ndim, we pad the reps with
2824+
# "1" so that reps will have the same ndim as _x
28222825
reps = list(reps)
28232826
if ndim is None:
2824-
ndim = builtins.max(len(reps), x.ndim)
2827+
ndim = builtins.max(len(reps), _x.ndim)
28252828
if len(reps) < ndim:
28262829
reps = [1] * (ndim - len(reps)) + reps
28272830

2828-
_shape = [1] * (ndim - x.ndim) + [x.shape[i] for i in range(x.ndim)]
2831+
_shape = [1] * (ndim - _x.ndim) + [_x.shape[i] for i in range(_x.ndim)]
28292832
alloc_shape = reps + _shape
2830-
y = alloc(x, *alloc_shape)
2833+
y = alloc(_x, *alloc_shape)
28312834
shuffle_ind = np.arange(ndim * 2).reshape(2, ndim)
28322835
shuffle_ind = shuffle_ind.transpose().flatten()
28332836
y = y.dimshuffle(*shuffle_ind)
@@ -3288,8 +3291,9 @@ def inverse_permutation(perm):
32883291
Each row of input should contain a permutation of the first integers.
32893292
32903293
"""
3294+
_perm = as_tensor_variable(perm)
32913295
return permute_row_elements(
3292-
arange(perm.shape[-1], dtype=perm.dtype), perm, inverse=True
3296+
arange(_perm.shape[-1], dtype=_perm.dtype), _perm, inverse=True
32933297
)
32943298

32953299

@@ -3575,12 +3579,14 @@ def diag(v, k=0):
35753579
35763580
"""
35773581

3578-
if v.ndim == 1:
3579-
return AllocDiag(k)(v)
3580-
elif v.ndim >= 2:
3581-
return diagonal(v, offset=k)
3582+
_v = as_tensor_variable(v)
3583+
3584+
if _v.ndim == 1:
3585+
return AllocDiag(k)(_v)
3586+
elif _v.ndim >= 2:
3587+
return diagonal(_v, offset=k)
35823588
else:
3583-
raise ValueError("Input must has v.ndim >= 1.")
3589+
raise ValueError("Number of dimensions of `v` must be greater than one.")
35843590

35853591

35863592
def stacklists(arg):

tests/tensor/test_basic.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,12 @@ def check(m):
10231023
rand2d[:4] = 0
10241024
check(rand2d)
10251025

1026+
# Test passing a list
1027+
m = [1, 2, 0]
1028+
out = flatnonzero(m)
1029+
f = function([], out)
1030+
assert np.array_equal(f(), np.flatnonzero(m))
1031+
10261032
@config.change_flags(compute_test_value="raise")
10271033
def test_nonzero_values(self):
10281034
def check(m):
@@ -1449,8 +1455,6 @@ def test_roll(self):
14491455

14501456
assert (out == want).all()
14511457

1452-
# Pass a list to make sure `a` is converted to a
1453-
# TensorVariable by roll
14541458
a = [1, 2, 3, 4, 5, 6]
14551459
b = roll(a, get_shift(2))
14561460
want = np.array([5, 6, 1, 2, 3, 4])
@@ -2221,6 +2225,20 @@ def run_tile(x, x_, reps, use_symbolic_reps):
22212225
== np.tile(x_, (2, 3, 4, 6))
22222226
)
22232227

2228+
# Test passing a float
2229+
x = scalar()
2230+
x_val = 1.0
2231+
assert np.array_equal(
2232+
run_tile(x, x_val, (2,), use_symbolic_reps), np.tile(x_val, (2,))
2233+
)
2234+
2235+
# Test when x is a list
2236+
x = matrix()
2237+
x_val = [[1.0, 2.0], [3.0, 4.0]]
2238+
assert np.array_equal(
2239+
run_tile(x, x_val, (2,), use_symbolic_reps), np.tile(x_val, (2,))
2240+
)
2241+
22242242
# Test when reps is integer, scalar or vector.
22252243
# Test 1,2,3,4-dimensional cases.
22262244
# Test input x has the shape [2], [2, 4], [2, 4, 3], [2, 4, 3, 5].
@@ -2794,6 +2812,12 @@ def test_dim1(self):
27942812
assert np.all(p_val[inv_val] == np.arange(10))
27952813
assert np.all(inv_val[p_val] == np.arange(10))
27962814

2815+
# Test passing a list
2816+
p = [2, 4, 3, 0, 1]
2817+
inv = at.inverse_permutation(p)
2818+
f = aesara.function([], inv)
2819+
assert np.array_equal(f(), np.array([3, 4, 0, 2, 1]))
2820+
27972821
def test_dim2(self):
27982822
# Test the inversion of several permutations at a time
27992823
# Each row of p is a different permutation to inverse
@@ -3449,6 +3473,12 @@ def test_diag(self):
34493473
with pytest.raises(ValueError):
34503474
diag(xx)
34513475

3476+
# Test passing a list
3477+
xx = [[1, 2], [3, 4]]
3478+
g = diag(xx)
3479+
f = function([], g)
3480+
assert np.array_equal(f(), np.diag(xx))
3481+
34523482
def test_infer_shape(self):
34533483
rng = np.random.default_rng(utt.fetch_seed())
34543484

@@ -4136,6 +4166,12 @@ def test_identity_like_dtype():
41364166
m_out_float = identity_like(m, dtype=np.float64)
41374167
assert m_out_float.dtype == "float64"
41384168

4169+
# Test passing list
4170+
m = [[0, 1], [1, 3]]
4171+
out = at.identity_like(m)
4172+
f = aesara.function([], out)
4173+
assert np.array_equal(f(), np.eye(2))
4174+
41394175

41404176
def test_atleast_Nd():
41414177
ary1 = dscalar()

0 commit comments

Comments
 (0)