Skip to content

Commit 4148e83

Browse files
committed
added a few missing methods to container wrappers
1 parent 369e79a commit 4148e83

File tree

6 files changed

+101
-2
lines changed

6 files changed

+101
-2
lines changed

docs/api_core.rst

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,14 @@ Wrapper classes
726726
negative). When `T` does not already represent a wrapped Python object,
727727
the function performs a cast.
728728

729+
.. cpp:function:: void clear()
730+
731+
Clear the list entries.
732+
733+
.. cpp:function:: void extend(handle h)
734+
735+
Analogous to the ``.extend(h)`` method of ``list`` in Python.
736+
729737
.. cpp:function:: template <typename T, enable_if_t<std::is_arithmetic_v<T>> = 1> detail::accessor<num_item_list> operator[](T key) const
730738

731739
Analogous to ``self[key]`` in Python, where ``key`` is an arithmetic
@@ -794,6 +802,10 @@ Wrapper classes
794802

795803
Clear the contents of the dictionary.
796804

805+
.. cpp:function:: void update(handle h)
806+
807+
Analogous to the ``.update(h)`` method of ``dict`` in Python.
808+
797809
.. cpp:class:: set: public object
798810

799811
Wrapper class representing Python ``set`` instances.
@@ -818,7 +830,14 @@ Wrapper classes
818830

819831
.. cpp:function:: void clear()
820832

821-
Clear the contents of the set
833+
Clear the contents of the set.
834+
835+
.. cpp:function:: template <typename T> bool discard(T&& key)
836+
837+
Analogous to the ``.discard(h)`` method of the ``set`` type in Python.
838+
Returns ``true`` if the item was deleted successfully, and ``false`` if
839+
the value was not present. When `T` does not already represent a wrapped
840+
Python object, the function performs a cast.
822841

823842
.. cpp:class:: module_: public object
824843

include/nanobind/nb_cast.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,14 +565,20 @@ template <typename T> bool set::contains(T&& key) const {
565565
return rv == 1;
566566
}
567567

568-
569568
template <typename T> void set::add(T&& key) {
570569
object o = nanobind::cast((detail::forward_t<T>) key);
571570
int rv = PySet_Add(m_ptr, o.ptr());
572571
if (rv == -1)
573572
raise_python_error();
574573
}
575574

575+
template <typename T> bool set::discard(T &&value) {
576+
object o = nanobind::cast((detail::forward_t<T>) value);
577+
int rv = PySet_Discard(m_ptr, o.ptr());
578+
if (rv < 0)
579+
raise_python_error();
580+
return rv == 1;
581+
}
576582

577583
template <typename T> bool mapping::contains(T&& key) const {
578584
object o = nanobind::cast((detail::forward_t<T>) key);

include/nanobind/nb_types.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,16 @@ class list : public object {
471471
template <typename T, detail::enable_if_t<std::is_arithmetic_v<T>> = 1>
472472
detail::accessor<detail::num_item_list> operator[](T key) const;
473473

474+
void clear() {
475+
if (PyList_SetSlice(m_ptr, 0, PY_SSIZE_T_MAX, nullptr))
476+
raise_python_error();
477+
}
478+
479+
void extend(handle h) {
480+
if (PyList_SetSlice(m_ptr, PY_SSIZE_T_MAX, PY_SSIZE_T_MAX, h.ptr()))
481+
raise_python_error();
482+
}
483+
474484
#if !defined(Py_LIMITED_API) && !defined(PYPY_VERSION)
475485
detail::fast_iterator begin() const;
476486
detail::fast_iterator end() const;
@@ -488,6 +498,10 @@ class dict : public object {
488498
list items() const { return steal<list>(detail::obj_op_1(m_ptr, PyDict_Items)); }
489499
template <typename T> bool contains(T&& key) const;
490500
void clear() { PyDict_Clear(m_ptr); }
501+
void update(handle h) {
502+
if (PyDict_Update(m_ptr, h.ptr()))
503+
raise_python_error();
504+
}
491505
};
492506

493507

@@ -501,6 +515,7 @@ class set : public object {
501515
if (PySet_Clear(m_ptr))
502516
raise_python_error();
503517
}
518+
template <typename T> bool discard(T &&value);
504519
};
505520

506521
class sequence : public object {

tests/test_functions.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,4 +314,48 @@ NB_MODULE(test_functions_ext, m) {
314314
"i"_a = 1, nb::kw_only(), "j"_a = 2);
315315

316316
m.def("test_any", [](nb::any a) { return a; } );
317+
318+
m.def("test_wrappers_list", []{
319+
nb::list l1, l2;
320+
l1.append(1);
321+
l2.append(2);
322+
l1.extend(l2);
323+
324+
bool b = nb::len(l1) == 2 && nb::len(l2) == 1 &&
325+
l1[0].equal(nb::int_(1)) && l1[1].equal(nb::int_(2));
326+
327+
l1.clear();
328+
return b && nb::len(l1) == 0;
329+
});
330+
331+
m.def("test_wrappers_dict", []{
332+
nb::dict d1, d2;
333+
d1["a"] = 1;
334+
d2["b"] = 2;
335+
d1.update(d2);
336+
337+
bool b = nb::len(d1) == 2 && nb::len(d2) == 1 &&
338+
d1["a"].equal(nb::int_(1)) &&
339+
d1["b"].equal(nb::int_(2));
340+
341+
d1.clear();
342+
return b && nb::len(d1) == 0;
343+
});
344+
345+
m.def("test_wrappers_set", []{
346+
nb::set s;
347+
s.add("a");
348+
s.add("b");
349+
350+
bool b = nb::len(s) == 2 && s.contains("a") && s.contains("b");
351+
352+
b &= s.discard("a");
353+
b &= !s.discard("q");
354+
355+
b &= !s.contains("a") && s.contains("b");
356+
s.clear();
357+
b &= s.size() == 0;
358+
359+
return b;
360+
});
317361
}

tests/test_functions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,3 +581,12 @@ def test41_any():
581581
s = "hello"
582582
assert t.test_any(s) is s
583583
assert t.test_any.__doc__ == "test_any(arg: typing.Any, /) -> typing.Any"
584+
585+
def test42_wrappers_list():
586+
assert t.test_wrappers_list()
587+
588+
def test43_wrappers_dict():
589+
assert t.test_wrappers_dict()
590+
591+
def test43_wrappers_set():
592+
assert t.test_wrappers_set()

tests/test_functions_ext.pyi.ref

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,9 @@ def test_tuple() -> tuple: ...
187187

188188
@overload
189189
def test_tuple(arg: tuple, /) -> int: ...
190+
191+
def test_wrappers_dict() -> bool: ...
192+
193+
def test_wrappers_list() -> bool: ...
194+
195+
def test_wrappers_set() -> bool: ...

0 commit comments

Comments
 (0)