Skip to content

Commit 8f602e1

Browse files
committed
ndarray: fast views into existing arrays
CPU loops involving nanobind ndarrays weren't getting properly vectorized. This commit adds *views*, which provide an efficient abstraction that enables better code generation.
1 parent a1ac207 commit 8f602e1

File tree

5 files changed

+302
-45
lines changed

5 files changed

+302
-45
lines changed

docs/api_extra.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,17 @@ section <ndarrays>`.
565565
Return a mutable pointer to the array data. Only enabled when `Scalar` is
566566
not itself ``const``.
567567

568+
.. cpp:function:: template <typename... Extra> auto view()
569+
570+
Returns an nd-array view that is optimized for fast array access on the
571+
CPU. You may optionally specify additional ndarray constraints via the
572+
`Extra` parameter (though a runtime check should first be performed to
573+
ensure that the array possesses these properties).
574+
575+
The returned view provides the operations ``data()``, ``ndim()``,
576+
``shape()``, ``stride()``, and ``operator()`` following the conventions
577+
of the `ndarray` type.
578+
568579
.. cpp:function:: template <typename... Ts> auto& operator()(Ts... indices)
569580

570581
Return a mutable reference to the element at stored at the provided

docs/ndarray.rst

Lines changed: 122 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,12 @@ should therefore prevent such undefined behavior.
8383
:cpp:class:`nb::ndarray\<...\> <ndarray>` accepts template arguments to
8484
specify such constraints. For example the function interface below
8585
guarantees that the implementation is only invoked when it is provided with
86-
a ``MxNx3`` array of 8-bit unsigned integers that is furthermore stored
87-
contiguously in CPU memory using a C-style array ordering convention.
86+
a ``MxNx3`` array of 8-bit unsigned integers.
8887

8988
.. code-block:: cpp
9089
9190
m.def("process", [](nb::ndarray<uint8_t, nb::shape<nb::any, nb::any, 3>,
92-
nb::c_contig, nb::device::cpu> data) {
91+
nb::device::cpu> data) {
9392
// Double brightness of the MxNx3 RGB image
9493
for (size_t y = 0; y < data.shape(0); ++y)
9594
for (size_t x = 0; x < data.shape(1); ++x)
@@ -100,15 +99,16 @@ contiguously in CPU memory using a C-style array ordering convention.
10099
101100
The above example also demonstrates the use of
102101
:cpp:func:`nb::ndarray\<...\>::operator() <ndarray::operator()>`, which
103-
provides direct (i.e., high-performance) read/write access to the array
104-
data. Note that this function is only available when the underlying data
105-
type and ndarray rank are specified. It should only be used when the
106-
array storage is reachable via CPU’s virtual memory address space.
102+
provides direct read/write access to the array contents. Note that this
103+
function is only available when the underlying data type and ndarray dimension
104+
are specified via the :cpp:type:`ndarray\<..\> <ndarray>` template parameters.
105+
It should only be used when the array storage is accessible through the CPU's
106+
virtual memory address space.
107107

108108
.. _ndarray-constraints-1:
109109

110110
Constraint types
111-
~~~~~~~~~~~~~~~~
111+
----------------
112112

113113
The following constraints are available
114114

@@ -153,6 +153,94 @@ count until they go out of scope. It is legal call
153153
when the `GIL <https://wiki.python.org/moin/GlobalInterpreterLock>`__ is not
154154
held.
155155

156+
.. _ndarray-views:
157+
158+
Fast array views
159+
----------------
160+
161+
The following advice applies to performance-sensitive CPU code that reads and
162+
writes arrays using loops that invoke :cpp:func:`nb::ndarray\<...\>::operator()
163+
<ndarray::operator()>`. It does not apply to GPU arrays because they are
164+
usually not accessed in this way.
165+
166+
Consider the following snippet, which fills a 2D array with data:
167+
168+
.. code-block:: cpp
169+
170+
void fill(nb::ndarray<float, nb::ndim<2>, nb::c_contig, nb::device::cpu> arg) {
171+
for (size_t i = 0; i < array.shape(0); ++i)
172+
for (size_t j = 0; j < array.shape(1); ++j)
173+
arg(i, j) = /* ... */;
174+
}
175+
176+
While functional, this code is not perfect. The problem is that to compute the
177+
address of an entry, ``operator()`` accesses the DLPack array descriptor. This
178+
indirection can break certain compiler optimizations.
179+
180+
nanobind provides the method :cpp:func:`ndarray\<...\>::view() <ndarray::view>`
181+
to fix this. It creates a tiny data structure that provides all information
182+
needed to access the array contents, and which can be held within CPU
183+
registers. All relevant compile-time information (:cpp:class:`nb::ndim <ndim>`,
184+
:cpp:class:`nb::shape <shape>`, :cpp:class:`nb::c_contig <c_contig>`,
185+
:cpp:class:`nb::f_contig <f_contig>`) is materialized in this view, which
186+
enables constant propagation, auto-vectorization, and loop unrolling.
187+
188+
An improved version of the example using such a view is shown below:
189+
190+
.. code-block:: cpp
191+
192+
void fill(nb::ndarray<float, nb::ndim<2>, nb::c_contig, nb::device::cpu> arg) {
193+
auto v = array.view(); /// <-- new!
194+
195+
for (size_t i = 0; i < v.shape(0); ++i) // Important; use 'v' instead of 'arg' everywhere in loop
196+
for (size_t j = 0; j < v.shape(1); ++j)
197+
v(i, j) = /* ... */;
198+
}
199+
200+
Note that the view performs no reference counting. You may not store it in a way
201+
that exceeds the lifetime of the original array.
202+
203+
When using OpenMP to parallelize expensive array operations, pass the
204+
``firstprivate(view_1, view_2, ...)`` so that each worker thread can copy the
205+
view into its register file.
206+
207+
.. code-block:: cpp
208+
209+
auto v = array.view();
210+
#pragma omp parallel for schedule(static) firstprivate(v)
211+
for (...) { /* parallel loop */ }
212+
213+
.. _ndarray-runtime-specialization:
214+
215+
Specializing views at runtime
216+
-----------------------------
217+
218+
As mentioned earlier, element access via ``operator()`` only works when both
219+
the array's scalar type and its dimension are specified within the type (i.e.,
220+
when they are known at compile time); the same is also true for array views.
221+
However, sometimes, it is useful that a function can be called with different
222+
array types.
223+
224+
You may use the :cpp:func:`ndarray\<...\>::view() <ndarray::view>` method to
225+
create *specialized* views if a run-time check determines that it is safe to
226+
do so. For example, the function below accepts contiguous CPU arrays and
227+
performs a loop over a specialized 2D ``float`` view when the array is of
228+
this type.
229+
230+
.. code-block:: cpp
231+
232+
void fill(nb::ndarray<nb::c_contig, nb::device::cpu> arg) {
233+
if (arg.dtype() == nb::dtype<float>() && arg.ndim() == 2) {
234+
auto v = array.view<float, nb::ndim<2>>(); // <-- new!
235+
236+
for (size_t i = 0; i < v.shape(0); ++i) {
237+
for (size_t j = 0; j < v.shape(1); ++j) {
238+
v(i, j) = /* ... */;
239+
}
240+
}
241+
} else { /* ... */ }
242+
}
243+
156244
Constraints in type signatures
157245
------------------------------
158246

@@ -364,6 +452,30 @@ interpreted as follows:
364452
- :cpp:enumerator:`rv_policy::move` is unsupported and demoted to
365453
:cpp:enumerator:`rv_policy::copy`.
366454

455+
.. _ndarray_nonstandard_arithmetic:
456+
457+
Nonstandard arithmetic types
458+
----------------------------
459+
460+
Low or extended-precision arithmetic types (e.g., ``int128``, ``float16``,
461+
``bfloat``) are sometimes used but don't have standardized C++ equivalents. If
462+
you wish to exchange arrays based on such types, you must register a partial
463+
overload of ``nanobind::ndarray_traits`` to inform nanobind about it.
464+
465+
For example, the following snippet makes ``__fp16`` (half-precision type on
466+
``aarch64``) available:
467+
468+
.. code-block:: cpp
469+
470+
namespace nanobind {
471+
template <> struct ndarray_traits<__fp16> {
472+
static constexpr bool is_float = true;
473+
static constexpr bool is_bool = false;
474+
static constexpr bool is_int = false;
475+
static constexpr bool is_signed = true;
476+
};
477+
};
478+
367479
Limitations
368480
-----------
369481

@@ -383,30 +495,9 @@ internal representations (*dtypes*), including
383495
nanobind's :cpp:class:`nb::ndarray\<...\> <ndarray>` is based on the `DLPack
384496
<https://github.com/dmlc/dlpack>`__ array exchange protocol, which causes it to
385497
be more restrictive. Presently supported dtypes include signed/unsigned
386-
integers, floating point values, and boolean values.
498+
integers, floating point values, and boolean values. Some :ref:`nonstandard
499+
arithmetic types <ndarray_nonstandard_arithmetic>` can be supported as well.
387500

388501
Nanobind can receive and return read-only arrays via the buffer protocol used
389502
to exchange data with NumPy. The DLPack interface currently ignores this
390503
annotation.
391-
392-
Supporting nonstandard arithmetic types
393-
---------------------------------------
394-
395-
Low or extended-precision arithmetic types (e.g., ``int128``, ``float16``,
396-
``bfloat``) are sometimes used but don't have standardized C++ equivalents. If
397-
you wish to exchange arrays based on such types, you must register a partial
398-
overload of ``nanobind::ndarray_traits`` to inform nanobind about it.
399-
400-
For example, the following snippet makes ``__fp16`` (half-precision type on
401-
``aarch64``) available:
402-
403-
.. code-block:: cpp
404-
405-
namespace nanobind {
406-
template <> struct ndarray_traits<__fp16> {
407-
static constexpr bool is_float = true;
408-
static constexpr bool is_bool = false;
409-
static constexpr bool is_int = false;
410-
static constexpr bool is_signed = true;
411-
};
412-
};

include/nanobind/ndarray.h

Lines changed: 116 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ template <typename... Ts> struct ndarray_info {
247247
using shape_type = void;
248248
constexpr static auto name = const_name("ndarray");
249249
constexpr static ndarray_framework framework = ndarray_framework::none;
250+
constexpr static char order = '\0';
250251
};
251252

252253
template <typename T, typename... Ts> struct ndarray_info<T, Ts...> : ndarray_info<Ts...> {
@@ -259,6 +260,14 @@ template <size_t... Is, typename... Ts> struct ndarray_info<shape<Is...>, Ts...>
259260
using shape_type = shape<Is...>;
260261
};
261262

263+
template <typename... Ts> struct ndarray_info<c_contig, Ts...> : ndarray_info<Ts...> {
264+
constexpr static char order = 'C';
265+
};
266+
267+
template <typename... Ts> struct ndarray_info<f_contig, Ts...> : ndarray_info<Ts...> {
268+
constexpr static char order = 'F';
269+
};
270+
262271
template <typename... Ts> struct ndarray_info<numpy, Ts...> : ndarray_info<Ts...> {
263272
constexpr static auto name = const_name("numpy.ndarray");
264273
constexpr static ndarray_framework framework = ndarray_framework::numpy;
@@ -282,6 +291,64 @@ template <typename... Ts> struct ndarray_info<jax, Ts...> : ndarray_info<Ts...>
282291

283292
NAMESPACE_END(detail)
284293

294+
template <typename Scalar, typename Shape, char Order> struct ndarray_view {
295+
static constexpr size_t Dim = Shape::size;
296+
297+
ndarray_view() = default;
298+
ndarray_view(const ndarray_view &) = default;
299+
ndarray_view(ndarray_view &&) = default;
300+
ndarray_view &operator=(const ndarray_view &) = default;
301+
ndarray_view &operator=(ndarray_view &&) noexcept = default;
302+
~ndarray_view() noexcept = default;
303+
304+
template <typename... Ts> NB_INLINE Scalar &operator()(Ts... indices) const {
305+
static_assert(
306+
sizeof...(Ts) == Dim,
307+
"ndarray_view::operator(): invalid number of arguments");
308+
309+
const int64_t indices_i64[] { (int64_t) indices... };
310+
int64_t offset = 0;
311+
for (size_t i = 0; i < Dim; ++i)
312+
offset += indices_i64[i] * m_strides[i];
313+
314+
return *(m_data + offset);
315+
}
316+
317+
size_t ndim() const { return Dim; }
318+
size_t shape(size_t i) const { return m_shape[i]; }
319+
int64_t stride(size_t i) const { return m_strides[i]; }
320+
Scalar *data() const { return m_data; }
321+
322+
private:
323+
template <typename...> friend class ndarray;
324+
325+
template <size_t... I1, size_t... I2>
326+
ndarray_view(Scalar *data, const int64_t *shape, const int64_t *strides,
327+
std::index_sequence<I1...>, nanobind::shape<I2...>)
328+
: m_data(data) {
329+
330+
/* Initialize shape/strides with compile-time knowledge if
331+
available (to permit vectorization, loop unrolling, etc.) */
332+
((m_shape[I1] = (I2 == any) ? shape[I1] : I2), ...);
333+
((m_strides[I1] = strides[I1]), ...);
334+
335+
if constexpr (Order == 'F') {
336+
m_strides[0] = 1;
337+
for (size_t i = 1; i < Dim; ++i)
338+
m_strides[i] = m_strides[i - 1] * m_shape[i - 1];
339+
} else if constexpr (Order == 'C') {
340+
m_strides[Dim - 1] = 1;
341+
for (Py_ssize_t i = (Py_ssize_t) Dim - 2; i >= 0; --i)
342+
m_strides[i] = m_strides[i + 1] * m_shape[i + 1];
343+
}
344+
}
345+
346+
Scalar *m_data = nullptr;
347+
int64_t m_shape[Dim] { };
348+
int64_t m_strides[Dim] { };
349+
};
350+
351+
285352
template <typename... Args> class ndarray {
286353
public:
287354
template <typename...> friend class ndarray;
@@ -405,24 +472,59 @@ template <typename... Args> class ndarray {
405472
byte_offset(indices...));
406473
}
407474

475+
template <typename... Extra> NB_INLINE auto view() {
476+
using Info2 = typename ndarray<Args..., Extra...>::Info;
477+
using Scalar2 = typename Info2::scalar_type;
478+
using Shape2 = typename Info2::shape_type;
479+
480+
constexpr bool has_scalar = !std::is_same_v<Scalar2, void>,
481+
has_shape = !std::is_same_v<Shape2, void>;
482+
483+
static_assert(has_scalar,
484+
"To use the ndarray::view<..>() method, you must add a scalar type "
485+
"annotation (e.g. 'float') to the template parameters of the parent "
486+
"ndarray, or to the call to .view<..>()");
487+
488+
static_assert(has_shape,
489+
"To use the ndarray::view<..>() method, you must add a shape<..> "
490+
"or ndim<..> annotation to the template parameters of the parent "
491+
"ndarray, or to the call to .view<..>()");
492+
493+
if constexpr (has_scalar && has_shape) {
494+
return ndarray_view<Scalar2, Shape2, Info2::order>(
495+
(Scalar2 *) data(), shape_ptr(), stride_ptr(),
496+
std::make_index_sequence<Shape2::size>(), Shape2());
497+
} else {
498+
return nullptr;
499+
}
500+
}
501+
408502
private:
409503
template <typename... Ts>
410504
NB_INLINE int64_t byte_offset(Ts... indices) const {
411-
static_assert(
412-
!std::is_same_v<Scalar, void>,
413-
"To use nb::ndarray::operator(), you must add a scalar type "
505+
constexpr bool has_scalar = !std::is_same_v<Scalar, void>,
506+
has_shape = !std::is_same_v<typename Info::shape_type, void>;
507+
508+
static_assert(has_scalar,
509+
"To use ndarray::operator(), you must add a scalar type "
414510
"annotation (e.g. 'float') to the ndarray template parameters.");
415-
static_assert(
416-
!std::is_same_v<Scalar, void>,
417-
"To use nb::ndarray::operator(), you must add a nb::shape<> "
418-
"annotation to the ndarray template parameters.");
419-
static_assert(sizeof...(Ts) == Info::shape_type::size,
420-
"nb::ndarray::operator(): invalid number of arguments");
421-
size_t counter = 0;
422-
int64_t index = 0;
423-
((index += int64_t(indices) * m_dltensor.strides[counter++]), ...);
424-
425-
return (int64_t) m_dltensor.byte_offset + index * sizeof(typename Info::scalar_type);
511+
512+
static_assert(has_shape,
513+
"To use ndarray::operator(), you must add a shape<> or "
514+
"ndim<> annotation to the ndarray template parameters.");
515+
516+
if constexpr (has_scalar && has_shape) {
517+
static_assert(sizeof...(Ts) == Info::shape_type::size,
518+
"ndarray::operator(): invalid number of arguments");
519+
520+
size_t counter = 0;
521+
int64_t index = 0;
522+
((index += int64_t(indices) * m_dltensor.strides[counter++]), ...);
523+
524+
return (int64_t) m_dltensor.byte_offset + index * sizeof(typename Info::scalar_type);
525+
} else {
526+
return 0;
527+
}
426528
}
427529

428530
detail::ndarray_handle *m_handle = nullptr;

0 commit comments

Comments
 (0)