Skip to content

Conversation

@aditvenk
Copy link
Contributor

PyTorch DTensor strategy can legitimately populate DTensorSpec without tensor_meta. In such cases, we attempt to do fake tensor propagation to populate tensor_meta, but for some ops, one or more outputs can legitimately be None depending on inputs (e.g., convolution.backward with certain output_mask).

In such cases, fake tensor prop cannot resolve the output tensor_meta, and we currently throw an error in validation.
Switch validation to instead emit a warning in such case. If tensor_meta is unknown, and that tensor is subsequently an input to a downstream op, we will fail during the input_spec validation.

Testing: Adding convolution test that revealed this issue.

PyTorch DTensor strategy can legitimately populate DTensorSpec without tensor_meta. In such cases, we attempt to do fake tensor propagation to populate tensor_meta, but for some ops, one or more outputs can legtimiately be None depending on inputs (e.g., convolution.backward with certain output_mask).

Switch validation to throw a warning in such case. If tensor_meta is legtimately not known, and the output of an op is subsequently an input to a downstream op, we will fail during the input_spec validation.

Testing: Adding convolution test that revealed this issue.

<!-- ps-id: 9c2841a8-6f6b-44c4-a07a-16fb9b32ac70 -->
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 26, 2025
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the diff @aditvenk !

So, one thing I'd like to clarify with DTensor folks (@wconstab @zpcore ) is around the expected outputs of DTensor specs.

My understanding is that there is some logic in DTensor that converts None placements as Replicate() somewhere down the road during DTensor sharding propagation (which is a shortcut which IMO should just be cleaned up, as we are querying the placements prior to this logic).
One case was in SDPA, which I worked around in

# This is overcoming some limitations of the lack of
# tensor_meta for sdpa which returns None
# we should just fix this all across the board
if ospec.tensor_meta is None:
ospec.tensor_meta = tm
else:
assert tm is None
so that we can assume tensor_meta is always present.

IMO we should always have the invariant that tensor_meta / redistribute_cost / etc are always present and populated in DTensorSpec, so that we can consistently rely on them.

Thoughts?

@aditvenk
Copy link
Contributor Author

Thanks for the diff @aditvenk !

So, one thing I'd like to clarify with DTensor folks (@wconstab @zpcore ) is around the expected outputs of DTensor specs.

My understanding is that there is some logic in DTensor that converts None placements as Replicate() somewhere down the road during DTensor sharding propagation (which is a shortcut which IMO should just be cleaned up, as we are querying the placements prior to this logic). One case was in SDPA, which I worked around in

# This is overcoming some limitations of the lack of
# tensor_meta for sdpa which returns None
# we should just fix this all across the board
if ospec.tensor_meta is None:
ospec.tensor_meta = tm
else:
assert tm is None

so that we can assume tensor_meta is always present.
IMO we should always have the invariant that tensor_meta / redistribute_cost / etc are always present and populated in DTensorSpec, so that we can consistently rely on them.

Thoughts?

The specific place where I saw the tensor meta as not populated came from here ( there is a TODO here too)
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/_ops/utils.py#L275

@fmassa
Copy link
Contributor

fmassa commented Dec 1, 2025

@wconstab wdty? Should we just enforce that we always return tensor_meta ?

@wconstab
Copy link
Contributor

wconstab commented Dec 1, 2025

Should we just enforce that we always return tensor_meta ?

I like the idea of enforcing this.

Annoyingly, I realized that in my single-dim rules currently, I am relying on not populating the output tensormeta because its weird to have to do shape inference inside the rule expansion infra, but i'm forced to create a new output spec. I still think we should figure out how to make this change, even if it requires a bit more of a refactor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants