Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.extra_ops import broadcast_arrays, repeat
from pytensor.tensor.math import Sum, add, eq, variadic_add
from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.type import DenseTensorType, TensorType
Expand Down Expand Up @@ -910,6 +910,42 @@ def local_join_make_vector(fgraph, node):
return [ret]


@register_specialize
@register_canonicalize
@node_rewriter([Join])
def local_join_to_repeat(fgraph, node):
"""Join(axis, x, x, x, ...) -> repeat(x, n, axis)

When the same tensor is concatenated multiple times,
replace with a single repeat operation which is more efficient.

Examples
--------
concatenate([x, x, x], axis=0) -> repeat(x, 3, axis=0)
"""
if not isinstance(node.op, Join):
return

# Extract axis and the tensors being joined
axis, *tensors = node.inputs

# Need at least 2 tensors to consider optimization
if len(tensors) <= 1:
return

# Check if all tensors are identical
if not all(t == tensors[0] for t in tensors[1:]):
return

# Replace with repeat operation
result = repeat(tensors[0], len(tensors), axis)

# Preserve debugging information
copy_stack_trace(node.outputs[0], result)

return [result]


@register_specialize
@register_canonicalize
@register_useless
Expand Down
81 changes: 81 additions & 0 deletions tests/tensor/rewriting/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
tile,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.extra_ops import Repeat
from pytensor.tensor.math import (
add,
bitwise_and,
Expand Down Expand Up @@ -1247,6 +1248,86 @@ def test_local_join_1():
assert f.maker.fgraph.outputs[0].dtype == config.floatX


def test_local_join_to_repeat():
"""Test that Join(axis, x, x, ...) gets rewritten to repeat(x, n, axis)"""

# Test with vector - concatenate same vector 3 times along axis 0
x = vector("x")
s = join(0, x, x, x)
f = function([x], s, mode=rewrite_mode)

# Check numerical correctness
test_val = np.array([1.0, 2.0, 3.0], dtype=config.floatX)
result = f(test_val)
expected = np.array(
[1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0], dtype=config.floatX
)
assert np.allclose(result, expected)

# Check that Join was replaced with Repeat
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1

# Test with matrix - concatenate same matrix along axis 0
a = matrix("a")
s = join(0, a, a, a, a)
f = function([a], s, mode=rewrite_mode)

test_mat = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
result = f(test_mat)
expected = np.vstack([test_mat, test_mat, test_mat, test_mat])
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1

# Test with matrix - concatenate along axis 1
s = join(1, a, a)
f = function([a], s, mode=rewrite_mode)

result = f(test_mat)
expected = np.hstack([test_mat, test_mat])
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1

# Test that it does NOT apply when tensors are different
b = matrix("b")
s = join(0, a, b)
f = function([a, b], s, mode=rewrite_mode)

test_mat1 = np.array([[1.0, 2.0]], dtype=config.floatX)
test_mat2 = np.array([[3.0, 4.0]], dtype=config.floatX)
result = f(test_mat1, test_mat2)
expected = np.vstack([test_mat1, test_mat2])
assert np.allclose(result, expected)

# Join should still be present (not optimized to Repeat)
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 1
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 0

# Test with 5 repetitions to ensure it works with larger counts
s = join(0, x, x, x, x, x)
f = function([x], s, mode=rewrite_mode)

test_val = np.array([1.0, 2.0], dtype=config.floatX)
result = f(test_val)
expected = np.tile(test_val, 5)
assert np.allclose(result, expected)

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1


def test_local_join_empty():
# Vector case
empty_vec = np.asarray([], dtype=config.floatX)
Expand Down
Loading