Skip to content

Commit 73228d1

Browse files
committed
ENH: test searchsorted with x2.ndim > 1
1 parent 978db05 commit 73228d1

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

array_api_tests/test_searching_functions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,12 +255,18 @@ def test_searchsorted(data):
255255
sorter = None
256256
x1 = xp.sort(x1)
257257
note(f"{x1=}")
258+
258259
x2 = data.draw(
259260
st.lists(st.sampled_from(_x1), unique=True, min_size=1).map(
260261
lambda o: xp.asarray(o, dtype=dh.default_float)
261262
),
262263
label="x2",
263264
)
265+
# make x2.ndim > 1, if it makes sense
266+
factors = hh._factorize(x2.shape[0])
267+
if len(factors) > 1:
268+
x2 = xp.reshape(x2, tuple(factors))
269+
264270
kw = data.draw(hh.kwargs(side=st.sampled_from(["left", "right"])))
265271

266272
repro_snippet = ph.format_snippet(f"xp.searchsorted({x1!r}, {x2!r}, sorter={sorter!r})")
@@ -273,7 +279,7 @@ def test_searchsorted(data):
273279
out_dtype=out.dtype,
274280
expected=xp.__array_namespace_info__().default_dtypes()["indexing"],
275281
)
276-
# TODO: x2.ndim > 1, values testing
282+
# TODO: values testing
277283
ph.assert_shape("searchsorted", out_shape=out.shape, expected=x2.shape)
278284
except Exception as exc:
279285
exc.add_note(repro_snippet)

0 commit comments

Comments
 (0)