Skip to content

Commit 5788333

Browse files
Feature complete Pack Op
1 parent 58c0286 commit 5788333

File tree

2 files changed

+444
-48
lines changed

2 files changed

+444
-48
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 288 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import warnings
22
from collections.abc import Collection, Iterable, Sequence
3+
from itertools import pairwise
34
from textwrap import dedent
45

56
import numpy as np
@@ -45,7 +46,7 @@
4546
from pytensor.tensor.math import sum as pt_sum
4647
from pytensor.tensor.shape import Shape_i
4748
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
48-
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
49+
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes
4950
from pytensor.tensor.utils import normalize_reduce_axis
5051
from pytensor.tensor.variable import TensorVariable
5152
from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
@@ -2011,6 +2012,287 @@ def concat_with_broadcast(tensor_list, axis=0):
20112012
return join(axis, *bcast_tensor_inputs)
20122013

20132014

2015+
class Pack(Op):
2016+
__props__ = ("axes",)
2017+
2018+
def __init__(self, axes: int | Sequence[int] | None):
2019+
self.axes = tuple(axes) if isinstance(axes, list) else axes
2020+
2021+
def _analyze_axes_list(self) -> tuple[int, int, int, int | None]:
2022+
"""
2023+
Analyze the provided axes list to determine how many axes are before and after the interval to be raveled, as
2024+
well as the minimum and maximum number of axes that the inputs can have.
2025+
2026+
The rules are:
2027+
- Axes must be strictly increasing in both the positive and negative parts of the list.
2028+
- Negative axes must come after positive axes.
2029+
- There can be at most one "hole" in the axes list, which can be either an implicit hole on an endpoint
2030+
(e.g. [0, 1]) or an explicit hole in the middle (e.g. [0, 2] or [1, -1]).
2031+
2032+
Returns
2033+
-------
2034+
n_axes_before: int
2035+
The number of axes before the interval to be raveled.
2036+
n_axes_after: int
2037+
The number of axes after the interval to be raveled.
2038+
min_axes: int
2039+
The minimum number of axes that the inputs must have.
2040+
max_axes: int or None
2041+
The maximum number of axes that the inputs can have, or None if there is no strict maximum. A maximum is
2042+
only introduced when it would resolve ambiguities in the interpretation of the axes list. For example,
2043+
[2, 3] can be either interpreted as having two ravel intervals [:2] and [4:], which is illegal,
2044+
unless 3 is interpreted as -1, which is only possible if all inputs have exactly 4 axes. Likewise,
2045+
[-3, -1] can be interpreted as having two ravel intervals [:-3], [-3:], unless -3 is interpreted as 0,
2046+
which is only possible if all inputs have exactly 3 axes.
2047+
"""
2048+
axes = self.axes
2049+
if axes is None:
2050+
return 0, 0, 0, None
2051+
2052+
if isinstance(axes, int):
2053+
axes = [axes]
2054+
2055+
if len(set(axes)) != len(axes):
2056+
raise ValueError("axes must have no duplicates")
2057+
if axes is not None and len(axes) == 0:
2058+
raise ValueError("axes=[] is ambiguous; use None to ravel all")
2059+
2060+
first_negative_idx = next((i for i, a in enumerate(axes) if a < 0), len(axes))
2061+
positive_axes = list(axes[:first_negative_idx])
2062+
negative_axes = list(axes[first_negative_idx:])
2063+
2064+
if not all(a < 0 for a in negative_axes):
2065+
raise ValueError("Negative axes must come after positive")
2066+
2067+
def strictly_increasing(s):
2068+
return all(b > a for a, b in pairwise(s))
2069+
2070+
if (positive_axes and not strictly_increasing(positive_axes)) or (
2071+
negative_axes and not strictly_increasing(negative_axes)
2072+
):
2073+
raise ValueError("Axes must be strictly increasing")
2074+
2075+
def find_gaps(s):
2076+
return [i for i, (a, b) in enumerate(pairwise(s)) if b - a > 1]
2077+
2078+
pos_gaps = find_gaps(positive_axes)
2079+
neg_gaps = find_gaps(negative_axes)
2080+
positive_only = positive_axes and not negative_axes
2081+
negative_only = negative_axes and not positive_axes
2082+
mixed_case = positive_axes and negative_axes
2083+
2084+
max_axes: int | None = None
2085+
2086+
n_explicit_holes = len(pos_gaps) + len(neg_gaps)
2087+
if n_explicit_holes > 1:
2088+
raise ValueError(
2089+
"Too many holes in axes list. There can be at most one hole in the axes list, "
2090+
"including implict holes resulting from omitting the 0 or -1 axis."
2091+
)
2092+
2093+
if mixed_case:
2094+
if pos_gaps or neg_gaps:
2095+
raise ValueError(
2096+
"Too many holes in axes list. There can be at most one hole in the axes list, "
2097+
"including implict holes resulting from omitting the 0 or -1 axis. Because both "
2098+
"positive and negative axes are present, there is always assume to be an explit hole "
2099+
"between them."
2100+
)
2101+
n_before = len(positive_axes)
2102+
n_after = len(negative_axes)
2103+
min_axes = n_before + n_after
2104+
2105+
if positive_only:
2106+
# There are four cases to consider when all axes are positive:
2107+
# 0. There are two implicit gaps (0 is not present) and an explicit gap (e.g. [2, 4])
2108+
# This case is always illegal, as there is no interpretation that would result in having
2109+
# 1. There is only an implicit right hole (e.g. [0, 1])
2110+
# This case is legal, and requires no special interpretation. It corresponds to 'i j *' in einops
2111+
# 2. There is an explicit internal hole (e.g. [0, 2])
2112+
# This case is legal, but requires interpreting the last axis as -1, which introduces a maximum number
2113+
# of axes. It corresponds to 'i * j' in einops, and requires at least one input to have 3 dimensions, and
2114+
# no input to have more than 3 dimensions.
2115+
# 2. The axes start at an index greater than 0, but have no internal holes (e.g. [2, 3])
2116+
# This case is legal, but requires flipping the axes to negative indexing, so that the largest axis is
2117+
# -1, followed by -2, etc. This introduces a maximum number of axes.
2118+
if pos_gaps and positive_axes[0] != 0:
2119+
raise ValueError(
2120+
"Too many holes in axes list. There can be at most one hole in the axes list, "
2121+
"including implict holes resulting from omitting the 0 or -1 axis. In this case, "
2122+
"there is an explicit internal hole as well as an implicit left hole."
2123+
)
2124+
2125+
elif positive_axes[0] == 0 and not pos_gaps:
2126+
# Case 1: Only right implicit hole. No ambiguities.
2127+
n_before = positive_axes[-1] + 1
2128+
n_after = 0
2129+
min_axes = n_before + n_after
2130+
max_axes = None
2131+
2132+
elif pos_gaps:
2133+
# Case 2: Explicit hole in the positives, plus right implicit hole.
2134+
split = pos_gaps[0] + 1
2135+
n_before = split
2136+
n_after = len(positive_axes) - split
2137+
min_axes = n_before + n_after
2138+
2139+
# Close the right implicit hole
2140+
max_axes = positive_axes[-1] + 1
2141+
2142+
else:
2143+
# Case 3: Left and right implicit holes, but the right can be closed by flipping to negative axes and
2144+
# adding a maximum number of axes.
2145+
# Compute min_axes and max_axes under Case 1 of the negative_only scenario, with a max_axes constraint.
2146+
max_axes = positive_axes[-1] + 1
2147+
n_before = 0
2148+
n_after = len(positive_axes)
2149+
min_axes = n_before + n_after
2150+
2151+
if negative_only:
2152+
# The same four cases are considered when all axes are negative, but ordering is reversed.
2153+
# 0. There are two implicit holes (e.g. [-4, -2])
2154+
# This case is always illegal, as there is no interpretation that would result in having only one hole
2155+
# in the axis list.
2156+
# 1. There is only an implicit left hole (e.g. [-2, -1])
2157+
# This case is legal, and requires no special interpretation. It corresponds to '* i j' in einops
2158+
# 2. There is an explicit internal hole (e.g. [-3, -1])
2159+
# This case is legal, but requires interpreting the smallest axis as 0, which introduces a maximum number
2160+
# of axes. It corresponds to '* i j' in einops, and requires at least one input to have 3 dimensions, and
2161+
# no input to have more than 3 dimensions.
2162+
# 3. The axes end at an index less than -1, but have no internal holes (e.g. [-4, -3]). Flip to positive
2163+
# axes, adding a maximum number of axes. Interpret the smallest axis as 0 to resolve ambiguity.
2164+
if neg_gaps and negative_axes[-1] != -1:
2165+
raise ValueError(
2166+
"Too many holes in axes list. There can be at most one hole in the axes list, "
2167+
"including implict holes resulting from omitting the 0 or -1 axis. In this case, "
2168+
"there is an explicit internal hole as well as an implicit right hole."
2169+
)
2170+
elif negative_axes[-1] == -1 and not neg_gaps:
2171+
# Case 1: No ambiguities, only left implicit hole.
2172+
n_before = 0
2173+
n_after = len(negative_axes)
2174+
min_axes = n_before + n_after
2175+
max_axes = None
2176+
elif neg_gaps:
2177+
# Case 2: Explicit hole in the negatives, plus left implicit hole.
2178+
split = neg_gaps[0] + 1
2179+
n_before = split
2180+
n_after = len(negative_axes) - split
2181+
min_axes = n_before + n_after
2182+
2183+
# Close the left implicit hole
2184+
max_axes = abs(min(negative_axes))
2185+
else:
2186+
# Case 3: Left and right implicit holes, but the left can be closed by flipping to positive axes and
2187+
# adding a maximum number of axes.
2188+
max_axes = abs(negative_axes[0])
2189+
n_before = negative_axes[-1] + max_axes + 1
2190+
n_after = 0
2191+
min_axes = n_before + n_after
2192+
2193+
return n_before, n_after, min_axes, max_axes
2194+
2195+
def make_node(self, *tensors: TensorVariable):
2196+
tensors = [ptb.as_tensor_variable(t) for t in tensors]
2197+
n_axes_before, n_axes_after, min_axes, max_axes = self._analyze_axes_list()
2198+
2199+
if min([t.ndim for t in tensors]) < min_axes:
2200+
raise ValueError(
2201+
f"All input tensors to {self!s} must have at least {min_axes} dimensions, but the minimum "
2202+
f"number of dimensions found was {min([t.ndim for t in tensors])}."
2203+
)
2204+
2205+
max_ndim = max([t.ndim for t in tensors])
2206+
if max_axes is not None and max_ndim > max_axes:
2207+
raise ValueError(
2208+
f"All input tensors to {self!s} must have at most {max_axes} dimensions, but the maximum "
2209+
f"number of dimensions found was {max_ndim}."
2210+
)
2211+
2212+
def _coalesce_dim(shapes: list[int | None], axis: int) -> int | None:
2213+
unique_shapes = {s for s in shapes if s is not None}
2214+
if not unique_shapes:
2215+
return None
2216+
if len(unique_shapes) > 1:
2217+
raise ValueError(
2218+
f"Input tensors to Pack op have incompatible sizes on dimension {axis} : {shapes}"
2219+
)
2220+
return unique_shapes.pop()
2221+
2222+
shapes_to_pack = [
2223+
t.type.shape[n_axes_before : t.ndim - n_axes_after] for t in tensors
2224+
]
2225+
packed_shape = (
2226+
None
2227+
if any(
2228+
shape is None
2229+
for packed_shape in shapes_to_pack
2230+
for shape in packed_shape
2231+
)
2232+
else int(sum(np.prod(shapes) for shapes in shapes_to_pack))
2233+
)
2234+
prefix_shapes = [
2235+
_coalesce_dim([t.type.shape[i] for t in tensors], i)
2236+
for i in range(n_axes_before)
2237+
]
2238+
suffix_shapes = [
2239+
_coalesce_dim(
2240+
[t.type.shape[t.ndim - n_axes_after + i] for t in tensors],
2241+
n_axes_before + i,
2242+
)
2243+
for i in range(n_axes_after)
2244+
]
2245+
out_shape = (*prefix_shapes, packed_shape, *suffix_shapes)
2246+
2247+
packed_output = ptb.tensor(dtype=tensors[0].dtype, shape=out_shape)
2248+
packed_shapes = [
2249+
ptb.tensor(dtype="int64", shape=(len(shapes),)) for shapes in shapes_to_pack
2250+
]
2251+
2252+
return Apply(self, tensors, [packed_output, *packed_shapes])
2253+
2254+
def perform(self, node, inputs, outputs):
2255+
tensors = inputs
2256+
packed_output, *packed_shapes = outputs
2257+
2258+
reshaped_tensors = []
2259+
tmp_shapes = []
2260+
2261+
n_axes_before, n_axes_after, min_axes, max_axes = self._analyze_axes_list()
2262+
2263+
if (
2264+
max_axes is not None
2265+
and any(t.ndim > max_axes for t in tensors)
2266+
and not any(t.ndim == max_axes for t in tensors)
2267+
):
2268+
raise ValueError(
2269+
f"All input tensors must have at most {max_axes} axes, and at least one input tensor must have exactly "
2270+
f"{max_axes} axes to resolve ambiguities in the interpretation of the axes list {self.axes}. A less"
2271+
f"ambiguous axes list can be used to avoid this restriction, usually by including 0 or -1 in the axes "
2272+
f"list."
2273+
)
2274+
2275+
for i, tensor in enumerate(tensors):
2276+
shape = tensor.shape
2277+
ndim = tensor.ndim
2278+
if tensor.ndim < min_axes:
2279+
raise ValueError(
2280+
f"packed tensor #{i} (enumeration starts with 0) has shape {shape}, "
2281+
f"while pattern {self.axes} assumes at least {min_axes} axes"
2282+
)
2283+
axis_after_packed_axes = ndim - n_axes_after
2284+
tmp_shapes.append(shape[n_axes_before:axis_after_packed_axes])
2285+
reshaped_tensors.append(
2286+
tensor.reshape(
2287+
(*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:])
2288+
)
2289+
)
2290+
2291+
packed_output[0] = np.concatenate(reshaped_tensors, axis=n_axes_before)
2292+
for i, packed_shape in enumerate(tmp_shapes):
2293+
packed_shapes[i][0] = np.array(packed_shape).astype("int64")
2294+
2295+
20142296
def pack(
20152297
*tensors: TensorVariable, axes: int | Sequence[int] | None = None
20162298
) -> tuple[TensorVariable, list[tuple[TensorVariable]]]:
@@ -2035,14 +2317,10 @@ def pack(
20352317
if not tensors:
20362318
raise ValueError("Cannot pack an empty list of tensors.")
20372319

2038-
packed_shapes = [
2039-
t.type.shape if not any(s is None for s in t.type.shape) else t.shape
2040-
for t in tensors
2041-
]
2042-
2043-
flat_tensor = join(0, *[t.ravel() for t in tensors])
2320+
pack_op = Pack(axes=axes)
2321+
packed_tensor, *packed_shapes = pack_op(*tensors)
20442322

2045-
return flat_tensor, packed_shapes
2323+
return packed_tensor, packed_shapes
20462324

20472325

20482326
def unpack(
@@ -2091,12 +2369,12 @@ def unpack(
20912369
"geomspace",
20922370
"linspace",
20932371
"logspace",
2372+
"pack",
20942373
"ravel_multi_index",
20952374
"repeat",
20962375
"searchsorted",
20972376
"squeeze",
20982377
"unique",
2099-
"unravel_index",
2100-
"pack",
21012378
"unpack",
2379+
"unravel_index",
21022380
]

0 commit comments

Comments
 (0)