88
99import pytensor
1010import pytensor .scalar .basic as ps
11+ from pytensor .compile .builders import OpFromGraph
1112from pytensor .gradient import (
1213 DisconnectedType ,
1314 _float_zeros_like ,
4445)
4546from pytensor .tensor .math import max as pt_max
4647from pytensor .tensor .math import sum as pt_sum
47- from pytensor .tensor .shape import Shape_i
48+ from pytensor .tensor .shape import Shape_i , specify_shape
4849from pytensor .tensor .subtensor import advanced_inc_subtensor1 , set_subtensor
4950from pytensor .tensor .type import TensorType , dvector , int_dtypes , integer_dtypes
5051from pytensor .tensor .utils import normalize_reduce_axis
@@ -2012,11 +2013,10 @@ def concat_with_broadcast(tensor_list, axis=0):
20122013 return join (axis , * bcast_tensor_inputs )
20132014
20142015
2015- class Pack (Op ):
2016- __props__ = ("axes" ,)
2017-
2016+ class PackHelper :
20182017 def __init__ (self , axes : int | Sequence [int ] | None ):
20192018 self .axes = tuple (axes ) if isinstance (axes , list ) else axes
2019+ self .op_name = "Pack{axes=" + str (self .axes ) + "}"
20202020
20212021 def _analyze_axes_list (self ) -> tuple [int , int , int , int | None ]:
20222022 """
@@ -2192,23 +2192,31 @@ def find_gaps(s):
21922192
21932193 return n_before , n_after , min_axes , max_axes
21942194
2195- def make_node (self , * tensors : TensorVariable ):
2195+ def validate_inputs (self , tensors : list [ TensorLike ] ):
21962196 tensors = [ptb .as_tensor_variable (t ) for t in tensors ]
2197- n_axes_before , n_axes_after , min_axes , max_axes = self ._analyze_axes_list ()
2197+ _ , _ , min_axes , max_axes = self ._analyze_axes_list ()
21982198
21992199 if min ([t .ndim for t in tensors ]) < min_axes :
22002200 raise ValueError (
2201- f"All input tensors to { self !s } must have at least { min_axes } dimensions, but the minimum "
2201+ f"All input tensors to { self . op_name } must have at least { min_axes } dimensions, but the minimum "
22022202 f"number of dimensions found was { min ([t .ndim for t in tensors ])} ."
22032203 )
22042204
22052205 max_ndim = max ([t .ndim for t in tensors ])
2206- if max_axes is not None and max_ndim > max_axes :
2206+ if (
2207+ max_axes is not None
2208+ and max_ndim > max_axes
2209+ and not any (t .ndim == max_axes for t in tensors )
2210+ ):
22072211 raise ValueError (
2208- f"All input tensors to { self !s } must have at most { max_axes } dimensions, but the maximum "
2212+ f"All input tensors to { self . op_name } must have at most { max_axes } dimensions, but the maximum "
22092213 f"number of dimensions found was { max_ndim } ."
22102214 )
22112215
2216+ def infer_shape (self , tensors : list [TensorLike ]) -> tuple [int | None , ...]:
2217+ tensors = [ptb .as_tensor_variable (t ) for t in tensors ]
2218+ n_axes_before , n_axes_after , _ , _ = self ._analyze_axes_list ()
2219+
22122220 def _coalesce_dim (shapes : list [int | None ], axis : int ) -> int | None :
22132221 unique_shapes = {s for s in shapes if s is not None }
22142222 if not unique_shapes :
@@ -2242,55 +2250,12 @@ def _coalesce_dim(shapes: list[int | None], axis: int) -> int | None:
22422250 )
22432251 for i in range (n_axes_after )
22442252 ]
2245- out_shape = (* prefix_shapes , packed_shape , * suffix_shapes )
2246-
2247- packed_output = ptb .tensor (dtype = tensors [0 ].dtype , shape = out_shape )
2248- packed_shapes = [
2249- ptb .tensor (dtype = "int64" , shape = (len (shapes ),)) for shapes in shapes_to_pack
2250- ]
2251-
2252- return Apply (self , tensors , [packed_output , * packed_shapes ])
2253-
2254- def perform (self , node , inputs , outputs ):
2255- tensors = inputs
2256- packed_output , * packed_shapes = outputs
2257-
2258- reshaped_tensors = []
2259- tmp_shapes = []
22602253
2261- n_axes_before , n_axes_after , min_axes , max_axes = self ._analyze_axes_list ()
2262-
2263- if (
2264- max_axes is not None
2265- and any (t .ndim > max_axes for t in tensors )
2266- and not any (t .ndim == max_axes for t in tensors )
2267- ):
2268- raise ValueError (
2269- f"All input tensors must have at most { max_axes } axes, and at least one input tensor must have exactly "
2270- f"{ max_axes } axes to resolve ambiguities in the interpretation of the axes list { self .axes } . A less"
2271- f"ambiguous axes list can be used to avoid this restriction, usually by including 0 or -1 in the axes "
2272- f"list."
2273- )
2254+ return (* prefix_shapes , packed_shape , * suffix_shapes )
22742255
2275- for i , tensor in enumerate (tensors ):
2276- shape = tensor .shape
2277- ndim = tensor .ndim
2278- if tensor .ndim < min_axes :
2279- raise ValueError (
2280- f"packed tensor #{ i } (enumeration starts with 0) has shape { shape } , "
2281- f"while pattern { self .axes } assumes at least { min_axes } axes"
2282- )
2283- axis_after_packed_axes = ndim - n_axes_after
2284- tmp_shapes .append (shape [n_axes_before :axis_after_packed_axes ])
2285- reshaped_tensors .append (
2286- tensor .reshape (
2287- (* shape [:n_axes_before ], - 1 , * shape [axis_after_packed_axes :])
2288- )
2289- )
22902256
2291- packed_output [0 ] = np .concatenate (reshaped_tensors , axis = n_axes_before )
2292- for i , packed_shape in enumerate (tmp_shapes ):
2293- packed_shapes [i ][0 ] = np .array (packed_shape ).astype ("int64" )
2257+ class Pack (OpFromGraph ):
2258+ "Wrapper for the Pack Op"
22942259
22952260
22962261def pack (
@@ -2317,10 +2282,44 @@ def pack(
23172282 if not tensors :
23182283 raise ValueError ("Cannot pack an empty list of tensors." )
23192284
2320- pack_op = Pack (axes = axes )
2321- packed_tensor , * packed_shapes = pack_op (* tensors )
2285+ tensors = [ptb .as_tensor (tensor ) for tensor in tensors ]
2286+
2287+ pack_helper = PackHelper (axes = axes )
2288+
2289+ reshaped_tensors = []
2290+ tmp_shapes = []
2291+
2292+ n_axes_before , n_axes_after , _ , _ = pack_helper ._analyze_axes_list ()
2293+ pack_helper .validate_inputs (tensors )
2294+ output_shape = pack_helper .infer_shape (tensors )
2295+
2296+ for i , tensor in enumerate (tensors ):
2297+ shape = tensor .shape
2298+ ndim = tensor .ndim
2299+ axis_after_packed_axes = ndim - n_axes_after
2300+ tmp_shapes .append (shape [n_axes_before :axis_after_packed_axes ])
2301+ reshaped_tensors .append (
2302+ tensor .reshape (
2303+ (* shape [:n_axes_before ], - 1 , * shape [axis_after_packed_axes :])
2304+ )
2305+ )
2306+
2307+ packed_output_tensor = specify_shape (
2308+ ptb .join (n_axes_before , * reshaped_tensors ), output_shape
2309+ )
2310+ packed_output_shapes = [
2311+ ptb .as_tensor_variable (packed_shape ).astype ("int64" )
2312+ for i , packed_shape in enumerate (tmp_shapes )
2313+ ]
2314+
2315+ pack_op = Pack (
2316+ inputs = tensors ,
2317+ outputs = [packed_output_tensor , * packed_output_shapes ],
2318+ name = "Pack{axes=" + str (axes ) + "}" ,
2319+ )
23222320
2323- return packed_tensor , packed_shapes
2321+ outputs = pack_op (* tensors )
2322+ return outputs [0 ], outputs [1 :]
23242323
23252324
23262325def unpack (
0 commit comments