-
Notifications
You must be signed in to change notification settings - Fork 146
Rewrite inverse for triangular matrix #1612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
f9192ed
a24fd3f
1076955
ab1f4eb
4d05221
cfdb130
0e01cce
d0dbf0e
07c48f3
0a98b73
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,18 +8,22 @@ | |
| from pytensor import tensor as pt | ||
| from pytensor.compile import optdb | ||
| from pytensor.graph import Apply, FunctionGraph | ||
| from pytensor.graph.basic import Constant | ||
| from pytensor.graph.rewriting.basic import ( | ||
| copy_stack_trace, | ||
| dfs_rewriter, | ||
| node_rewriter, | ||
| ) | ||
| from pytensor.graph.rewriting.unify import OpPattern | ||
| from pytensor.scalar.basic import Abs, Log, Mul, Sign | ||
| from pytensor.scalar.basic import Mul as ScalarMul | ||
| from pytensor.scalar.basic import Sub as ScalarSub | ||
| from pytensor.tensor.basic import ( | ||
| AllocDiag, | ||
| ExtractDiag, | ||
| Eye, | ||
| TensorVariable, | ||
| Tri, | ||
| concatenate, | ||
| diag, | ||
| diagonal, | ||
|
|
@@ -46,12 +50,16 @@ | |
| ) | ||
| from pytensor.tensor.rewriting.blockwise import blockwise_of | ||
| from pytensor.tensor.slinalg import ( | ||
| LU, | ||
| QR, | ||
| BlockDiagonal, | ||
| Cholesky, | ||
| CholeskySolve, | ||
| LUFactor, | ||
| Solve, | ||
| SolveBase, | ||
| SolveTriangular, | ||
| TriangularInv, | ||
| _bilinear_solve_discrete_lyapunov, | ||
| block_diag, | ||
| cholesky, | ||
|
|
@@ -1017,3 +1025,95 @@ def scalar_solve_to_division(fgraph, node): | |
| copy_stack_trace(old_out, new_out) | ||
|
|
||
| return [new_out] | ||
|
|
||
|
|
||
| def _find_triangular_op(var): | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| Inspects a variable to see if it's triangular. | ||
|
|
||
| Returns `True` if lower-triangular, `False` if upper-triangular, otherwise `None`. | ||
| """ | ||
| # Case 1: Check for an explicit tag | ||
| is_lower = getattr(var.tag, "lower_triangular", False) | ||
| is_upper = getattr(var.tag, "upper_triangular", False) | ||
| if is_lower or is_upper: | ||
| return is_lower | ||
|
|
||
| if not var.owner: | ||
| return None | ||
|
|
||
| op = var.owner.op | ||
| core_op = op.core_op if isinstance(op, Blockwise) else op | ||
|
|
||
| # Case 2: Check for direct creator Ops | ||
| if isinstance(core_op, Cholesky): | ||
| return core_op.lower | ||
|
|
||
| if isinstance(core_op, LU | LUFactor): | ||
| if var.owner.outputs[1] == var: | ||
| return True | ||
| if var.owner.outputs[2] == var: | ||
| return False | ||
|
|
||
| if isinstance(core_op, QR): | ||
| if var.owner.outputs[1] == var: | ||
| return False | ||
|
|
||
| if isinstance(core_op, Tri): | ||
| k_node = var.owner.inputs[2] | ||
| if isinstance(k_node, Constant) and k_node.data == 0: | ||
| return True | ||
|
|
||
| # Case 3: tril/triu patterns which are implemented as Mul | ||
| if isinstance(core_op, Elemwise) and isinstance(core_op.scalar_op, ScalarMul): | ||
| other_inp = next( | ||
| (i for i in var.owner.inputs if i != var.owner.inputs[0]), None | ||
| ) | ||
|
|
||
| if other_inp is not None and other_inp.owner: | ||
| # Check for tril pattern: Mul(x, Tri(...)) | ||
| if isinstance(other_inp.owner.op, Tri): | ||
| k_node = other_inp.owner.inputs[2] | ||
| if isinstance(k_node, Constant) and k_node.data == 0: | ||
| return True # It's tril | ||
|
|
||
| # Check for triu pattern: Mul(x, Sub(1, Tri(k=-1))) | ||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| sub_op = other_inp.owner.op | ||
| if isinstance(sub_op, Elemwise) and isinstance(sub_op.scalar_op, ScalarSub): | ||
| sub_inputs = other_inp.owner.inputs | ||
| const_one = next( | ||
| (i for i in sub_inputs if isinstance(i, Constant) and i.data == 1), | ||
| None, | ||
| ) | ||
| tri_inp = next( | ||
| (i for i in sub_inputs if i.owner and isinstance(i.owner.op, Tri)), | ||
| None, | ||
| ) | ||
|
|
||
| if const_one is not None and tri_inp is not None: | ||
| k_node = tri_inp.owner.inputs[2] | ||
| if isinstance(k_node, Constant) and k_node.data == -1: | ||
| return False # It's triu | ||
|
|
||
| return None | ||
|
|
||
|
|
||
| @register_canonicalize | ||
asifzubair marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| @register_stabilize | ||
| @node_rewriter([blockwise_of(MATRIX_INVERSE_OPS)]) | ||
| def rewrite_inv_to_triangular_solve(fgraph, node): | ||
| """ | ||
| This rewrite takes advantage of the fact that the inverse of a triangular | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should make sure any rewrites targeting MatrixInverse, such as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this was mentioned in the PR but suggested that it be handled separately . Would you prefer to handle it in the same PR and also investigate other cases ?
Sorry, didn't quite understand how this would be worse ? Because instead of the |
||
| matrix can be computed more efficiently than the inverse of a general | ||
| matrix by using a triangular inv instead of a general matrix inverse. | ||
| """ | ||
|
|
||
| A = node.inputs[0] | ||
| is_lower = _find_triangular_op(A) | ||
| if is_lower is None: | ||
| return None | ||
|
|
||
| new_op = TriangularInv(lower=is_lower) | ||
| new_inv = new_op(A) | ||
| copy_stack_trace(node.outputs[0], new_inv) | ||
| return [new_inv] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,7 @@ | |
| from pytensor.tensor import math as ptm | ||
| from pytensor.tensor.basic import as_tensor_variable, diagonal | ||
| from pytensor.tensor.blockwise import Blockwise | ||
| from pytensor.tensor.nlinalg import kron, matrix_dot | ||
| from pytensor.tensor.nlinalg import MatrixInverse, kron, matrix_dot | ||
| from pytensor.tensor.shape import reshape | ||
| from pytensor.tensor.type import matrix, tensor, vector | ||
| from pytensor.tensor.variable import TensorVariable | ||
|
|
@@ -1016,6 +1016,71 @@ def solve_triangular( | |
| return cast(TensorVariable, ret) | ||
|
|
||
|
|
||
| class TriangularInv(MatrixInverse): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We shouldn't subclass from Ops we also instantiate. That may cause confusion between TriangularInv and MatrixInverse. Sometimes that's fine, others it isn't. Better to have a BaseMatrixInverse that both inherit from. Then code can look for For instance, I'm surprised your current rewrite is not applying recursively since the returned graph should fit the bill for the pattern you're matching (an inverse of an A that is found to be triangular) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you, @ricardoV94 ! This was very insightful. I'm not sure if this is true, but perhaps in the re-write we do a If I make the |
||
| """ | ||
| Computes the inverse of a triangular matrix. | ||
| """ | ||
|
|
||
| __props__ = ("lower", "on_error", "overwrite_a") | ||
|
|
||
| def __init__(self, lower=True, on_error="raise", overwrite_a=False): | ||
| self.lower = lower | ||
| if on_error not in ("raise", "nan"): | ||
| raise ValueError('on_error must be one of "raise" or "nan"') | ||
| self.on_error = on_error | ||
| self.overwrite_a = overwrite_a | ||
|
|
||
| if self.overwrite_a: | ||
| self.destroy_map = {0: [0]} | ||
|
|
||
| def perform(self, node, inputs, outputs): | ||
| (x,) = inputs | ||
| (z,) = outputs | ||
| (trtri,) = get_lapack_funcs(("trtri",), (x,)) | ||
|
|
||
| # Check if we want to overwrite and if the input is C-contiguous | ||
| c_contiguous_input = self.overwrite_a and x.flags["C_CONTIGUOUS"] | ||
| if c_contiguous_input: | ||
| # Transpose C-contiguous to F-contiguous | ||
| x_in = x.T | ||
| lower_flag = not self.lower | ||
| overwrite_flag = True | ||
| else: | ||
| # Use original matrix and flags | ||
| x_in = x | ||
| lower_flag = self.lower | ||
| overwrite_flag = self.overwrite_a | ||
|
|
||
| # Call trtri with the potentially transposed input and correct flags | ||
| # Use overwrite_c (LAPACK flag for trtri) based on our logic | ||
| inv_maybe_transposed, info = trtri( | ||
| x_in, lower=lower_flag, overwrite_c=overwrite_flag | ||
| ) | ||
|
|
||
| if info != 0: | ||
| if self.on_error == "nan": | ||
| z[0] = np.full_like(x, np.nan) | ||
| return | ||
| elif info > 0: | ||
| raise np.linalg.LinAlgError("Singular matrix") | ||
| elif info < 0: | ||
| raise ValueError( | ||
| f"illegal value in {-info}-th argument of internal trtri" | ||
| ) | ||
| z[0] = inv_maybe_transposed.T if c_contiguous_input else inv_maybe_transposed | ||
|
|
||
| def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": | ||
| """ | ||
| Allows this Op to overwrite its input buffer with its output. | ||
| """ | ||
| if not allowed_inplace_inputs: | ||
| return self | ||
|
|
||
| new_props = self._props_dict() # type: ignore | ||
| new_props["overwrite_a"] = True | ||
| return type(self)(**new_props) | ||
|
|
||
|
|
||
| class Solve(SolveBase): | ||
| """ | ||
| Solve a system of linear equations. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also have
Mul and Subimported above. Just use those?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, will fix. Thank you!