Skip to content

Commit 94ca681

Browse files
authored
Merge pull request #7 from mrocklin/nogil
Add nogil declaration to _run_paradiso function
2 parents 8a92a13 + 35d4d5d commit 94ca681

File tree

2 files changed

+44
-4
lines changed

2 files changed

+44
-4
lines changed

pydiso/mkl_solver.pyx

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22
#cython: linetrace=True
33
cimport numpy as np
44
from cython cimport numeric
5+
from cpython.pythread cimport (
6+
PyThread_type_lock,
7+
PyThread_allocate_lock,
8+
PyThread_acquire_lock,
9+
PyThread_release_lock,
10+
PyThread_free_lock
11+
)
512

613
import warnings
714
import numpy as np
@@ -42,12 +49,12 @@ cdef extern from 'mkl.h':
4249
void pardiso(_MKL_DSS_HANDLE_t, const int*, const int*, const int*,
4350
const int *, const int *, const void *, const int *,
4451
const int *, int *, const int *, int *,
45-
const int *, void *, void *, int *)
52+
const int *, void *, void *, int *) nogil
4653

4754
void pardiso_64(_MKL_DSS_HANDLE_t, const long_t *, const long_t *, const long_t *,
4855
const long_t *, const long_t *, const void *, const long_t *,
4956
const long_t *, long_t *, const long_t *, long_t *,
50-
const long_t *, void *, void *, long_t *)
57+
const long_t *, void *, void *, long_t *) nogil
5158

5259

5360
#call pardiso (pt, maxfct, mnum, mtype, phase, n, a, ia, ja, perm, nrhs, iparm, msglvl, b, x, error)
@@ -184,7 +191,7 @@ cdef class MKLPardisoSolver:
184191
cdef int_t _factored
185192
cdef size_t shape[2]
186193
cdef int_t _initialized
187-
194+
cdef PyThread_type_lock lock
188195
cdef void * a
189196

190197
cdef object _data_type
@@ -253,6 +260,9 @@ cdef class MKLPardisoSolver:
253260
raise ValueError("Matrix is not square")
254261
self.shape = n_row, n_col
255262

263+
# allocate the lock
264+
self.lock = PyThread_allocate_lock()
265+
256266
self._data_type = A.dtype
257267
if matrix_type is None:
258268
if np.issubdtype(self._data_type, np.complexfloating):
@@ -496,6 +506,7 @@ cdef class MKLPardisoSolver:
496506
cdef long_t phase64=-1, nrhs64=0, error64=0
497507

498508
if self._initialized:
509+
PyThread_acquire_lock(self.lock, 1)
499510
if self._is_32:
500511
pardiso(
501512
self.handle, &self._par.maxfct, &self._par.mnum, &self._par.mtype,
@@ -508,9 +519,12 @@ cdef class MKLPardisoSolver:
508519
&phase64, &self._par64.n, self.a, NULL, NULL, NULL, &nrhs64,
509520
self._par64.iparm, &self._par64.msglvl, NULL, NULL, &error64
510521
)
522+
PyThread_release_lock(self.lock)
511523
err = error or error64
512524
if err!=0:
513525
raise PardisoError("Memmory release error "+_err_messages[err])
526+
#dealloc lock
527+
PyThread_free_lock(self.lock)
514528

515529
cdef _analyze(self):
516530
#phase = 11
@@ -536,17 +550,20 @@ cdef class MKLPardisoSolver:
536550
if err!=0:
537551
raise PardisoError("Solve step error, "+_err_messages[err])
538552

539-
cdef int _run_pardiso(self, int_t phase, void* b=NULL, void* x=NULL, int_t nrhs=0):
553+
cdef int _run_pardiso(self, int_t phase, void* b=NULL, void* x=NULL, int_t nrhs=0) nogil:
540554
cdef int_t error=0
541555
cdef long_t error64=0, phase64=phase, nrhs64=nrhs
542556

557+
PyThread_acquire_lock(self.lock, 1)
543558
if self._is_32:
544559
pardiso(self.handle, &self._par.maxfct, &self._par.mnum, &self._par.mtype,
545560
&phase, &self._par.n, self.a, &self._par.ia[0], &self._par.ja[0],
546561
&self._par.perm[0], &nrhs, self._par.iparm, &self._par.msglvl, b, x, &error)
562+
PyThread_release_lock(self.lock)
547563
return error
548564
else:
549565
pardiso_64(self.handle, &self._par64.maxfct, &self._par64.mnum, &self._par64.mtype,
550566
&phase64, &self._par64.n, self.a, &self._par64.ia[0], &self._par64.ja[0],
551567
&self._par64.perm[0], &nrhs64, self._par64.iparm, &self._par64.msglvl, b, x, &error64)
568+
PyThread_release_lock(self.lock)
552569
return error64

tests/test_pydiso.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
set_mkl_threads,
99
set_mkl_pardiso_threads,
1010
)
11+
from concurrent.futures import ThreadPoolExecutor
1112
import pytest
1213
import sys
1314

@@ -147,3 +148,25 @@ def test_rhs_size_error():
147148
solver.solve(b_bad)
148149
with pytest.raises(ValueError):
149150
solver.solve(b, x_bad)
151+
152+
def test_threading():
153+
"""
154+
Here we test that calling the solver is safe from multiple threads.
155+
There isn't actually any speedup because it acquires a lock on each call
156+
to pardiso internally (because those calls are not thread safe).
157+
"""
158+
n = 200
159+
n_rhs = 75
160+
A = sp.diags([-1, 2, -1], (-1, 0, 1), shape=(n, n), format='csr')
161+
Ainv = Solver(A)
162+
163+
x_true = np.random.rand(n, n_rhs)
164+
rhs = A @ x_true
165+
166+
with ThreadPoolExecutor() as pool:
167+
x_sol = np.stack(
168+
list(pool.map(lambda i: Ainv.solve(rhs[:, i]), range(n_rhs))),
169+
axis=1
170+
)
171+
172+
np.testing.assert_allclose(x_true, x_sol)

0 commit comments

Comments
 (0)