diff --git a/pytensor/scan/views.py b/pytensor/scan/views.py index b86476b330..7d9365bb47 100644 --- a/pytensor/scan/views.py +++ b/pytensor/scan/views.py @@ -170,3 +170,63 @@ def foldr(fn, sequences, outputs_info, non_sequences=None, mode=None, name=None) mode=mode, name=name, ) + + +def filter( + fn, + sequences, + non_sequences=None, + go_backwards=False, + mode=None, + name=None, +): + """Construct a `Scan` `Op` that functions like `filter`. + + Parameters + ---------- + fn : callable + Predicate function returning a boolean tensor. + sequences : list + Sequences to filter. + non_sequences : list + Non-iterated arguments passed to `fn`. + go_backwards : bool + Whether to iterate in reverse. + mode : str or None + See ``scan``. + name : str or None + See ``scan``. + + Notes + ----- + If the predicate function `fn` returns multiple boolean masks (one per sequence), + each mask will be applied to its corresponding sequence. If it returns a single mask, + that mask will be broadcast to all sequences. + """ + mask, _ = scan( + fn=fn, + sequences=sequences, + outputs_info=None, + non_sequences=non_sequences, + go_backwards=go_backwards, + mode=mode, + name=name, + ) + + if isinstance(mask, (list, tuple)): + # One mask per sequence + if not isinstance(sequences, (list, tuple)): + raise TypeError( + "If multiple masks are returned, sequences must be a list or tuple." + ) + if len(mask) != len(sequences): + raise ValueError("Number of masks must match number of sequences.") + filtered_sequences = [seq[m] for seq, m in zip(sequences, mask)] + else: + # Single mask applied to all sequences + if isinstance(sequences, (list, tuple)): + filtered_sequences = [seq[mask] for seq in sequences] + else: + filtered_sequences = sequences[mask] + + return filtered_sequences diff --git a/tests/scan/test_views.py b/tests/scan/test_views.py index 38c9b9cfcd..3002f9cd3a 100644 --- a/tests/scan/test_views.py +++ b/tests/scan/test_views.py @@ -3,6 +3,7 @@ import pytensor.tensor as pt from pytensor import config, function, grad, shared from pytensor.compile.mode import FAST_RUN +from pytensor.scan.views import filter as pt_filter from pytensor.scan.views import foldl, foldr from pytensor.scan.views import map as pt_map from pytensor.scan.views import reduce as pt_reduce @@ -133,3 +134,42 @@ def test_foldr_memory_consumption(): gx = grad(o, x) f2 = function([], gx) utt.assert_allclose(f2(), np.ones((10,))) + + +def test_filter(): + v = pt.vector("v") + + def fn(x): + return pt.eq(x % 2, 0) + + filtered = pt_filter(fn, v) + f = function([v], filtered, allow_input_downcast=True) + + rng = np.random.default_rng(utt.fetch_seed()) + vals = rng.integers(0, 10, size=(10,)) + expected = vals[vals % 2 == 0] + result = f(vals) + utt.assert_allclose(expected, result) + + +def test_filter_multiple_masks(): + v1 = pt.vector("v1") + v2 = pt.vector("v2") + + def fn(x1, x2): + # Mask v1 for even numbers, mask v2 for numbers > 5 + return pt.eq(x1 % 2, 0), pt.gt(x2, 5) + + filtered_v1, filtered_v2 = pt_filter(fn, [v1, v2]) + f = function([v1, v2], [filtered_v1, filtered_v2], allow_input_downcast=True) + + vals1 = np.arange(10) + vals2 = np.arange(10) + + expected_v1 = vals1[vals1 % 2 == 0] + expected_v2 = vals2[vals2 > 5] + + result_v1, result_v2 = f(vals1, vals2) + + utt.assert_allclose(expected_v1, result_v1) + utt.assert_allclose(expected_v2, result_v2)