Releases: mctigger/tensorcontainer
0.8.2: Fixes issue with __setitem__ and non-tuple indices
0.8.1: Enhanced PyTree Error Diagnostics and Refactoring
Enhanced PyTree Error Diagnostics and Refactoring
This update introduces significant enhancements to error reporting for PyTree operations (e.g., torch.stack, torch.cat) on TensorContainer subclasses such as TensorDict and TensorDataClass. The primary objective is to provide developers with clear, actionable diagnostics when structural mismatches occur, thereby improving the debugging workflow. This is accompanied by a refactoring of the underlying PyTree context implementation for improved modularity and extensibility.
Key Changes
1. Detailed Mismatch Diagnosis
A new utility, diagnose_pytree_structure_mismatch, has been implemented to provide detailed diagnostics for structural differences. This replaces generic RuntimeError or ValueError exceptions with specific, informative messages that identify the root cause of an issue. The diagnostics cover:
- Key/Field Mismatches: Differences in keys for
TensorDictor fields forTensorDataClass. - Type Mismatches: Incompatible container types (e.g.,
TensorDictvs.list). - Device Mismatches: Containers located on different devices (e.g.,
cpuvs.cuda). - Nesting Mismatches: Discrepancies in the nested structure of containers.
2. Example Error Message
When attempting to stack TensorContainer instances with incompatible keys, the new error message provides a precise description of the problem:
Structure mismatch at container: containers have incompatible layouts.
Container 0 at container: TensorDict(keys=['a', 'b'], device=cpu)
Container 1 at container: TensorDict(keys=['a', 'c'], device=cpu)
Fix: Key mismatch detected. Missing keys in container 1: ['b']. Extra keys in container 1: ['c'].
3. Internal Refactoring
To support these improvements, the following architectural changes were made:
- Structured Exceptions: A set of specific exception classes (
TypeMismatch,ContextMismatch,KeyPathMismatch) were introduced to represent distinct error conditions. - PyTree Context Abstraction: A new
ContextWithAnalysisabstract base class standardizes structural analysis, andTensorContainerPytreeContextcentralizes common logic to reduce code duplication. - Dataclass Conversion: PyTree context classes were converted from
NamedTupletodataclassto improve readability and extensibility.
What's Changed
- Improved PyTree Error Messages and Refactoring by @mctigger in #21
- fix(build): bump project version to 0.8.1 by @mctigger in #22
Full Changelog: 0.8.0...0.8.1
0.8.0: Beter documentation, error handling and fixes torch.cat
What's Changed
- Add documentation for the package by @mctigger in #15
- Documentation by @mctigger in #18
- fix(container): improve 'too many indices' error handling by @mctigger in #19
- refactor(container): centralize pytree mapping and fixes operations o… by @mctigger in #20
Full Changelog: 0.7.1...0.8.0
Add TensorSymLog
Adds TensorSymLog based on SymLogDistribution
Validation streamlined, unified types, added documentation
What's Changed
- Tensor dict types by @mctigger in #13
- Adds more sophisticated tests for TensorContainer and subclassing
- Adds documentation for development rationale
- Removes validate_args from TensorContainer. Instead, use context manager only now.
Full Changelog: 0.6.3...0.7
Unified TensorDistribution
Overview
TensorContainer 0.6.3 introduces significant improvements to the tensor distribution module, featuring unified implementations, enhanced performance, and better developer experience. This release focuses on simplifying distribution classes while maintaining full compatibility with existing code.
What's New
Major Improvements
Unified TensorDistribution Implementation
- Comprehensive Refactoring: All tensor distribution classes have been unified under a consistent implementation pattern
- Code Reduction: Removed over 1000 lines of redundant code while adding new functionality
- Better Inheritance: Enhanced class hierarchy to properly leverage parent class functionality
- Consistent APIs: All distributions now follow the same patterns for parameter handling and validation
Enhanced Parameter Handling
- Unified Broadcasting: Implemented consistent parameter broadcasting across all distributions using
broadcast_all - Improved Validation: Better error messages and validation logic for distribution parameters
- Scalar Parameter Support: Enhanced handling of scalar parameters in distribution initialization
New Features
unflattenDistribution Method: New functionality in the base distribution class for handling flattened parameters- Enhanced
validate_argsSupport: Improved argument validation across all distribution classes - TensorOneHotCategoricalStraightThrough: New distribution implementation for straight-through gradient estimation
Performance & Quality Improvements
Code Quality
- Modern Type Hints: Updated to use Python 3.10+ union syntax (
|instead ofUnion) - Better Error Handling: Replaced generic RuntimeError with more specific ValueError for parameter validation
- Consistent Code Style: Unified formatting and style across all distribution implementations
Testing Enhancements
- Expanded Test Coverage: Added comprehensive tests for pytree integration and parameter validation
- Better Error Testing: Improved test cases for error handling and edge cases
- Parameter Factory Fixtures: Added reusable test fixtures for distribution parameter testing
Technical Details
Changed Components
- 55 files modified across the tensor distribution module
- All 30+ distribution classes refactored for consistency:
- Bernoulli, Beta, Binomial, Categorical, Cauchy, Chi2
- ContinuousBernoulli, Dirichlet, Exponential, FisherSnedecor
- Gamma, Geometric, Gumbel, HalfCauchy, HalfNormal
- InverseGamma, Kumaraswamy, Laplace, LogisticNormal
- Multinomial, MultivariateNormal, NegativeBinomial, Normal
- OneHotCategorical, Pareto, Poisson, RelaxedBernoulli
- RelaxedOneHotCategorical, StudentT, TanhNormal
- TruncatedNormal, Uniform, VonMises, Weibull, Wishart
Breaking Changes
While most user code should continue to work without changes, there are some minor breaking changes:
- Some
RuntimeErrorinstances replaced withValueErrorfor better parameter validation - Simplified initialization signatures in some distribution classes
Migration Guide
Most existing code will continue to work without modification. For cases where changes are needed:
- Update error handling code to catch
ValueErrorinstead ofRuntimeErrorfor parameter validation - Review distribution initialization if using advanced parameter configurations
Compatibility
- Python: 3.9+ (dropped Python 3.8 support)
- PyTorch: 2.0+
Fix: TensorTanhNormal and TensorIndependent
Key Fixes and Improvements
TensorIndependentShape Bug: A bug wherereinterpreted_batch_ndims=0resulted in an incorrect shape has been fixed. The shape calculation now correctly returns the full shape of the base distribution in this case.TensorTanhNormalSimplification: Thereinterpreted_batch_ndimsparameter has been removed fromTensorTanhNormalto enforce a clearer separation of concerns. To achieve this functionality, users should now wrapTensorTanhNormalwithTensorIndependent. The class also now includes new statistical properties like mean, variance, and standard deviation.SamplingDistributionEnhancements: This class has been rewritten for better performance and reliability. It now uses__slots__for memory efficiency, includes better caching for statistical properties, improves error handling, and adds input validation.
Impact and Migration
This update introduces a breaking change: the reinterpreted_batch_ndims parameter is no longer available in TensorTanhNormal.
Fix: Device propagation
- Fixes missing device propagation to child tensor containers
v0.6.0: torch.distributions Parity with TensorAnnotated Architecture
This release represents a massive expansion of TensorContainer with complete torch.distributions parity, introducing TensorAnnotated as a new base class, dramatically expanding the distribution ecosystem, and adding comprehensive documentation.
Highlights
- Complete torch.distributions Parity: All 43+ distributions from torch.distributions now have TensorContainer equivalents with full API compatibility
- TensorAnnotated Architecture: New foundational base class that powers both TensorDataClass and TensorDistribution with unified annotation-based tensor handling
- Massive Distribution Ecosystem: Added 35+ new distribution types including complex distributions like Wishart, LKJ Cholesky, and MixtureSameFamily
- Enhanced Documentation: Comprehensive documentation system with detailed guides and practical examples
- Improved torch.compile Support: Better integration with PyTorch's compilation system and reduced graph breaks
New Features
TensorAnnotated Base Class
- New foundational class that unifies tensor annotation handling across TensorDataClass and TensorDistribution
- Automatic detection and transformation of annotated tensor attributes
- Improved inheritance patterns and subclassing support
- Enhanced PyTree registration and serialization capabilities
Complete Distribution Coverage
Added support for all remaining torch.distributions types:
- Statistical Distributions: Beta, Gamma, Chi2, StudentT, FisherSnedecor, InverseGamma, Pareto, Weibull
- Multivariate Distributions: Wishart, LKJCholesky, LowRankMultivariateNormal, LogisticNormal
- Discrete Distributions: Geometric, NegativeBinomial, Multinomial, Poisson
- Specialized Distributions: MixtureSameFamily, Independent, TransformedDistribution
- Advanced Distributions: RelaxedBernoulli, RelaxedOneHotCategorical, VonMises, Kumaraswamy
Enhanced Distribution Features
- KL divergence support with automatic registration
- Advanced tensor operations (stacking, copying, view operations)
- Comprehensive compile compatibility for all distributions
- Improved parameter validation and error handling
Documentation System
- Complete API Documentation: Individual guides for TensorAnnotated, TensorDataClass, and TensorDistribution
- Practical Examples: 6 comprehensive examples demonstrating flexibility and chaining operations
- Testing Guide: Detailed testing philosophy and standards documentation
- Compatibility Guide: Python version compatibility and best practices
Improvements
torch.compile Enhancements
- Better graph break detection and prevention
- Improved recompilation tracking
- Enhanced compatibility testing across all distribution types
- Reduced compilation overhead for tensor operations
Performance Optimizations
- More efficient parameter handling in distributions
- Optimized annotation processing in TensorAnnotated
- Better memory management for large tensor operations
- Improved tensor transformation pipelines
Developer Experience
- Enhanced CI/CD pipeline with nightly builds and comprehensive testing
- Improved error messages and debugging support
- Better IDE integration with enhanced type annotations
- Comprehensive test coverage across all new features
Bug Fixes
- Fixed Python 3.9 compatibility issues with annotation handling
- Resolved subclassing behavior inconsistencies
- Improved device consistency across tensor operations
- Fixed memory leaks in TensorDataClass operations
- Enhanced shape validation and error reporting
Testing & Quality
- Expanded Test Suite: Grew from ~643 tests to 900+ tests with 42 new distribution test files
- Enhanced Coverage: Comprehensive testing for all distribution types and tensor operations
- Compile Testing: Extensive torch.compile compatibility verification
- Cross-Platform Testing: Support for Python 3.9-3.12 across multiple platforms
Breaking Changes
- TensorDistribution now inherits from TensorAnnotated instead of directly from TensorContainer
- Some internal APIs have changed to support the new annotation system
- Distribution parameter validation now defers to torch.distributions for consistency
Documentation
- Added comprehensive documentation in
docs/directory - Six practical examples demonstrating real-world usage patterns
- Detailed compatibility and testing guides
- Enhanced README with expanded use cases and examples
Initial release with Python 3.9+ support
Release Notes
This release introduces major new features, including TensorDataClass and TensorDistribution, as well as significant improvements to torch.compile compatibility and a more comprehensive test suite.
Highlights
- TensorDataClass: A new, type-safe, dataclass-based container for tensors that provides a strongly-typed alternative to
TensorDict. It automatically converts annotated class definitions into dataclasses with optimized settings for tensor operations. - TensorDistribution: A new set of classes for representing and manipulating probability distributions of tensors, including
TensorNormal,TensorBernoulli,TensorCategorical,TensorTruncatedNormal, andTensorTanhNormal. - Enhanced
torch.compileCompatibility: This release includes significant improvements totorch.compilecompatibility, with a focus on reducing graph breaks and recompilations. The test suite now includes extensive tests fortorch.compilecompatibility, including graph break detection and recompilation tracking. - Comprehensive Test Suite: The test suite has been expanded to over 643 tests, with 86% code coverage, ensuring the reliability and stability of the library.
New Features
- TensorDataClass:
- Provides a strongly-typed, dataclass-based container for tensors with automatic field generation and batch semantics.
- Supports static typing with IDE support and autocomplete, natural inheritance patterns, and memory-efficient
slots=Truelayout. - Seamlessly integrates with PyTree for tree operations and is compatible with
torch.compileand JIT compilation.
- TensorDistribution:
- A new set of classes for representing and manipulating probability distributions of tensors, including
TensorNormal,TensorBernoulli,TensorCategorical,TensorTruncatedNormal, andTensorTanhNormal. - Supports standard distribution operations like
sample,rsample,log_prob,entropy,mean,stddev, andmode. - Includes a
ClampedTanhTransformand aSamplingDistributionfor more complex transformations and sampling strategies.
- A new set of classes for representing and manipulating probability distributions of tensors, including
- Tensor Manipulation:
- Added support for
expand,squeeze,unsqueeze,permute, andtransposeoperations onTensorDictandTensorDataClassinstances.
- Added support for
unsafe_construction:- Added a new context manager,
TensorContainer.unsafe_construction(), to disable validation during construction for performance-critical scenarios.
- Added a new context manager,
Bug Fixes and Improvements
- Improved handling of various indexing scenarios in
__getitem__and__setitem__, including basic, advanced, ellipsis, newaxis, and boolean mask indexing. - Ensured correct handling of device and shape consistency in various operations, including
to,cpu,cuda, and casting methods. - Improved the
__repr__method forTensorDictandTensorDataClassto provide more informative and correctly formatted output. - Fixed memory leaks in
TensorDataClass.
Other Changes
- The CI/CD pipeline now tests against Python 3.9, 3.10, and 3.11, and uses
rufffor linting and formatting. - The documentation has been improved with more detailed docstrings and a comprehensive
README.mdfile.