Skip to content

Commit ed60651

Browse files
Implement Pack as OpFromGraph
1 parent 5788333 commit ed60651

File tree

2 files changed

+76
-77
lines changed

2 files changed

+76
-77
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 57 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import pytensor
1010
import pytensor.scalar.basic as ps
11+
from pytensor.compile.builders import OpFromGraph
1112
from pytensor.gradient import (
1213
DisconnectedType,
1314
_float_zeros_like,
@@ -44,7 +45,7 @@
4445
)
4546
from pytensor.tensor.math import max as pt_max
4647
from pytensor.tensor.math import sum as pt_sum
47-
from pytensor.tensor.shape import Shape_i
48+
from pytensor.tensor.shape import Shape_i, specify_shape
4849
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
4950
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes
5051
from pytensor.tensor.utils import normalize_reduce_axis
@@ -2012,11 +2013,10 @@ def concat_with_broadcast(tensor_list, axis=0):
20122013
return join(axis, *bcast_tensor_inputs)
20132014

20142015

2015-
class Pack(Op):
2016-
__props__ = ("axes",)
2017-
2016+
class PackHelper:
20182017
def __init__(self, axes: int | Sequence[int] | None):
20192018
self.axes = tuple(axes) if isinstance(axes, list) else axes
2019+
self.op_name = "Pack{axes=" + str(self.axes) + "}"
20202020

20212021
def _analyze_axes_list(self) -> tuple[int, int, int, int | None]:
20222022
"""
@@ -2192,23 +2192,31 @@ def find_gaps(s):
21922192

21932193
return n_before, n_after, min_axes, max_axes
21942194

2195-
def make_node(self, *tensors: TensorVariable):
2195+
def validate_inputs(self, tensors: list[TensorLike]):
21962196
tensors = [ptb.as_tensor_variable(t) for t in tensors]
2197-
n_axes_before, n_axes_after, min_axes, max_axes = self._analyze_axes_list()
2197+
_, _, min_axes, max_axes = self._analyze_axes_list()
21982198

21992199
if min([t.ndim for t in tensors]) < min_axes:
22002200
raise ValueError(
2201-
f"All input tensors to {self!s} must have at least {min_axes} dimensions, but the minimum "
2201+
f"All input tensors to {self.op_name} must have at least {min_axes} dimensions, but the minimum "
22022202
f"number of dimensions found was {min([t.ndim for t in tensors])}."
22032203
)
22042204

22052205
max_ndim = max([t.ndim for t in tensors])
2206-
if max_axes is not None and max_ndim > max_axes:
2206+
if (
2207+
max_axes is not None
2208+
and max_ndim > max_axes
2209+
and not any(t.ndim == max_axes for t in tensors)
2210+
):
22072211
raise ValueError(
2208-
f"All input tensors to {self!s} must have at most {max_axes} dimensions, but the maximum "
2212+
f"All input tensors to {self.op_name} must have at most {max_axes} dimensions, but the maximum "
22092213
f"number of dimensions found was {max_ndim}."
22102214
)
22112215

2216+
def infer_shape(self, tensors: list[TensorLike]) -> tuple[int | None, ...]:
2217+
tensors = [ptb.as_tensor_variable(t) for t in tensors]
2218+
n_axes_before, n_axes_after, _, _ = self._analyze_axes_list()
2219+
22122220
def _coalesce_dim(shapes: list[int | None], axis: int) -> int | None:
22132221
unique_shapes = {s for s in shapes if s is not None}
22142222
if not unique_shapes:
@@ -2242,55 +2250,12 @@ def _coalesce_dim(shapes: list[int | None], axis: int) -> int | None:
22422250
)
22432251
for i in range(n_axes_after)
22442252
]
2245-
out_shape = (*prefix_shapes, packed_shape, *suffix_shapes)
2246-
2247-
packed_output = ptb.tensor(dtype=tensors[0].dtype, shape=out_shape)
2248-
packed_shapes = [
2249-
ptb.tensor(dtype="int64", shape=(len(shapes),)) for shapes in shapes_to_pack
2250-
]
2251-
2252-
return Apply(self, tensors, [packed_output, *packed_shapes])
2253-
2254-
def perform(self, node, inputs, outputs):
2255-
tensors = inputs
2256-
packed_output, *packed_shapes = outputs
2257-
2258-
reshaped_tensors = []
2259-
tmp_shapes = []
22602253

2261-
n_axes_before, n_axes_after, min_axes, max_axes = self._analyze_axes_list()
2262-
2263-
if (
2264-
max_axes is not None
2265-
and any(t.ndim > max_axes for t in tensors)
2266-
and not any(t.ndim == max_axes for t in tensors)
2267-
):
2268-
raise ValueError(
2269-
f"All input tensors must have at most {max_axes} axes, and at least one input tensor must have exactly "
2270-
f"{max_axes} axes to resolve ambiguities in the interpretation of the axes list {self.axes}. A less"
2271-
f"ambiguous axes list can be used to avoid this restriction, usually by including 0 or -1 in the axes "
2272-
f"list."
2273-
)
2254+
return (*prefix_shapes, packed_shape, *suffix_shapes)
22742255

2275-
for i, tensor in enumerate(tensors):
2276-
shape = tensor.shape
2277-
ndim = tensor.ndim
2278-
if tensor.ndim < min_axes:
2279-
raise ValueError(
2280-
f"packed tensor #{i} (enumeration starts with 0) has shape {shape}, "
2281-
f"while pattern {self.axes} assumes at least {min_axes} axes"
2282-
)
2283-
axis_after_packed_axes = ndim - n_axes_after
2284-
tmp_shapes.append(shape[n_axes_before:axis_after_packed_axes])
2285-
reshaped_tensors.append(
2286-
tensor.reshape(
2287-
(*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:])
2288-
)
2289-
)
22902256

2291-
packed_output[0] = np.concatenate(reshaped_tensors, axis=n_axes_before)
2292-
for i, packed_shape in enumerate(tmp_shapes):
2293-
packed_shapes[i][0] = np.array(packed_shape).astype("int64")
2257+
class Pack(OpFromGraph):
2258+
"Wrapper for the Pack Op"
22942259

22952260

22962261
def pack(
@@ -2317,10 +2282,44 @@ def pack(
23172282
if not tensors:
23182283
raise ValueError("Cannot pack an empty list of tensors.")
23192284

2320-
pack_op = Pack(axes=axes)
2321-
packed_tensor, *packed_shapes = pack_op(*tensors)
2285+
tensors = [ptb.as_tensor(tensor) for tensor in tensors]
2286+
2287+
pack_helper = PackHelper(axes=axes)
2288+
2289+
reshaped_tensors = []
2290+
tmp_shapes = []
2291+
2292+
n_axes_before, n_axes_after, _, _ = pack_helper._analyze_axes_list()
2293+
pack_helper.validate_inputs(tensors)
2294+
output_shape = pack_helper.infer_shape(tensors)
2295+
2296+
for i, tensor in enumerate(tensors):
2297+
shape = tensor.shape
2298+
ndim = tensor.ndim
2299+
axis_after_packed_axes = ndim - n_axes_after
2300+
tmp_shapes.append(shape[n_axes_before:axis_after_packed_axes])
2301+
reshaped_tensors.append(
2302+
tensor.reshape(
2303+
(*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:])
2304+
)
2305+
)
2306+
2307+
packed_output_tensor = specify_shape(
2308+
ptb.join(n_axes_before, *reshaped_tensors), output_shape
2309+
)
2310+
packed_output_shapes = [
2311+
ptb.as_tensor_variable(packed_shape).astype("int64")
2312+
for i, packed_shape in enumerate(tmp_shapes)
2313+
]
2314+
2315+
pack_op = Pack(
2316+
inputs=tensors,
2317+
outputs=[packed_output_tensor, *packed_output_shapes],
2318+
name="Pack{axes=" + str(axes) + "}",
2319+
)
23222320

2323-
return packed_tensor, packed_shapes
2321+
outputs = pack_op(*tensors)
2322+
return outputs[0], outputs[1:]
23242323

23252324

23262325
def unpack(

tests/tensor/test_extra_ops.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
CumOp,
2222
FillDiagonal,
2323
FillDiagonalOffset,
24-
Pack,
24+
PackHelper,
2525
RavelMultiIndex,
2626
Repeat,
2727
SearchsortedOp,
@@ -1421,52 +1421,52 @@ class TestPack:
14211421
],
14221422
)
14231423
def test_analyze_axes_list_valid(self, axes, expected):
1424-
op = Pack(axes)
1425-
outputs = op._analyze_axes_list()
1424+
helper = PackHelper(axes)
1425+
outputs = helper._analyze_axes_list()
14261426
names = ["n_before", "n_after", "min_axes", "max_axes"]
14271427
for out, exp, name in zip(outputs, expected, names, strict=True):
14281428
assert out == exp, f"Expected {exp}, got {out} for {name}"
14291429

14301430
def test_analyze_axes_list_invalid(self):
14311431
# Two explicit holes
1432-
op = Pack([0, 2, -1])
1432+
helper = PackHelper([0, 2, -1])
14331433
with pytest.raises(ValueError, match="Too many holes"):
1434-
op._analyze_axes_list()
1434+
helper._analyze_axes_list()
14351435

14361436
# Explict hole + two implicit holes
1437-
op = Pack([1, 3])
1437+
helper = PackHelper([1, 3])
14381438
with pytest.raises(ValueError, match="Too many holes"):
1439-
op._analyze_axes_list()
1439+
helper._analyze_axes_list()
14401440

14411441
# Two explicit holes, all positive
1442-
op = Pack([0, 2, 4])
1442+
helper = PackHelper([0, 2, 4])
14431443
with pytest.raises(ValueError, match="Too many holes"):
1444-
op._analyze_axes_list()
1444+
helper._analyze_axes_list()
14451445

14461446
# Explicit hole + two implicit hole, all negative
1447-
op = Pack([-4, -2])
1447+
helper = PackHelper([-4, -2])
14481448
with pytest.raises(ValueError, match="Too many holes"):
1449-
op._analyze_axes_list()
1449+
helper._analyze_axes_list()
14501450

14511451
# Two explicit holes + implicit hole, all negative
1452-
op = Pack([-5, -3, -1])
1452+
helper = PackHelper([-5, -3, -1])
14531453
with pytest.raises(ValueError, match="Too many holes"):
1454-
op._analyze_axes_list()
1454+
helper._analyze_axes_list()
14551455

14561456
# Duplicate axes
1457-
op = Pack([0, 0])
1457+
helper = PackHelper([0, 0])
14581458
with pytest.raises(ValueError, match="axes must have no duplicates"):
1459-
op._analyze_axes_list()
1459+
helper._analyze_axes_list()
14601460

14611461
# Not monotonic
1462-
op = Pack([0, 2, 1])
1462+
helper = PackHelper([0, 2, 1])
14631463
with pytest.raises(ValueError, match="Axes must be strictly increasing"):
1464-
op._analyze_axes_list()
1464+
helper._analyze_axes_list()
14651465

14661466
# Negative before positive
1467-
op = Pack([-1, 0])
1467+
helper = PackHelper([-1, 0])
14681468
with pytest.raises(ValueError, match="Negative axes must come after positive"):
1469-
op._analyze_axes_list()
1469+
helper._analyze_axes_list()
14701470

14711471
def test_pack_basic(self):
14721472
# rng = np.random.default_rng()

0 commit comments

Comments
 (0)