Skip to content

Commit 8a92a13

Browse files
authored
Merge pull request #6 from jcapriot/transpose_solve
add transpose option to solve call
2 parents bbd5658 + fee2501 commit 8a92a13

File tree

3 files changed

+29
-7
lines changed

3 files changed

+29
-7
lines changed

pydiso/mkl_solver.pyx

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ cdef class MKLPardisoSolver:
336336
def __call__(self, b):
337337
return self.solve(b)
338338

339-
def solve(self, b, x=None):
339+
def solve(self, b, x=None, transpose=False):
340340
"""solve(self, b, x=None, transpose=False)
341341
Solves the equation AX=B using the factored A matrix
342342
@@ -354,6 +354,8 @@ cdef class MKLPardisoSolver:
354354
x : numpy.ndarray, optional
355355
A pre-allocated output array (of the same data type as A).
356356
If None, a new array is constructed.
357+
transpose : bool, optional
358+
If True, it will solve A^TX=B using the factored A matrix.
357359
358360
Returns
359361
-------
@@ -388,6 +390,10 @@ cdef class MKLPardisoSolver:
388390

389391
cdef int_t nrhs = b.shape[1] if b.ndim == 2 else 1
390392

393+
if transpose:
394+
self.set_iparm(11, 2)
395+
else:
396+
self.set_iparm(11, 0)
391397
self._solve(bp, xp, nrhs)
392398
return x
393399

@@ -420,7 +426,7 @@ cdef class MKLPardisoSolver:
420426
if self._is_32:
421427
self._par.iparm[i] = val
422428
else:
423-
self._par.iparm[i] = val
429+
self._par64.iparm[i] = val
424430

425431
@property
426432
def nnz(self):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def configuration(parent_package="", top_path=None):
2424
python_requires=">=3.8",
2525
setup_requires=[
2626
"numpy>=1.8",
27-
"cython>=3.0",
27+
"cython>=0.29.31",
2828
],
2929
install_requires=[
3030
'numpy>=1.8',

tests/test_pydiso.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
set_mkl_pardiso_threads,
1010
)
1111
import pytest
12+
import sys
1213

1314
np.random.seed(12345)
1415
n = 40
@@ -39,6 +40,7 @@
3940
}
4041

4142

43+
@pytest.mark.xfail(sys.platform == "darwin", reason="Unexpected Thread bug in third party library")
4244
def test_thread_setting():
4345
n1 = get_mkl_max_threads()
4446
n2 = get_mkl_pardiso_max_threads()
@@ -93,8 +95,22 @@ def test_solver(A, matrix_type):
9395
x2 = solver.solve(b)
9496

9597
eps = np.finfo(dtype).eps
96-
rel_err = np.linalg.norm(x-x2)/np.linalg.norm(x)
97-
assert rel_err < 1E3*eps
98+
np.testing.assert_allclose(x, x2, atol=1E3*eps)
99+
100+
@pytest.mark.parametrize("A, matrix_type", inputs)
101+
def test_transpose_solver(A, matrix_type):
102+
dtype = A.dtype
103+
if np.issubdtype(dtype, np.complexfloating):
104+
x = xc.astype(dtype)
105+
else:
106+
x = xr.astype(dtype)
107+
b = A.T @ x
108+
109+
solver = Solver(A, matrix_type=matrix_type)
110+
x2 = solver.solve(b, transpose=True)
111+
112+
eps = np.finfo(dtype).eps
113+
np.testing.assert_allclose(x, x2, atol=1E3*eps)
98114

99115
def test_multiple_RHS():
100116
A = A_real_dict["real_symmetric_positive_definite"]
@@ -105,8 +121,7 @@ def test_multiple_RHS():
105121
x2 = solver.solve(b)
106122

107123
eps = np.finfo(np.float64).eps
108-
rel_err = np.linalg.norm(x-x2)/np.linalg.norm(x)
109-
assert rel_err < 1E3*eps
124+
np.testing.assert_allclose(x, x2, atol=1E3*eps)
110125

111126

112127
def test_matrix_type_errors():
@@ -119,6 +134,7 @@ def test_matrix_type_errors():
119134
solver = Solver(A, matrix_type="real_symmetric_positive_definite")
120135

121136

137+
122138
def test_rhs_size_error():
123139
A = A_real_dict["real_symmetric_positive_definite"]
124140
solver = Solver(A, "real_symmetric_positive_definite")

0 commit comments

Comments
 (0)