diff --git a/CHANGELOG.md b/CHANGELOG.md index 7e5101da0..46a82d167 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,8 @@ - Raised an error when an expression is used when a variable is required - Fixed some compile warnings ### Changed -- MatrixExpr.sum() now supports axis arguments and can return either a scalar or MatrixExpr depending on the result dimensions +- MatrixExpr.sum() now supports axis arguments and can return either a scalar or MatrixExpr, depending on the result dimensions. +- AddMatrixCons() also accepts ExprCons. ### Removed ## 5.5.0 - 2025.05.06 diff --git a/src/pyscipopt/scip.pxi b/src/pyscipopt/scip.pxi index 49f822991..fa33be017 100644 --- a/src/pyscipopt/scip.pxi +++ b/src/pyscipopt/scip.pxi @@ -5775,7 +5775,7 @@ cdef class Model: return constraints def addMatrixCons(self, - cons: MatrixExprCons, + cons: Union[ExprCons, MatrixExprCons], name: Union[str, np.ndarray] ='', initial: Union[bool, np.ndarray] = True, separate: Union[bool, np.ndarray] = True, @@ -5792,8 +5792,8 @@ cdef class Model: Parameters ---------- - cons : MatrixExprCons - The matrix expression constraint that is not yet an actual constraint + cons : ExprCons or MatrixExprCons + The matrix expression constraint or expression constraint, that are not yet an actual constraint name : str or np.ndarray, optional the name of the matrix constraint, generic name if empty (Default value = "") initial : bool or np.ndarray, optional @@ -5820,12 +5820,17 @@ cdef class Model: Returns ------- - MatrixConstraint - The created and added MatrixConstraint object. - + Constraint or MatrixConstraint + The created and added Constraint or MatrixConstraint. """ - assert isinstance(cons, MatrixExprCons), ( - "given constraint is not MatrixExprCons but %s" % cons.__class__.__name__) + if not isinstance(cons, (ExprCons, MatrixExprCons)): + raise TypeError("given constraint is not MatrixExprCons nor ExprCons but %s" % cons.__class__.__name__) + + if isinstance(cons, ExprCons): + return self.addCons(cons, name=name, initial=initial, separate=separate, + enforce=enforce, check=check, propagate=propagate, + local=local, modifiable=modifiable, dynamic=dynamic, + removable=removable, stickingatnode=stickingatnode) shape = cons.shape diff --git a/tests/test_matrix_variable.py b/tests/test_matrix_variable.py index 6242e97e7..0308bb694 100644 --- a/tests/test_matrix_variable.py +++ b/tests/test_matrix_variable.py @@ -15,12 +15,15 @@ def test_catching_errors(): y = m.addMatrixVar(shape=(3, 3)) rhs = np.ones((2, 1)) + # require ExprCons with pytest.raises(Exception): - m.addMatrixCons(x <= 1) + m.addCons(y <= 3) + # require MatrixExprCons or ExprCons with pytest.raises(Exception): - m.addCons(y <= 3) + m.addMatrixCons(x) + # test shape mismatch with pytest.raises(Exception): m.addMatrixCons(y <= rhs) @@ -169,7 +172,7 @@ def test_matrix_sum_argument(): # compare the result of summing 2d array to a scalar with a scalar x = m.addMatrixVar((2, 3), "x", "I", ub=4) - m.addCons(x.sum() == 24) + m.addMatrixCons(x.sum() == 24) # compare the result of summing 2d array to 1d array y = m.addMatrixVar((2, 4), "y", "I", ub=4)