-
Notifications
You must be signed in to change notification settings - Fork 9
Don't crash if tensor_meta is not available for output spec #268
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Don't crash if tensor_meta is not available for output spec #268
Conversation
PyTorch DTensor strategy can legitimately populate DTensorSpec without tensor_meta. In such cases, we attempt to do fake tensor propagation to populate tensor_meta, but for some ops, one or more outputs can legtimiately be None depending on inputs (e.g., convolution.backward with certain output_mask). Switch validation to throw a warning in such case. If tensor_meta is legtimately not known, and the output of an op is subsequently an input to a downstream op, we will fail during the input_spec validation. Testing: Adding convolution test that revealed this issue. <!-- ps-id: 9c2841a8-6f6b-44c4-a07a-16fb9b32ac70 -->
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the diff @aditvenk !
So, one thing I'd like to clarify with DTensor folks (@wconstab @zpcore ) is around the expected outputs of DTensor specs.
My understanding is that there is some logic in DTensor that converts None placements as Replicate() somewhere down the road during DTensor sharding propagation (which is a shortcut which IMO should just be cleaned up, as we are querying the placements prior to this logic).
One case was in SDPA, which I worked around in
autoparallel/autoparallel/utils.py
Lines 84 to 90 in e794cc2
| # This is overcoming some limitations of the lack of | |
| # tensor_meta for sdpa which returns None | |
| # we should just fix this all across the board | |
| if ospec.tensor_meta is None: | |
| ospec.tensor_meta = tm | |
| else: | |
| assert tm is None |
tensor_meta is always present.
IMO we should always have the invariant that tensor_meta / redistribute_cost / etc are always present and populated in DTensorSpec, so that we can consistently rely on them.
Thoughts?
The specific place where I saw the tensor meta as not populated came from here ( there is a TODO here too) |
|
@wconstab wdty? Should we just enforce that we always return tensor_meta ? |
I like the idea of enforcing this. Annoyingly, I realized that in my single-dim rules currently, I am relying on not populating the output tensormeta because its weird to have to do shape inference inside the rule expansion infra, but i'm forced to create a new output spec. I still think we should figure out how to make this change, even if it requires a bit more of a refactor. |
PyTorch DTensor strategy can legitimately populate DTensorSpec without tensor_meta. In such cases, we attempt to do fake tensor propagation to populate tensor_meta, but for some ops, one or more outputs can legitimately be None depending on inputs (e.g., convolution.backward with certain output_mask).
In such cases, fake tensor prop cannot resolve the output tensor_meta, and we currently throw an error in validation.
Switch validation to instead emit a warning in such case. If tensor_meta is unknown, and that tensor is subsequently an input to a downstream op, we will fail during the input_spec validation.
Testing: Adding convolution test that revealed this issue.