-
Notifications
You must be signed in to change notification settings - Fork 108
[Feature] Add TensorClassModuleBase for type-safe TensorClass modules #1473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ensorDictModuleBase
|
Hi @az0uz! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this!
Overall great feature!
I left some comments here and there but they're pretty minor.
Can you sign the CLA? It's required to run the tests and merge
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new classes need to be added to the doc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I need to write anything else than adding them to docs/source/reference/nn.rst like above?
I've added a section with description and examples. let me know if that works for you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's really cool thx
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this!
Overall great feature!
I left some comments here and there but they're pretty minor.
Can you sign the CLA? It's required to run the tests and merge
|
The person responsible for the existing cla in my company has asked cla@ to add me, but haven't got any answer yet, I'll ping them again. Do I need to also sign it personally? |
No the company one should be fine. If we need to ping anyone on the Meta side LMK! |
Yes that would be great, no answer for a week, we've tried to ping cla@meta.com again yesterday but still no news. |
|
GHA seems down -- we'll need to wait till tests run for this |
vmoens
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's really cool thx
Description
This PR introduces
TensorClassModuleBase, a new base class that provides a type-safe way to define PyTorch modules that operate onTensorClassinstances. This implementation offers compile-time type checking and seamless integration with TensorDict-based workflows.Key Features:
TensorClassModuleBase[InputClass, OutputClass]with automatic type extractionTensorClassModuleWrapperenables conversion toTensorDictModuleviaas_td_module()methodImplementation Details:
tensordict/nn/tensorclass_module.pywithTensorClassModuleBaseandTensorClassModuleWrapper_tensor_class_keys()__init_subclass__for clean APItest/test_tensorclass_module.pycovering forward pass, TensorDict conversion, ONNX export, and edge casesAdditional Changes:
tensordict/__init__.pyandtensordict/nn/__init__.pyMotivation and Context
implements #1355
This change addresses the need for type-safe module definitions when working with TensorClass instances. Previously, developers had to work directly with TensorDict or manually handle type conversions, which was error-prone and lacked compile-time type checking.
With
TensorClassModuleBase, developers can:Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!