11import warnings
22from collections .abc import Collection , Iterable , Sequence
3+ from itertools import pairwise
34from textwrap import dedent
45
56import numpy as np
4546from pytensor .tensor .math import sum as pt_sum
4647from pytensor .tensor .shape import Shape_i
4748from pytensor .tensor .subtensor import advanced_inc_subtensor1 , set_subtensor
48- from pytensor .tensor .type import TensorType , dvector , int_dtypes , integer_dtypes , vector
49+ from pytensor .tensor .type import TensorType , dvector , int_dtypes , integer_dtypes
4950from pytensor .tensor .utils import normalize_reduce_axis
5051from pytensor .tensor .variable import TensorVariable
5152from pytensor .utils import LOCAL_BITWIDTH , PYTHON_INT_BITWIDTH
@@ -2011,6 +2012,287 @@ def concat_with_broadcast(tensor_list, axis=0):
20112012 return join (axis , * bcast_tensor_inputs )
20122013
20132014
2015+ class Pack (Op ):
2016+ __props__ = ("axes" ,)
2017+
2018+ def __init__ (self , axes : int | Sequence [int ] | None ):
2019+ self .axes = tuple (axes ) if isinstance (axes , list ) else axes
2020+
2021+ def _analyze_axes_list (self ) -> tuple [int , int , int , int | None ]:
2022+ """
2023+ Analyze the provided axes list to determine how many axes are before and after the interval to be raveled, as
2024+ well as the minimum and maximum number of axes that the inputs can have.
2025+
2026+ The rules are:
2027+ - Axes must be strictly increasing in both the positive and negative parts of the list.
2028+ - Negative axes must come after positive axes.
2029+ - There can be at most one "hole" in the axes list, which can be either an implicit hole on an endpoint
2030+ (e.g. [0, 1]) or an explicit hole in the middle (e.g. [0, 2] or [1, -1]).
2031+
2032+ Returns
2033+ -------
2034+ n_axes_before: int
2035+ The number of axes before the interval to be raveled.
2036+ n_axes_after: int
2037+ The number of axes after the interval to be raveled.
2038+ min_axes: int
2039+ The minimum number of axes that the inputs must have.
2040+ max_axes: int or None
2041+ The maximum number of axes that the inputs can have, or None if there is no strict maximum. A maximum is
2042+ only introduced when it would resolve ambiguities in the interpretation of the axes list. For example,
2043+ [2, 3] can be either interpreted as having two ravel intervals [:2] and [4:], which is illegal,
2044+ unless 3 is interpreted as -1, which is only possible if all inputs have exactly 4 axes. Likewise,
2045+ [-3, -1] can be interpreted as having two ravel intervals [:-3], [-3:], unless -3 is interpreted as 0,
2046+ which is only possible if all inputs have exactly 3 axes.
2047+ """
2048+ axes = self .axes
2049+ if axes is None :
2050+ return 0 , 0 , 0 , None
2051+
2052+ if isinstance (axes , int ):
2053+ axes = [axes ]
2054+
2055+ if len (set (axes )) != len (axes ):
2056+ raise ValueError ("axes must have no duplicates" )
2057+ if axes is not None and len (axes ) == 0 :
2058+ raise ValueError ("axes=[] is ambiguous; use None to ravel all" )
2059+
2060+ first_negative_idx = next ((i for i , a in enumerate (axes ) if a < 0 ), len (axes ))
2061+ positive_axes = list (axes [:first_negative_idx ])
2062+ negative_axes = list (axes [first_negative_idx :])
2063+
2064+ if not all (a < 0 for a in negative_axes ):
2065+ raise ValueError ("Negative axes must come after positive" )
2066+
2067+ def strictly_increasing (s ):
2068+ return all (b > a for a , b in pairwise (s ))
2069+
2070+ if (positive_axes and not strictly_increasing (positive_axes )) or (
2071+ negative_axes and not strictly_increasing (negative_axes )
2072+ ):
2073+ raise ValueError ("Axes must be strictly increasing" )
2074+
2075+ def find_gaps (s ):
2076+ return [i for i , (a , b ) in enumerate (pairwise (s )) if b - a > 1 ]
2077+
2078+ pos_gaps = find_gaps (positive_axes )
2079+ neg_gaps = find_gaps (negative_axes )
2080+ positive_only = positive_axes and not negative_axes
2081+ negative_only = negative_axes and not positive_axes
2082+ mixed_case = positive_axes and negative_axes
2083+
2084+ max_axes : int | None = None
2085+
2086+ n_explicit_holes = len (pos_gaps ) + len (neg_gaps )
2087+ if n_explicit_holes > 1 :
2088+ raise ValueError (
2089+ "Too many holes in axes list. There can be at most one hole in the axes list, "
2090+ "including implict holes resulting from omitting the 0 or -1 axis."
2091+ )
2092+
2093+ if mixed_case :
2094+ if pos_gaps or neg_gaps :
2095+ raise ValueError (
2096+ "Too many holes in axes list. There can be at most one hole in the axes list, "
2097+ "including implict holes resulting from omitting the 0 or -1 axis. Because both "
2098+ "positive and negative axes are present, there is always assume to be an explit hole "
2099+ "between them."
2100+ )
2101+ n_before = len (positive_axes )
2102+ n_after = len (negative_axes )
2103+ min_axes = n_before + n_after
2104+
2105+ if positive_only :
2106+ # There are four cases to consider when all axes are positive:
2107+ # 0. There are two implicit gaps (0 is not present) and an explicit gap (e.g. [2, 4])
2108+ # This case is always illegal, as there is no interpretation that would result in having
2109+ # 1. There is only an implicit right hole (e.g. [0, 1])
2110+ # This case is legal, and requires no special interpretation. It corresponds to 'i j *' in einops
2111+ # 2. There is an explicit internal hole (e.g. [0, 2])
2112+ # This case is legal, but requires interpreting the last axis as -1, which introduces a maximum number
2113+ # of axes. It corresponds to 'i * j' in einops, and requires at least one input to have 3 dimensions, and
2114+ # no input to have more than 3 dimensions.
2115+ # 2. The axes start at an index greater than 0, but have no internal holes (e.g. [2, 3])
2116+ # This case is legal, but requires flipping the axes to negative indexing, so that the largest axis is
2117+ # -1, followed by -2, etc. This introduces a maximum number of axes.
2118+ if pos_gaps and positive_axes [0 ] != 0 :
2119+ raise ValueError (
2120+ "Too many holes in axes list. There can be at most one hole in the axes list, "
2121+ "including implict holes resulting from omitting the 0 or -1 axis. In this case, "
2122+ "there is an explicit internal hole as well as an implicit left hole."
2123+ )
2124+
2125+ elif positive_axes [0 ] == 0 and not pos_gaps :
2126+ # Case 1: Only right implicit hole. No ambiguities.
2127+ n_before = positive_axes [- 1 ] + 1
2128+ n_after = 0
2129+ min_axes = n_before + n_after
2130+ max_axes = None
2131+
2132+ elif pos_gaps :
2133+ # Case 2: Explicit hole in the positives, plus right implicit hole.
2134+ split = pos_gaps [0 ] + 1
2135+ n_before = split
2136+ n_after = len (positive_axes ) - split
2137+ min_axes = n_before + n_after
2138+
2139+ # Close the right implicit hole
2140+ max_axes = positive_axes [- 1 ] + 1
2141+
2142+ else :
2143+ # Case 3: Left and right implicit holes, but the right can be closed by flipping to negative axes and
2144+ # adding a maximum number of axes.
2145+ # Compute min_axes and max_axes under Case 1 of the negative_only scenario, with a max_axes constraint.
2146+ max_axes = positive_axes [- 1 ] + 1
2147+ n_before = 0
2148+ n_after = len (positive_axes )
2149+ min_axes = n_before + n_after
2150+
2151+ if negative_only :
2152+ # The same four cases are considered when all axes are negative, but ordering is reversed.
2153+ # 0. There are two implicit holes (e.g. [-4, -2])
2154+ # This case is always illegal, as there is no interpretation that would result in having only one hole
2155+ # in the axis list.
2156+ # 1. There is only an implicit left hole (e.g. [-2, -1])
2157+ # This case is legal, and requires no special interpretation. It corresponds to '* i j' in einops
2158+ # 2. There is an explicit internal hole (e.g. [-3, -1])
2159+ # This case is legal, but requires interpreting the smallest axis as 0, which introduces a maximum number
2160+ # of axes. It corresponds to '* i j' in einops, and requires at least one input to have 3 dimensions, and
2161+ # no input to have more than 3 dimensions.
2162+ # 3. The axes end at an index less than -1, but have no internal holes (e.g. [-4, -3]). Flip to positive
2163+ # axes, adding a maximum number of axes. Interpret the smallest axis as 0 to resolve ambiguity.
2164+ if neg_gaps and negative_axes [- 1 ] != - 1 :
2165+ raise ValueError (
2166+ "Too many holes in axes list. There can be at most one hole in the axes list, "
2167+ "including implict holes resulting from omitting the 0 or -1 axis. In this case, "
2168+ "there is an explicit internal hole as well as an implicit right hole."
2169+ )
2170+ elif negative_axes [- 1 ] == - 1 and not neg_gaps :
2171+ # Case 1: No ambiguities, only left implicit hole.
2172+ n_before = 0
2173+ n_after = len (negative_axes )
2174+ min_axes = n_before + n_after
2175+ max_axes = None
2176+ elif neg_gaps :
2177+ # Case 2: Explicit hole in the negatives, plus left implicit hole.
2178+ split = neg_gaps [0 ] + 1
2179+ n_before = split
2180+ n_after = len (negative_axes ) - split
2181+ min_axes = n_before + n_after
2182+
2183+ # Close the left implicit hole
2184+ max_axes = abs (min (negative_axes ))
2185+ else :
2186+ # Case 3: Left and right implicit holes, but the left can be closed by flipping to positive axes and
2187+ # adding a maximum number of axes.
2188+ max_axes = abs (negative_axes [0 ])
2189+ n_before = negative_axes [- 1 ] + max_axes + 1
2190+ n_after = 0
2191+ min_axes = n_before + n_after
2192+
2193+ return n_before , n_after , min_axes , max_axes
2194+
2195+ def make_node (self , * tensors : TensorVariable ):
2196+ tensors = [ptb .as_tensor_variable (t ) for t in tensors ]
2197+ n_axes_before , n_axes_after , min_axes , max_axes = self ._analyze_axes_list ()
2198+
2199+ if min ([t .ndim for t in tensors ]) < min_axes :
2200+ raise ValueError (
2201+ f"All input tensors to { self !s} must have at least { min_axes } dimensions, but the minimum "
2202+ f"number of dimensions found was { min ([t .ndim for t in tensors ])} ."
2203+ )
2204+
2205+ max_ndim = max ([t .ndim for t in tensors ])
2206+ if max_axes is not None and max_ndim > max_axes :
2207+ raise ValueError (
2208+ f"All input tensors to { self !s} must have at most { max_axes } dimensions, but the maximum "
2209+ f"number of dimensions found was { max_ndim } ."
2210+ )
2211+
2212+ def _coalesce_dim (shapes : list [int | None ], axis : int ) -> int | None :
2213+ unique_shapes = {s for s in shapes if s is not None }
2214+ if not unique_shapes :
2215+ return None
2216+ if len (unique_shapes ) > 1 :
2217+ raise ValueError (
2218+ f"Input tensors to Pack op have incompatible sizes on dimension { axis } : { shapes } "
2219+ )
2220+ return unique_shapes .pop ()
2221+
2222+ shapes_to_pack = [
2223+ t .type .shape [n_axes_before : t .ndim - n_axes_after ] for t in tensors
2224+ ]
2225+ packed_shape = (
2226+ None
2227+ if any (
2228+ shape is None
2229+ for packed_shape in shapes_to_pack
2230+ for shape in packed_shape
2231+ )
2232+ else int (sum (np .prod (shapes ) for shapes in shapes_to_pack ))
2233+ )
2234+ prefix_shapes = [
2235+ _coalesce_dim ([t .type .shape [i ] for t in tensors ], i )
2236+ for i in range (n_axes_before )
2237+ ]
2238+ suffix_shapes = [
2239+ _coalesce_dim (
2240+ [t .type .shape [t .ndim - n_axes_after + i ] for t in tensors ],
2241+ n_axes_before + i ,
2242+ )
2243+ for i in range (n_axes_after )
2244+ ]
2245+ out_shape = (* prefix_shapes , packed_shape , * suffix_shapes )
2246+
2247+ packed_output = ptb .tensor (dtype = tensors [0 ].dtype , shape = out_shape )
2248+ packed_shapes = [
2249+ ptb .tensor (dtype = "int64" , shape = (len (shapes ),)) for shapes in shapes_to_pack
2250+ ]
2251+
2252+ return Apply (self , tensors , [packed_output , * packed_shapes ])
2253+
2254+ def perform (self , node , inputs , outputs ):
2255+ tensors = inputs
2256+ packed_output , * packed_shapes = outputs
2257+
2258+ reshaped_tensors = []
2259+ tmp_shapes = []
2260+
2261+ n_axes_before , n_axes_after , min_axes , max_axes = self ._analyze_axes_list ()
2262+
2263+ if (
2264+ max_axes is not None
2265+ and any (t .ndim > max_axes for t in tensors )
2266+ and not any (t .ndim == max_axes for t in tensors )
2267+ ):
2268+ raise ValueError (
2269+ f"All input tensors must have at most { max_axes } axes, and at least one input tensor must have exactly "
2270+ f"{ max_axes } axes to resolve ambiguities in the interpretation of the axes list { self .axes } . A less"
2271+ f"ambiguous axes list can be used to avoid this restriction, usually by including 0 or -1 in the axes "
2272+ f"list."
2273+ )
2274+
2275+ for i , tensor in enumerate (tensors ):
2276+ shape = tensor .shape
2277+ ndim = tensor .ndim
2278+ if tensor .ndim < min_axes :
2279+ raise ValueError (
2280+ f"packed tensor #{ i } (enumeration starts with 0) has shape { shape } , "
2281+ f"while pattern { self .axes } assumes at least { min_axes } axes"
2282+ )
2283+ axis_after_packed_axes = ndim - n_axes_after
2284+ tmp_shapes .append (shape [n_axes_before :axis_after_packed_axes ])
2285+ reshaped_tensors .append (
2286+ tensor .reshape (
2287+ (* shape [:n_axes_before ], - 1 , * shape [axis_after_packed_axes :])
2288+ )
2289+ )
2290+
2291+ packed_output [0 ] = np .concatenate (reshaped_tensors , axis = n_axes_before )
2292+ for i , packed_shape in enumerate (tmp_shapes ):
2293+ packed_shapes [i ][0 ] = np .array (packed_shape ).astype ("int64" )
2294+
2295+
20142296def pack (
20152297 * tensors : TensorVariable , axes : int | Sequence [int ] | None = None
20162298) -> tuple [TensorVariable , list [tuple [TensorVariable ]]]:
@@ -2035,14 +2317,10 @@ def pack(
20352317 if not tensors :
20362318 raise ValueError ("Cannot pack an empty list of tensors." )
20372319
2038- packed_shapes = [
2039- t .type .shape if not any (s is None for s in t .type .shape ) else t .shape
2040- for t in tensors
2041- ]
2042-
2043- flat_tensor = join (0 , * [t .ravel () for t in tensors ])
2320+ pack_op = Pack (axes = axes )
2321+ packed_tensor , * packed_shapes = pack_op (* tensors )
20442322
2045- return flat_tensor , packed_shapes
2323+ return packed_tensor , packed_shapes
20462324
20472325
20482326def unpack (
@@ -2091,12 +2369,12 @@ def unpack(
20912369 "geomspace" ,
20922370 "linspace" ,
20932371 "logspace" ,
2372+ "pack" ,
20942373 "ravel_multi_index" ,
20952374 "repeat" ,
20962375 "searchsorted" ,
20972376 "squeeze" ,
20982377 "unique" ,
2099- "unravel_index" ,
2100- "pack" ,
21012378 "unpack" ,
2379+ "unravel_index" ,
21022380]
0 commit comments