-
Notifications
You must be signed in to change notification settings - Fork 146
Rewrite concatenate([x, x]) as repeat(x, 2) #1714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Let's try with |
| return | ||
|
|
||
| # Replace with repeat operation | ||
| result = repeat(tensors[0], len(tensors), axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one thing you'll need to handle is that join axis is symbolic but repeat must be a python integer. You should do something like:
if not isinstance(axis, Constant):
return None # rewrite only applies to constant axis
axis = axis.data| # Check optimization applied | ||
| ops = f.maker.fgraph.toposort() | ||
| assert len([n for n in ops if isinstance(n.op, Join)]) == 0 | ||
| assert len([n for n in ops if isinstance(n.op, Repeat)]) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The RepeatOp will not actually be used. It's only used when there are vector repeats. For scalar Repeats we end up with Alloc instead (PyTensor version of BroadcastTo). Did you try to run the tests locally?
|
Btw would be nice to get rid of this join (and split) symbolic axis if you would like to work on that after this PR. relevant issue: #1528 |
Add optimization for Join → Repeat when concatenating identical tensors
Description
This PR introduces a graph rewrite optimization in pytensor/tensor/rewriting/basic.py that replaces redundant Join operations with an equivalent and more efficient Repeat operation when all concatenated tensors are identical.
Example:
join(0, x, x, x) → repeat(x, 3, axis=0)
Key additions:
Related Issue
Checklist
Type of change