Skip to content

Commit 41b21b7

Browse files
authored
accessor: Overload operators to allow in-place assignments on m_base (#1119)
This enables users to write the following for in-place ops: ```cpp nb::object obj = ...; // increment a (mutable) counter value inline. obj.attr("count") += nb::int(1); ``` This is useful in bindings and wrappers because it avoids the duplication of `obj.attr()` calls. Templating the operator declaration allows their use in all accessor implementations. A new `test_accessor` CMake target was defined along with a Python test file, covering the most important accessor types.
1 parent 99668cd commit 41b21b7

File tree

4 files changed

+115
-0
lines changed

4 files changed

+115
-0
lines changed

include/nanobind/nb_accessor.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@
1010
NAMESPACE_BEGIN(NB_NAMESPACE)
1111
NAMESPACE_BEGIN(detail)
1212

13+
#define NB_DECL_ACCESSOR_OP_I(name) \
14+
template <typename T> accessor& name(const api<T> &o);
15+
16+
#define NB_IMPL_ACCESSOR_OP_I(name, op) \
17+
template <typename Impl> template <typename T> \
18+
accessor<Impl>& accessor<Impl>::name(const api<T> &o) { \
19+
PyObject *res = obj_op_2(ptr(), o.derived().ptr(), op); \
20+
Impl::set(m_base, m_key, res); \
21+
return *this; \
22+
}
23+
1324
template <typename Impl> class accessor : public api<accessor<Impl>> {
1425
template <typename T> friend void nanobind::del(accessor<T> &);
1526
template <typename T> friend void nanobind::del(accessor<T> &&);
@@ -37,6 +48,17 @@ template <typename Impl> class accessor : public api<accessor<Impl>> {
3748
NB_INLINE handle base() const { return m_base; }
3849
NB_INLINE object key() const { return steal(Impl::key(m_key)); }
3950

51+
NB_DECL_ACCESSOR_OP_I(operator+=)
52+
NB_DECL_ACCESSOR_OP_I(operator-=)
53+
NB_DECL_ACCESSOR_OP_I(operator*=)
54+
NB_DECL_ACCESSOR_OP_I(operator/=)
55+
NB_DECL_ACCESSOR_OP_I(operator%=)
56+
NB_DECL_ACCESSOR_OP_I(operator|=)
57+
NB_DECL_ACCESSOR_OP_I(operator&=)
58+
NB_DECL_ACCESSOR_OP_I(operator^=)
59+
NB_DECL_ACCESSOR_OP_I(operator<<=)
60+
NB_DECL_ACCESSOR_OP_I(operator>>=)
61+
4062
private:
4163
NB_INLINE void del () { Impl::del(m_base, m_key); }
4264

@@ -205,6 +227,17 @@ accessor<num_item> api<D>::operator[](T index) const {
205227
return { derived(), (Py_ssize_t) index };
206228
}
207229

230+
NB_IMPL_ACCESSOR_OP_I(operator+=, PyNumber_InPlaceAdd)
231+
NB_IMPL_ACCESSOR_OP_I(operator%=, PyNumber_InPlaceRemainder)
232+
NB_IMPL_ACCESSOR_OP_I(operator-=, PyNumber_InPlaceSubtract)
233+
NB_IMPL_ACCESSOR_OP_I(operator*=, PyNumber_InPlaceMultiply)
234+
NB_IMPL_ACCESSOR_OP_I(operator/=, PyNumber_InPlaceTrueDivide)
235+
NB_IMPL_ACCESSOR_OP_I(operator|=, PyNumber_InPlaceOr)
236+
NB_IMPL_ACCESSOR_OP_I(operator&=, PyNumber_InPlaceAnd)
237+
NB_IMPL_ACCESSOR_OP_I(operator^=, PyNumber_InPlaceXor)
238+
NB_IMPL_ACCESSOR_OP_I(operator<<=,PyNumber_InPlaceLshift)
239+
NB_IMPL_ACCESSOR_OP_I(operator>>=,PyNumber_InPlaceRshift)
240+
208241
NAMESPACE_END(detail)
209242

210243
template <typename T, detail::enable_if_t<std::is_arithmetic_v<T>>>

tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ if (NB_TEST_SANITIZERS)
6161
endif()
6262

6363
set(TEST_NAMES
64+
accessor
6465
functions
6566
callbacks
6667
classes
@@ -149,6 +150,7 @@ target_link_libraries(test_inter_module_2_ext PRIVATE inter_module)
149150
set(TEST_FILES
150151
common.py
151152
conftest.py
153+
test_accessor.py
152154
test_callbacks.py
153155
test_classes.py
154156
test_eigen.py

tests/test_accessor.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include <nanobind/nanobind.h>
2+
3+
namespace nb = nanobind;
4+
5+
struct A { int value; };
6+
7+
NB_MODULE(test_accessor_ext, m) {
8+
nb::class_<A>(m, "A")
9+
.def(nb::init<>())
10+
.def_rw("value", &A::value);
11+
12+
m.def("test_str_attr_accessor_inplace_mutation", []() {
13+
nb::object a_ = nb::module_::import_("test_accessor_ext").attr("A")();
14+
a_.attr("value") += nb::int_(1);
15+
return a_;
16+
});
17+
18+
m.def("test_str_item_accessor_inplace_mutation", []() {
19+
nb::dict d;
20+
d["a"] = nb::int_(0);
21+
d["a"] += nb::int_(1);
22+
return d;
23+
});
24+
25+
m.def("test_num_item_list_accessor_inplace_mutation", []() {
26+
nb::list l;
27+
l.append(nb::int_(0));
28+
l[0] += nb::int_(1);
29+
return l;
30+
});
31+
32+
m.def("test_obj_item_accessor_inplace_mutation", []() {
33+
nb::dict d;
34+
nb::int_ key = nb::int_(0);
35+
d[key] = nb::int_(0);
36+
d[key] += nb::int_(1);
37+
return d;
38+
});
39+
}

tests/test_accessor.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import test_accessor_ext as t
2+
3+
4+
def test_01_str_attr_inplace_mutation():
5+
"""
6+
Tests that a C++ expression like obj.attr("foo") += ...
7+
can actually modify the object in-place.
8+
"""
9+
a = t.test_str_attr_accessor_inplace_mutation()
10+
assert a.value == 1
11+
12+
13+
def test_02_str_item_inplace_mutation():
14+
"""
15+
Similar to test 01, but tests obj["foo"] (keyed attribute access)
16+
on the C++ side.
17+
"""
18+
d = t.test_str_item_accessor_inplace_mutation()
19+
assert d.keys() == {"a"}
20+
assert d["a"] == 1
21+
22+
23+
def test_03_num_item_list_inplace_mutation():
24+
"""
25+
Similar to test 01, but tests l[n] (index access)
26+
on the C++ side, where l is an ``nb::list``.
27+
"""
28+
l = t.test_num_item_list_accessor_inplace_mutation()
29+
assert len(l) == 1
30+
assert l[0] == 1
31+
32+
33+
def test_04_obj_item_inplace_mutation():
34+
"""
35+
Similar to test 01, but tests obj[h] (handle access)
36+
on the C++ side.
37+
"""
38+
d = t.test_obj_item_accessor_inplace_mutation()
39+
assert len(d) == 1
40+
assert d.keys() == {0}
41+
assert d[0] == 1 # dict lookup

0 commit comments

Comments
 (0)