Skip to content

Commit 0c4b241

Browse files
committed
API: Make numpy.h compatible with both NumPy 1.x and 2.x
1 parent 8b48ff8 commit 0c4b241

File tree

1 file changed

+82
-16
lines changed

1 file changed

+82
-16
lines changed

include/pybind11/numpy.h

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ template <typename type, typename SFINAE = void>
5454
struct npy_format_descriptor;
5555

5656
struct PyArrayDescr_Proxy {
57+
PyObject_HEAD
58+
PyObject *typeobj;
59+
char kind;
60+
char type;
61+
char byteorder;
62+
char _former_flags;
63+
int type_num;
64+
/* Additional fields are NumPy version specific. */
65+
};
66+
67+
/* NumPy 1 proxy (always includes legacy fields) */
68+
struct PyArrayDescr1_Proxy {
5769
PyObject_HEAD
5870
PyObject *typeobj;
5971
char kind;
@@ -68,6 +80,28 @@ struct PyArrayDescr_Proxy {
6880
PyObject *names;
6981
};
7082

83+
/* NumPy 2 proxy, including legacy fields */
84+
struct PyArrayDescr2_Proxy {
85+
PyObject_HEAD
86+
PyObject *typeobj;
87+
char kind;
88+
char type;
89+
char byteorder;
90+
char _former_flags;
91+
int type_num;
92+
std::uint64_t flags;
93+
ssize_t elsize;
94+
ssize_t alignment;
95+
PyObject *metadata;
96+
Py_hash_t hash;
97+
void *reserved_null;
98+
/* The following fields only exist if 0 < type_num < 2000 */
99+
struct _arr_descr *subarray;
100+
PyObject *fields;
101+
PyObject *names;
102+
};
103+
104+
71105
struct PyArray_Proxy {
72106
PyObject_HEAD
73107
char *data;
@@ -203,6 +237,8 @@ struct npy_api {
203237
NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_),
204238
};
205239

240+
int PyArray_RUNTIME_VERSION_;
241+
206242
struct PyArray_Dims {
207243
Py_intptr_t *ptr;
208244
int len;
@@ -241,14 +277,6 @@ struct npy_api {
241277
PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *);
242278
int (*PyArray_DescrConverter_)(PyObject *, PyObject **);
243279
bool (*PyArray_EquivTypes_)(PyObject *, PyObject *);
244-
int (*PyArray_GetArrayParamsFromObject_)(PyObject *,
245-
PyObject *,
246-
unsigned char,
247-
PyObject **,
248-
int *,
249-
Py_intptr_t *,
250-
PyObject **,
251-
PyObject *);
252280
PyObject *(*PyArray_Squeeze_)(PyObject *);
253281
// Unused. Not removed because that affects ABI of the class.
254282
int (*PyArray_SetBaseObject_)(PyObject *, PyObject *);
@@ -275,7 +303,6 @@ struct npy_api {
275303
API_PyArray_View = 137,
276304
API_PyArray_DescrConverter = 174,
277305
API_PyArray_EquivTypes = 182,
278-
API_PyArray_GetArrayParamsFromObject = 278,
279306
API_PyArray_SetBaseObject = 282
280307
};
281308

@@ -290,7 +317,8 @@ struct npy_api {
290317
npy_api api;
291318
#define DECL_NPY_API(Func) api.Func##_ = (decltype(api.Func##_)) api_ptr[API_##Func];
292319
DECL_NPY_API(PyArray_GetNDArrayCFeatureVersion);
293-
if (api.PyArray_GetNDArrayCFeatureVersion_() < 0x7) {
320+
api.PyArray_RUNTIME_VERSION_ = api.PyArray_GetNDArrayCFeatureVersion_();
321+
if (api.PyArray_RUNTIME_VERSION_ < 0x7) {
294322
pybind11_fail("pybind11 numpy support requires numpy >= 1.7.0");
295323
}
296324
DECL_NPY_API(PyArray_Type);
@@ -309,7 +337,6 @@ struct npy_api {
309337
DECL_NPY_API(PyArray_View);
310338
DECL_NPY_API(PyArray_DescrConverter);
311339
DECL_NPY_API(PyArray_EquivTypes);
312-
DECL_NPY_API(PyArray_GetArrayParamsFromObject);
313340
DECL_NPY_API(PyArray_SetBaseObject);
314341

315342
#undef DECL_NPY_API
@@ -331,6 +358,14 @@ inline const PyArrayDescr_Proxy *array_descriptor_proxy(const PyObject *ptr) {
331358
return reinterpret_cast<const PyArrayDescr_Proxy *>(ptr);
332359
}
333360

361+
inline const PyArrayDescr1_Proxy *array_descriptor1_proxy(const PyObject *ptr) {
362+
return reinterpret_cast<const PyArrayDescr1_Proxy *>(ptr);
363+
}
364+
365+
inline const PyArrayDescr2_Proxy *array_descriptor2_proxy(const PyObject *ptr) {
366+
return reinterpret_cast<const PyArrayDescr2_Proxy *>(ptr);
367+
}
368+
334369
inline bool check_flags(const void *ptr, int flag) {
335370
return (flag == (array_proxy(ptr)->flags & flag));
336371
}
@@ -610,10 +645,27 @@ class dtype : public object {
610645
}
611646

612647
/// Size of the data type in bytes.
613-
ssize_t itemsize() const { return detail::array_descriptor_proxy(m_ptr)->elsize; }
648+
ssize_t itemsize() const {
649+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
650+
return detail::array_descriptor1_proxy(m_ptr)->elsize;
651+
}
652+
else {
653+
return detail::array_descriptor2_proxy(m_ptr)->elsize;
654+
}
655+
}
614656

615657
/// Returns true for structured data types.
616-
bool has_fields() const { return detail::array_descriptor_proxy(m_ptr)->names != nullptr; }
658+
bool has_fields() const {
659+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
660+
return detail::array_descriptor1_proxy(m_ptr)->names != nullptr;
661+
}
662+
else if (num() < 0 || num() > 2000) {
663+
return false;
664+
}
665+
else {
666+
return detail::array_descriptor2_proxy(m_ptr)->names != nullptr;
667+
}
668+
}
617669

618670
/// Single-character code for dtype's kind.
619671
/// For example, floating point types are 'f' and integral types are 'i'.
@@ -640,10 +692,24 @@ class dtype : public object {
640692
char byteorder() const { return detail::array_descriptor_proxy(m_ptr)->byteorder; }
641693

642694
/// Alignment of the data type
643-
int alignment() const { return detail::array_descriptor_proxy(m_ptr)->alignment; }
695+
ssize_t alignment() const {
696+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
697+
return detail::array_descriptor1_proxy(m_ptr)->alignment;
698+
}
699+
else {
700+
return detail::array_descriptor2_proxy(m_ptr)->alignment;
701+
}
702+
}
644703

645704
/// Flags for the array descriptor
646-
char flags() const { return detail::array_descriptor_proxy(m_ptr)->flags; }
705+
std::uint64_t flags() const {
706+
if (detail::npy_api::get().PyArray_RUNTIME_VERSION_ < 0x12) {
707+
return (unsigned char)detail::array_descriptor1_proxy(m_ptr)->flags;
708+
}
709+
else {
710+
return detail::array_descriptor2_proxy(m_ptr)->flags;
711+
}
712+
}
647713

648714
private:
649715
static object &_dtype_from_pep3118() {
@@ -811,7 +877,7 @@ class array : public buffer {
811877

812878
/// Byte size of a single element
813879
ssize_t itemsize() const {
814-
return detail::array_descriptor_proxy(detail::array_proxy(m_ptr)->descr)->elsize;
880+
return dtype().itemsize();
815881
}
816882

817883
/// Total number of bytes

0 commit comments

Comments
 (0)