Skip to content

Conversation

@tchan102
Copy link

@tchan102 tchan102 commented Nov 4, 2025

Add optimization for Join → Repeat when concatenating identical tensors

Description

This PR introduces a graph rewrite optimization in pytensor/tensor/rewriting/basic.py that replaces redundant Join operations with an equivalent and more efficient Repeat operation when all concatenated tensors are identical.

Example:
join(0, x, x, x) → repeat(x, 3, axis=0)

Key additions:

  • Implemented new rewrite function local_join_to_repeat registered under both @register_canonicalize and @register_specialize.
  • Added corresponding test test_local_join_to_repeat to verify correctness, performance, and behavior for vectors and matrices.

Related Issue

Checklist

Type of change

  • [ x] New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@ricardoV94 ricardoV94 added graph rewriting enhancement New feature or request labels Nov 4, 2025
@ricardoV94
Copy link
Member

Let's try with @register_canonicalize only

return

# Replace with repeat operation
result = repeat(tensors[0], len(tensors), axis)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing you'll need to handle is that join axis is symbolic but repeat must be a python integer. You should do something like:

if not isinstance(axis, Constant):
    return None  # rewrite only applies to constant axis
axis = axis.data

# Check optimization applied
ops = f.maker.fgraph.toposort()
assert len([n for n in ops if isinstance(n.op, Join)]) == 0
assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The RepeatOp will not actually be used. It's only used when there are vector repeats. For scalar Repeats we end up with Alloc instead (PyTensor version of BroadcastTo). Did you try to run the tests locally?

@ricardoV94
Copy link
Member

Btw would be nice to get rid of this join (and split) symbolic axis if you would like to work on that after this PR. relevant issue: #1528

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request graph rewriting

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Rewrite concatenate([x, x]) as repeat(x, 2)

2 participants