Skip to content

Commit 9f9b906

Browse files
authored
Add implementation of dpnp.scipy.special.erfinv and dpnp.scipy.special.erfcinv (#2624)
The PR adds the implementation f `dpnp.scipy.special.erfinv` and `dpnp.scipy.special.erfcinv` functions, including tests coverage.
1 parent 286e6e5 commit 9f9b906

File tree

13 files changed

+491
-42
lines changed

13 files changed

+491
-42
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ The release drops support for Python 3.9, making Python 3.10 the minimum require
1515
* Added implementation of `dpnp.linalg.lu_solve` for batch inputs (SciPy-compatible) [#2619](https://github.com/IntelPython/dpnp/pull/2619)
1616
* Added `dpnp.exceptions` submodule to aggregate the generic exceptions used by dpnp [#2616](https://github.com/IntelPython/dpnp/pull/2616)
1717
* Added implementation of `dpnp.scipy.special.erfcx` [#2596](https://github.com/IntelPython/dpnp/pull/2596)
18+
* Added implementation of `dpnp.scipy.special.erfinv` and `dpnp.scipy.special.erfcinv` [#2624](https://github.com/IntelPython/dpnp/pull/2624)
1819

1920
### Changed
2021

dpnp/backend/extensions/ufunc/elementwise_functions/erf_funcs.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ static void populate(py::module_ m,
213213
MACRO_DEFINE_IMPL(erf, Erf);
214214
MACRO_DEFINE_IMPL(erfc, Erfc);
215215
MACRO_DEFINE_IMPL(erfcx, Erfcx);
216+
MACRO_DEFINE_IMPL(erfinv, Erfinv);
217+
MACRO_DEFINE_IMPL(erfcinv, Erfcinv);
216218
} // namespace impl
217219

218220
void init_erf_funcs(py::module_ m)
@@ -236,5 +238,13 @@ void init_erf_funcs(py::module_ m)
236238
impl::populate<impl::ErfcxContigFactory, impl::ErfcxStridedFactory>(
237239
m, "_erfcx", "", impl::erfcx_contig_dispatch_vector,
238240
impl::erfcx_strided_dispatch_vector);
241+
242+
impl::populate<impl::ErfinvContigFactory, impl::ErfinvStridedFactory>(
243+
m, "_erfinv", "", impl::erfinv_contig_dispatch_vector,
244+
impl::erfinv_strided_dispatch_vector);
245+
246+
impl::populate<impl::ErfcinvContigFactory, impl::ErfcinvStridedFactory>(
247+
m, "_erfcinv", "", impl::erfcinv_contig_dispatch_vector,
248+
impl::erfcinv_strided_dispatch_vector);
239249
}
240250
} // namespace dpnp::extensions::ufunc

dpnp/backend/extensions/vm/erf_funcs.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ using ew_cmn_ns::unary_contig_impl_fn_ptr_t;
134134
MACRO_DEFINE_IMPL(erf, Erf);
135135
MACRO_DEFINE_IMPL(erfc, Erfc);
136136
MACRO_DEFINE_IMPL(erfcx, Erfcx);
137+
MACRO_DEFINE_IMPL(erfinv, Erfinv);
138+
MACRO_DEFINE_IMPL(erfcinv, Erfcinv);
137139

138140
template <template <typename fnT, typename T> typename factoryT>
139141
static void populate(py::module_ m,
@@ -194,5 +196,17 @@ void init_erf_funcs(py::module_ m)
194196
"Call `erfcx` function from OneMKL VM library to compute the scaled "
195197
"complementary error function value of vector elements",
196198
impl::erfcx_contig_dispatch_vector);
199+
200+
impl::populate<impl::ErfinvContigFactory>(
201+
m, "_erfinv",
202+
"Call `erfinv` function from OneMKL VM library to compute the inverse "
203+
"of the error function value of vector elements",
204+
impl::erfinv_contig_dispatch_vector);
205+
206+
impl::populate<impl::ErfcinvContigFactory>(
207+
m, "_erfcinv",
208+
"Call `erfcinv` function from OneMKL VM library to compute the inverse "
209+
"of the complementary error function value of vector elements",
210+
impl::erfcinv_contig_dispatch_vector);
197211
}
198212
} // namespace dpnp::extensions::vm

dpnp/backend/kernels/elementwise_functions/erf.hpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#include <sycl/ext/intel/math.hpp>
4444
#else
4545
#include "erfcx.hpp"
46+
#include "erfinv.hpp"
4647
#endif
4748

4849
namespace dpnp::kernels::erfs
@@ -85,13 +86,15 @@ struct BaseFunctor
8586
template <typename ArgT, typename ResT> \
8687
using __f_name__##Functor = BaseFunctor<__f_name__##Op, ArgT, ResT>;
8788

88-
MACRO_DEFINE_FUNCTOR(sycl::erf, Erf);
89-
MACRO_DEFINE_FUNCTOR(sycl::erfc, Erfc);
90-
MACRO_DEFINE_FUNCTOR(
9189
#if defined(__SYCL_EXT_INTEL_MATH_SUPPORT)
92-
sycl::ext::intel::math::erfcx,
90+
using namespace sycl::ext::intel::math;
9391
#else
94-
impl::erfcx,
92+
using namespace impl;
9593
#endif
96-
Erfcx);
94+
95+
MACRO_DEFINE_FUNCTOR(sycl::erf, Erf);
96+
MACRO_DEFINE_FUNCTOR(sycl::erfc, Erfc);
97+
MACRO_DEFINE_FUNCTOR(erfcx, Erfcx);
98+
MACRO_DEFINE_FUNCTOR(erfinv, Erfinv);
99+
MACRO_DEFINE_FUNCTOR(erfcinv, Erfcinv);
97100
} // namespace dpnp::kernels::erfs

dpnp/backend/kernels/elementwise_functions/erfcx.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1622,19 +1622,17 @@ For x < 0, we use the relationship erfcx(-x) = 2 exp(x^2) - erfc(x), with the
16221622
usual checks for overflow etcetera.
16231623
*/
16241624
template <typename Tp>
1625-
Tp erfcx(Tp x)
1625+
inline Tp erfcx(Tp x)
16261626
{
16271627
static_assert(std::is_floating_point_v<Tp>,
16281628
"erfcx requires a floating-point type");
16291629

16301630
if (x >= 0) {
1631-
if (x > 50) // continued-fraction expansion is faster
1632-
{
1631+
if (x > 50) { // continued-fraction expansion is faster
16331632
// 1/sqrt(pi)
16341633
constexpr Tp inv_sqrtpi = 0.564189583547756286948079451560772586L;
16351634

1636-
if (x > 5e7) // 1-term expansion, important to avoid overflow
1637-
{
1635+
if (x > 5e7) { // 1-term expansion, important to avoid overflow
16381636
return inv_sqrtpi / x;
16391637
}
16401638

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
29+
#pragma once
30+
31+
#include <limits>
32+
#include <sycl/sycl.hpp>
33+
34+
namespace dpnp::kernels::erfs::impl
35+
{
36+
template <typename Tp>
37+
inline Tp polevl(Tp x, const Tp *coeff, int i)
38+
{
39+
Tp p = *coeff++;
40+
41+
do {
42+
p = p * x + *coeff++;
43+
} while (--i);
44+
return p;
45+
}
46+
47+
template <typename Tp>
48+
inline Tp p1evl(Tp x, const Tp *coeff, int i)
49+
{
50+
Tp p = x + *coeff++;
51+
52+
while (--i) {
53+
p = p * x + *coeff++;
54+
}
55+
return p;
56+
}
57+
58+
template <typename Tp>
59+
inline Tp ndtri(Tp y0)
60+
{
61+
Tp y;
62+
int code = 1;
63+
64+
if (y0 == 0.0) {
65+
return -std::numeric_limits<Tp>::infinity();
66+
}
67+
else if (y0 == 1.0) {
68+
return std::numeric_limits<Tp>::infinity();
69+
}
70+
else if (y0 < 0.0 || y0 > 1.0) {
71+
return std::numeric_limits<Tp>::quiet_NaN();
72+
}
73+
74+
// exp(-2)
75+
constexpr Tp exp_minus2 = 0.13533528323661269189399949497248L;
76+
if (y0 > (1.0 - exp_minus2)) {
77+
y = 1.0 - y0;
78+
code = 0;
79+
}
80+
else {
81+
y = y0;
82+
}
83+
84+
if (y > exp_minus2) {
85+
// sqrt(2*pi)
86+
constexpr Tp root_2_pi = 2.50662827463100050241576528481105L;
87+
88+
// approximation for 0 <= |y - 0.5| <= 3/8
89+
constexpr Tp p[] = {
90+
-5.99633501014107895267E1, 9.80010754185999661536E1,
91+
-5.66762857469070293439E1, 1.39312609387279679503E1,
92+
-1.23916583867381258016E0,
93+
};
94+
constexpr Tp q[] = {
95+
1.95448858338141759834E0, 4.67627912898881538453E0,
96+
8.63602421390890590575E1, -2.25462687854119370527E2,
97+
2.00260212380060660359E2, -8.20372256168333339912E1,
98+
1.59056225126211695515E1, -1.18331621121330003142E0,
99+
};
100+
101+
y -= 0.5;
102+
Tp y2 = y * y;
103+
Tp x = y + y * (y2 * polevl(y2, p, 4) / p1evl(y2, q, 8));
104+
return x * root_2_pi;
105+
}
106+
107+
Tp x = sycl::sqrt(-2.0 * sycl::log(y));
108+
Tp x0 = x - sycl::log(x) / x;
109+
Tp z = 1.0 / x;
110+
111+
Tp x1;
112+
if (x < 8.0) {
113+
// approximation for 2 <= sqrt(-2*log(y)) < 8
114+
constexpr Tp p[] = {
115+
4.05544892305962419923E0, 3.15251094599893866154E1,
116+
5.71628192246421288162E1, 4.40805073893200834700E1,
117+
1.46849561928858024014E1, 2.18663306850790267539E0,
118+
-1.40256079171354495875E-1, -3.50424626827848203418E-2,
119+
-8.57456785154685413611E-4,
120+
};
121+
122+
constexpr Tp q[] = {
123+
1.57799883256466749731E1, 4.53907635128879210584E1,
124+
4.13172038254672030440E1, 1.50425385692907503408E1,
125+
2.50464946208309415979E0, -1.42182922854787788574E-1,
126+
-3.80806407691578277194E-2, -9.33259480895457427372E-4,
127+
};
128+
129+
x1 = z * polevl(z, p, 8) / p1evl(z, q, 8);
130+
}
131+
else {
132+
// approximation for 8 <= sqrt(-2*log(y)) < 64
133+
constexpr Tp p[] = {
134+
3.23774891776946035970E0, 6.91522889068984211695E0,
135+
3.93881025292474443415E0, 1.33303460815807542389E0,
136+
2.01485389549179081538E-1, 1.23716634817820021358E-2,
137+
3.01581553508235416007E-4, 2.65806974686737550832E-6,
138+
6.23974539184983293730E-9,
139+
};
140+
141+
constexpr Tp q[] = {
142+
6.02427039364742014255E0, 3.67983563856160859403E0,
143+
1.37702099489081330271E0, 2.16236993594496635890E-1,
144+
1.34204006088543189037E-2, 3.28014464682127739104E-4,
145+
2.89247864745380683936E-6, 6.79019408009981274425E-9,
146+
};
147+
148+
x1 = z * polevl(z, p, 8) / p1evl(z, q, 8);
149+
}
150+
151+
x = x0 - x1;
152+
if (code != 0) {
153+
x = -x;
154+
}
155+
return x;
156+
}
157+
158+
template <typename Tp>
159+
inline Tp erfinv(Tp y)
160+
{
161+
static_assert(std::is_floating_point_v<Tp>,
162+
"erfinv requires a floating-point type");
163+
164+
constexpr Tp lower = -1;
165+
constexpr Tp upper = 1;
166+
167+
constexpr Tp thresh = 1e-7;
168+
169+
// For small arguments, use the Taylor expansion.
170+
// Otherwise, y + 1 loses precision for |y| << 1.
171+
if ((-thresh < y) && (y < thresh)) {
172+
// 2/sqrt(pi)
173+
constexpr Tp inv_sqrtpi = 1.1283791670955125738961589031215452L;
174+
return y / inv_sqrtpi;
175+
}
176+
177+
if ((lower < y) && (y < upper)) {
178+
// 1/sqrt(2)
179+
constexpr Tp one_div_root_2 = 0.7071067811865475244008443621048490L;
180+
return ndtri(0.5 * (y + 1)) * one_div_root_2;
181+
}
182+
else if (y == lower) {
183+
return -std::numeric_limits<Tp>::infinity();
184+
}
185+
else if (y == upper) {
186+
return std::numeric_limits<Tp>::infinity();
187+
}
188+
else if (sycl::isnan(y)) {
189+
return y;
190+
}
191+
return std::numeric_limits<Tp>::quiet_NaN();
192+
}
193+
194+
template <typename Tp>
195+
inline Tp erfcinv(Tp y)
196+
{
197+
static_assert(std::is_floating_point_v<Tp>,
198+
"erfcinv requires a floating-point type");
199+
200+
constexpr Tp lower = 0;
201+
constexpr Tp upper = 2;
202+
203+
if ((lower < y) && (y < upper)) {
204+
// 1/sqrt(2)
205+
constexpr Tp one_div_root_2 = 0.7071067811865475244008443621048490L;
206+
return -ndtri(0.5 * y) * one_div_root_2;
207+
}
208+
else if (y == lower) {
209+
return std::numeric_limits<Tp>::infinity();
210+
}
211+
else if (y == upper) {
212+
return -std::numeric_limits<Tp>::infinity();
213+
}
214+
else if (sycl::isnan(y)) {
215+
return y;
216+
}
217+
return std::numeric_limits<Tp>::quiet_NaN();
218+
}
219+
} // namespace dpnp::kernels::erfs::impl

dpnp/scipy/special/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,15 @@
4545
from ._erf import (
4646
erf,
4747
erfc,
48+
erfcinv,
4849
erfcx,
50+
erfinv,
4951
)
5052

5153
__all__ = [
5254
"erf",
5355
"erfc",
56+
"erfcinv",
5457
"erfcx",
58+
"erfinv",
5559
]

0 commit comments

Comments
 (0)