Skip to content

Releases: mctigger/tensorcontainer

0.8.2: Fixes issue with __setitem__ and non-tuple indices

12 Sep 13:41
e5c3c69

Choose a tag to compare

What's Changed

  • Modernize type hints to PEP 604 and collections.abc by @mctigger in #23
  • Rename publish environment to nightly by @mctigger in #24
  • Fix setitem in TensorContainer by @mctigger in #25
  • fix(build): bump version for setitem fix by @mctigger in #26

Full Changelog: 0.8.1...0.8.2

0.8.1: Enhanced PyTree Error Diagnostics and Refactoring

10 Sep 11:08
f93849f

Choose a tag to compare

Enhanced PyTree Error Diagnostics and Refactoring

This update introduces significant enhancements to error reporting for PyTree operations (e.g., torch.stack, torch.cat) on TensorContainer subclasses such as TensorDict and TensorDataClass. The primary objective is to provide developers with clear, actionable diagnostics when structural mismatches occur, thereby improving the debugging workflow. This is accompanied by a refactoring of the underlying PyTree context implementation for improved modularity and extensibility.


Key Changes

1. Detailed Mismatch Diagnosis

A new utility, diagnose_pytree_structure_mismatch, has been implemented to provide detailed diagnostics for structural differences. This replaces generic RuntimeError or ValueError exceptions with specific, informative messages that identify the root cause of an issue. The diagnostics cover:

  • Key/Field Mismatches: Differences in keys for TensorDict or fields for TensorDataClass.
  • Type Mismatches: Incompatible container types (e.g., TensorDict vs. list).
  • Device Mismatches: Containers located on different devices (e.g., cpu vs. cuda).
  • Nesting Mismatches: Discrepancies in the nested structure of containers.

2. Example Error Message

When attempting to stack TensorContainer instances with incompatible keys, the new error message provides a precise description of the problem:

Structure mismatch at container: containers have incompatible layouts.

Container 0 at container: TensorDict(keys=['a', 'b'], device=cpu)
Container 1 at container: TensorDict(keys=['a', 'c'], device=cpu)

Fix: Key mismatch detected. Missing keys in container 1: ['b']. Extra keys in container 1: ['c'].

3. Internal Refactoring

To support these improvements, the following architectural changes were made:

  • Structured Exceptions: A set of specific exception classes (TypeMismatch, ContextMismatch, KeyPathMismatch) were introduced to represent distinct error conditions.
  • PyTree Context Abstraction: A new ContextWithAnalysis abstract base class standardizes structural analysis, and TensorContainerPytreeContext centralizes common logic to reduce code duplication.
  • Dataclass Conversion: PyTree context classes were converted from NamedTuple to dataclass to improve readability and extensibility.

What's Changed

  • Improved PyTree Error Messages and Refactoring by @mctigger in #21
  • fix(build): bump project version to 0.8.1 by @mctigger in #22

Full Changelog: 0.8.0...0.8.1

0.8.0: Beter documentation, error handling and fixes torch.cat

01 Sep 15:19
233da76

Choose a tag to compare

What's Changed

  • Add documentation for the package by @mctigger in #15
  • Documentation by @mctigger in #18
  • fix(container): improve 'too many indices' error handling by @mctigger in #19
  • refactor(container): centralize pytree mapping and fixes operations o… by @mctigger in #20

Full Changelog: 0.7.1...0.8.0

Add TensorSymLog

20 Aug 12:18
55fd89e

Choose a tag to compare

Adds TensorSymLog based on SymLogDistribution

Validation streamlined, unified types, added documentation

14 Aug 09:15
01babde

Choose a tag to compare

What's Changed

  • Tensor dict types by @mctigger in #13
  • Adds more sophisticated tests for TensorContainer and subclassing
  • Adds documentation for development rationale
  • Removes validate_args from TensorContainer. Instead, use context manager only now.

Full Changelog: 0.6.3...0.7

Unified TensorDistribution

05 Aug 06:38
1d7d2d2

Choose a tag to compare

Overview

TensorContainer 0.6.3 introduces significant improvements to the tensor distribution module, featuring unified implementations, enhanced performance, and better developer experience. This release focuses on simplifying distribution classes while maintaining full compatibility with existing code.

What's New

Major Improvements

Unified TensorDistribution Implementation

  • Comprehensive Refactoring: All tensor distribution classes have been unified under a consistent implementation pattern
  • Code Reduction: Removed over 1000 lines of redundant code while adding new functionality
  • Better Inheritance: Enhanced class hierarchy to properly leverage parent class functionality
  • Consistent APIs: All distributions now follow the same patterns for parameter handling and validation

Enhanced Parameter Handling

  • Unified Broadcasting: Implemented consistent parameter broadcasting across all distributions using broadcast_all
  • Improved Validation: Better error messages and validation logic for distribution parameters
  • Scalar Parameter Support: Enhanced handling of scalar parameters in distribution initialization

New Features

  • unflatten Distribution Method: New functionality in the base distribution class for handling flattened parameters
  • Enhanced validate_args Support: Improved argument validation across all distribution classes
  • TensorOneHotCategoricalStraightThrough: New distribution implementation for straight-through gradient estimation

Performance & Quality Improvements

Code Quality

  • Modern Type Hints: Updated to use Python 3.10+ union syntax (| instead of Union)
  • Better Error Handling: Replaced generic RuntimeError with more specific ValueError for parameter validation
  • Consistent Code Style: Unified formatting and style across all distribution implementations

Testing Enhancements

  • Expanded Test Coverage: Added comprehensive tests for pytree integration and parameter validation
  • Better Error Testing: Improved test cases for error handling and edge cases
  • Parameter Factory Fixtures: Added reusable test fixtures for distribution parameter testing

Technical Details

Changed Components

  • 55 files modified across the tensor distribution module
  • All 30+ distribution classes refactored for consistency:
    • Bernoulli, Beta, Binomial, Categorical, Cauchy, Chi2
    • ContinuousBernoulli, Dirichlet, Exponential, FisherSnedecor
    • Gamma, Geometric, Gumbel, HalfCauchy, HalfNormal
    • InverseGamma, Kumaraswamy, Laplace, LogisticNormal
    • Multinomial, MultivariateNormal, NegativeBinomial, Normal
    • OneHotCategorical, Pareto, Poisson, RelaxedBernoulli
    • RelaxedOneHotCategorical, StudentT, TanhNormal
    • TruncatedNormal, Uniform, VonMises, Weibull, Wishart

Breaking Changes

While most user code should continue to work without changes, there are some minor breaking changes:

  • Some RuntimeError instances replaced with ValueError for better parameter validation
  • Simplified initialization signatures in some distribution classes

Migration Guide

Most existing code will continue to work without modification. For cases where changes are needed:

  1. Update error handling code to catch ValueError instead of RuntimeError for parameter validation
  2. Review distribution initialization if using advanced parameter configurations

Compatibility

  • Python: 3.9+ (dropped Python 3.8 support)
  • PyTorch: 2.0+

Fix: TensorTanhNormal and TensorIndependent

23 Jul 17:07
e1cec97

Choose a tag to compare

Key Fixes and Improvements

  • TensorIndependent Shape Bug: A bug where reinterpreted_batch_ndims=0 resulted in an incorrect shape has been fixed. The shape calculation now correctly returns the full shape of the base distribution in this case.
  • TensorTanhNormal Simplification: The reinterpreted_batch_ndims parameter has been removed from TensorTanhNormal to enforce a clearer separation of concerns. To achieve this functionality, users should now wrap TensorTanhNormal with TensorIndependent. The class also now includes new statistical properties like mean, variance, and standard deviation.
  • SamplingDistribution Enhancements: This class has been rewritten for better performance and reliability. It now uses __slots__ for memory efficiency, includes better caching for statistical properties, improves error handling, and adds input validation.

Impact and Migration

This update introduces a breaking change: the reinterpreted_batch_ndims parameter is no longer available in TensorTanhNormal.

Fix: Device propagation

19 Jul 22:06
80763ca

Choose a tag to compare

  • Fixes missing device propagation to child tensor containers

v0.6.0: torch.distributions Parity with TensorAnnotated Architecture

19 Jul 19:40
a8eca59

Choose a tag to compare

This release represents a massive expansion of TensorContainer with complete torch.distributions parity, introducing TensorAnnotated as a new base class, dramatically expanding the distribution ecosystem, and adding comprehensive documentation.

Highlights

  • Complete torch.distributions Parity: All 43+ distributions from torch.distributions now have TensorContainer equivalents with full API compatibility
  • TensorAnnotated Architecture: New foundational base class that powers both TensorDataClass and TensorDistribution with unified annotation-based tensor handling
  • Massive Distribution Ecosystem: Added 35+ new distribution types including complex distributions like Wishart, LKJ Cholesky, and MixtureSameFamily
  • Enhanced Documentation: Comprehensive documentation system with detailed guides and practical examples
  • Improved torch.compile Support: Better integration with PyTorch's compilation system and reduced graph breaks

New Features

TensorAnnotated Base Class

  • New foundational class that unifies tensor annotation handling across TensorDataClass and TensorDistribution
  • Automatic detection and transformation of annotated tensor attributes
  • Improved inheritance patterns and subclassing support
  • Enhanced PyTree registration and serialization capabilities

Complete Distribution Coverage

Added support for all remaining torch.distributions types:

  • Statistical Distributions: Beta, Gamma, Chi2, StudentT, FisherSnedecor, InverseGamma, Pareto, Weibull
  • Multivariate Distributions: Wishart, LKJCholesky, LowRankMultivariateNormal, LogisticNormal
  • Discrete Distributions: Geometric, NegativeBinomial, Multinomial, Poisson
  • Specialized Distributions: MixtureSameFamily, Independent, TransformedDistribution
  • Advanced Distributions: RelaxedBernoulli, RelaxedOneHotCategorical, VonMises, Kumaraswamy

Enhanced Distribution Features

  • KL divergence support with automatic registration
  • Advanced tensor operations (stacking, copying, view operations)
  • Comprehensive compile compatibility for all distributions
  • Improved parameter validation and error handling

Documentation System

  • Complete API Documentation: Individual guides for TensorAnnotated, TensorDataClass, and TensorDistribution
  • Practical Examples: 6 comprehensive examples demonstrating flexibility and chaining operations
  • Testing Guide: Detailed testing philosophy and standards documentation
  • Compatibility Guide: Python version compatibility and best practices

Improvements

torch.compile Enhancements

  • Better graph break detection and prevention
  • Improved recompilation tracking
  • Enhanced compatibility testing across all distribution types
  • Reduced compilation overhead for tensor operations

Performance Optimizations

  • More efficient parameter handling in distributions
  • Optimized annotation processing in TensorAnnotated
  • Better memory management for large tensor operations
  • Improved tensor transformation pipelines

Developer Experience

  • Enhanced CI/CD pipeline with nightly builds and comprehensive testing
  • Improved error messages and debugging support
  • Better IDE integration with enhanced type annotations
  • Comprehensive test coverage across all new features

Bug Fixes

  • Fixed Python 3.9 compatibility issues with annotation handling
  • Resolved subclassing behavior inconsistencies
  • Improved device consistency across tensor operations
  • Fixed memory leaks in TensorDataClass operations
  • Enhanced shape validation and error reporting

Testing & Quality

  • Expanded Test Suite: Grew from ~643 tests to 900+ tests with 42 new distribution test files
  • Enhanced Coverage: Comprehensive testing for all distribution types and tensor operations
  • Compile Testing: Extensive torch.compile compatibility verification
  • Cross-Platform Testing: Support for Python 3.9-3.12 across multiple platforms

Breaking Changes

  • TensorDistribution now inherits from TensorAnnotated instead of directly from TensorContainer
  • Some internal APIs have changed to support the new annotation system
  • Distribution parameter validation now defers to torch.distributions for consistency

Documentation

  • Added comprehensive documentation in docs/ directory
  • Six practical examples demonstrating real-world usage patterns
  • Detailed compatibility and testing guides
  • Enhanced README with expanded use cases and examples

Initial release with Python 3.9+ support

07 Jul 16:19
f429ec8

Choose a tag to compare

Release Notes

This release introduces major new features, including TensorDataClass and TensorDistribution, as well as significant improvements to torch.compile compatibility and a more comprehensive test suite.

Highlights

  • TensorDataClass: A new, type-safe, dataclass-based container for tensors that provides a strongly-typed alternative to TensorDict. It automatically converts annotated class definitions into dataclasses with optimized settings for tensor operations.
  • TensorDistribution: A new set of classes for representing and manipulating probability distributions of tensors, including TensorNormal, TensorBernoulli, TensorCategorical, TensorTruncatedNormal, and TensorTanhNormal.
  • Enhanced torch.compile Compatibility: This release includes significant improvements to torch.compile compatibility, with a focus on reducing graph breaks and recompilations. The test suite now includes extensive tests for torch.compile compatibility, including graph break detection and recompilation tracking.
  • Comprehensive Test Suite: The test suite has been expanded to over 643 tests, with 86% code coverage, ensuring the reliability and stability of the library.

New Features

  • TensorDataClass:
    • Provides a strongly-typed, dataclass-based container for tensors with automatic field generation and batch semantics.
    • Supports static typing with IDE support and autocomplete, natural inheritance patterns, and memory-efficient slots=True layout.
    • Seamlessly integrates with PyTree for tree operations and is compatible with torch.compile and JIT compilation.
  • TensorDistribution:
    • A new set of classes for representing and manipulating probability distributions of tensors, including TensorNormal, TensorBernoulli, TensorCategorical, TensorTruncatedNormal, and TensorTanhNormal.
    • Supports standard distribution operations like sample, rsample, log_prob, entropy, mean, stddev, and mode.
    • Includes a ClampedTanhTransform and a SamplingDistribution for more complex transformations and sampling strategies.
  • Tensor Manipulation:
    • Added support for expand, squeeze, unsqueeze, permute, and transpose operations on TensorDict and TensorDataClass instances.
  • unsafe_construction:
    • Added a new context manager, TensorContainer.unsafe_construction(), to disable validation during construction for performance-critical scenarios.

Bug Fixes and Improvements

  • Improved handling of various indexing scenarios in __getitem__ and __setitem__, including basic, advanced, ellipsis, newaxis, and boolean mask indexing.
  • Ensured correct handling of device and shape consistency in various operations, including to, cpu, cuda, and casting methods.
  • Improved the __repr__ method for TensorDict and TensorDataClass to provide more informative and correctly formatted output.
  • Fixed memory leaks in TensorDataClass.

Other Changes

  • The CI/CD pipeline now tests against Python 3.9, 3.10, and 3.11, and uses ruff for linting and formatting.
  • The documentation has been improved with more detailed docstrings and a comprehensive README.md file.