|
48 | 48 | _BatchedUninitializedParameter, |
49 | 49 | _check_inbuild, |
50 | 50 | _clone_value, |
| 51 | + _create_segments_from_int, |
| 52 | + _create_segments_from_list, |
51 | 53 | _get_item, |
52 | 54 | _get_leaf_tensordict, |
53 | 55 | _get_shape_from_args, |
@@ -1750,71 +1752,28 @@ def split( |
1750 | 1752 | # we must use slices to keep the storage of the tensors |
1751 | 1753 | WRONG_TYPE = "split(): argument 'split_size' must be int or list of ints" |
1752 | 1754 | batch_size = self.batch_size |
1753 | | - batch_sizes = [] |
1754 | 1755 | dim = _maybe_correct_neg_dim(dim, batch_size) |
1755 | 1756 | max_size = batch_size[dim] |
1756 | 1757 | if isinstance(split_size, int): |
1757 | | - idx0 = 0 |
1758 | | - idx1 = min(max_size, split_size) |
1759 | | - batch_sizes.append( |
1760 | | - torch.Size( |
1761 | | - tuple( |
1762 | | - d if i != dim else idx1 - idx0 for i, d in enumerate(batch_size) |
1763 | | - ) |
1764 | | - ) |
1765 | | - ) |
1766 | | - while idx1 < max_size: |
1767 | | - idx0 = idx1 |
1768 | | - idx1 = min(max_size, idx1 + split_size) |
1769 | | - batch_sizes.append( |
1770 | | - torch.Size( |
1771 | | - tuple( |
1772 | | - d if i != dim else idx1 - idx0 |
1773 | | - for i, d in enumerate(batch_size) |
1774 | | - ) |
1775 | | - ) |
1776 | | - ) |
| 1758 | + segments = _create_segments_from_int(split_size, max_size) |
| 1759 | + chunks = -(self.batch_size[dim] // -split_size) |
| 1760 | + splits = {k: v.chunk(chunks, dim) for k, v in self.items()} |
1777 | 1761 | elif isinstance(split_size, (list, tuple)): |
1778 | 1762 | if len(split_size) == 0: |
1779 | 1763 | raise RuntimeError("Insufficient number of elements in split_size.") |
1780 | | - try: |
1781 | | - idx0 = 0 |
1782 | | - idx1 = split_size[0] |
1783 | | - batch_sizes.append( |
1784 | | - torch.Size( |
1785 | | - tuple( |
1786 | | - d if i != dim else idx1 - idx0 |
1787 | | - for i, d in enumerate(batch_size) |
1788 | | - ) |
1789 | | - ) |
1790 | | - ) |
1791 | | - for idx in split_size[1:]: |
1792 | | - idx0 = idx1 |
1793 | | - idx1 = min(max_size, idx1 + idx) |
1794 | | - batch_sizes.append( |
1795 | | - torch.Size( |
1796 | | - tuple( |
1797 | | - d if i != dim else idx1 - idx0 |
1798 | | - for i, d in enumerate(batch_size) |
1799 | | - ) |
1800 | | - ) |
1801 | | - ) |
1802 | | - except TypeError: |
| 1764 | + if not all(isinstance(x, int) for x in split_size): |
1803 | 1765 | raise TypeError(WRONG_TYPE) |
1804 | | - |
1805 | | - if idx1 < batch_size[dim]: |
1806 | | - raise RuntimeError( |
1807 | | - f"Split method expects split_size to sum exactly to {self.batch_size[dim]} (tensor's size at dimension {dim}), but got split_size={split_size}" |
1808 | | - ) |
| 1766 | + splits = {k: v.split(split_size, dim) for k, v in self.items()} |
| 1767 | + segments = _create_segments_from_list(split_size, max_size) |
1809 | 1768 | else: |
1810 | 1769 | raise TypeError(WRONG_TYPE) |
1811 | 1770 | names = self._maybe_names() |
1812 | | - # Use chunk instead of split to account for nested tensors if possible |
1813 | | - if isinstance(split_size, int): |
1814 | | - chunks = -(self.batch_size[dim] // -split_size) |
1815 | | - splits = {k: v.chunk(chunks, dim) for k, v in self.items()} |
1816 | | - else: |
1817 | | - splits = {k: v.split(split_size, dim) for k, v in self.items()} |
| 1771 | + batch_sizes = [ |
| 1772 | + torch.Size( |
| 1773 | + tuple(d if i != dim else end - start for i, d in enumerate(batch_size)) |
| 1774 | + ) |
| 1775 | + for start, end in segments |
| 1776 | + ] |
1818 | 1777 | splits = [ |
1819 | 1778 | {k: v[ss] for k, v in splits.items()} for ss in range(len(batch_sizes)) |
1820 | 1779 | ] |
@@ -2184,7 +2143,6 @@ def from_dict_instance( |
2184 | 2143 | batch_dims=None, |
2185 | 2144 | names=None, |
2186 | 2145 | ): |
2187 | | - |
2188 | 2146 | if batch_dims is not None and batch_size is not None: |
2189 | 2147 | raise ValueError( |
2190 | 2148 | "Cannot pass both batch_size and batch_dims to `from_dict`." |
@@ -2274,7 +2232,7 @@ def batch_dims(self) -> int: |
2274 | 2232 | @batch_dims.setter |
2275 | 2233 | def batch_dims(self, value: int) -> None: |
2276 | 2234 | raise RuntimeError( |
2277 | | - f"Setting batch dims on {type(self).__name__} instances is " f"not allowed." |
| 2235 | + f"Setting batch dims on {type(self).__name__} instances is not allowed." |
2278 | 2236 | ) |
2279 | 2237 |
|
2280 | 2238 | def _has_names(self): |
@@ -2763,7 +2721,6 @@ def _memmap_( |
2763 | 2721 | share_non_tensor, |
2764 | 2722 | existsok, |
2765 | 2723 | ) -> T: |
2766 | | - |
2767 | 2724 | if prefix is not None: |
2768 | 2725 | prefix = Path(prefix) |
2769 | 2726 | if not prefix.exists(): |
@@ -2806,7 +2763,6 @@ def _memmap_( |
2806 | 2763 | ) |
2807 | 2764 | continue |
2808 | 2765 | else: |
2809 | | - |
2810 | 2766 | if executor is None: |
2811 | 2767 | _populate_memmap( |
2812 | 2768 | dest=dest, |
|
0 commit comments