diff --git a/src/aero_particle.F90 b/src/aero_particle.F90 index e87c1041..ea00442b 100644 --- a/src/aero_particle.F90 +++ b/src/aero_particle.F90 @@ -570,4 +570,75 @@ subroutine f_aero_particle_refract_core( & end subroutine + subroutine f_aero_particle_set_weight_class( & + aero_particle_ptr_c, & + weight_class & + ) bind(C) + + type(aero_particle_t), pointer :: aero_particle_ptr_f => null() + type(c_ptr), intent(in) :: aero_particle_ptr_c + integer(c_int), intent(in) :: weight_class + + call c_f_pointer(aero_particle_ptr_c, aero_particle_ptr_f) + + call aero_particle_set_weight(aero_particle_ptr_f, i_class=weight_class) + + end subroutine + + subroutine f_aero_particle_get_weight_class( & + aero_particle_ptr_c, & + weight_class & + ) bind(C) + + type(aero_particle_t), pointer :: aero_particle_ptr_f => null() + type(c_ptr), intent(in) :: aero_particle_ptr_c + integer(c_int), intent(out) :: weight_class + + call c_f_pointer(aero_particle_ptr_c, aero_particle_ptr_f) + + weight_class = aero_particle_ptr_f%weight_class + + end subroutine + + subroutine f_aero_particle_set_weight_group( & + aero_particle_ptr_c, & + weight_group & + ) bind(C) + + type(aero_particle_t), pointer :: aero_particle_ptr_f => null() + type(c_ptr), intent(in) :: aero_particle_ptr_c + integer(c_int), intent(in) :: weight_group + + call c_f_pointer(aero_particle_ptr_c, aero_particle_ptr_f) + + call aero_particle_set_weight(aero_particle_ptr_f, i_group=weight_group) + + end subroutine + + subroutine f_aero_particle_get_weight_group( & + aero_particle_ptr_c, & + weight_group & + ) bind(C) + + type(aero_particle_t), pointer :: aero_particle_ptr_f => null() + type(c_ptr), intent(in) :: aero_particle_ptr_c + integer(c_int), intent(out) :: weight_group + + call c_f_pointer(aero_particle_ptr_c, aero_particle_ptr_f) + + weight_group = aero_particle_ptr_f%weight_group + + end subroutine + + subroutine f_aero_particle_new_id(aero_particle_ptr_c) bind(C) + + type(aero_particle_t), pointer :: aero_particle_ptr_f => null() + type(c_ptr), intent(in) :: aero_particle_ptr_c + + call c_f_pointer(aero_particle_ptr_c, aero_particle_ptr_f) + + call aero_particle_new_id(aero_particle_ptr_f) + + end subroutine + end module diff --git a/src/aero_particle.hpp b/src/aero_particle.hpp index 7593dd58..3f7d176f 100644 --- a/src/aero_particle.hpp +++ b/src/aero_particle.hpp @@ -45,6 +45,11 @@ extern "C" void f_aero_particle_id(const void *aero_particle_ptr, int64_t *val) extern "C" void f_aero_particle_frozen(const void *aero_particle_ptr, int *val) noexcept; extern "C" void f_aero_particle_refract_shell(const void *aero_particle_ptr, std::complex *val, const int *arr_size) noexcept; extern "C" void f_aero_particle_refract_core(const void *aero_particle_ptr, std::complex *val, const int *arr_size) noexcept; +extern "C" void f_aero_particle_set_weight_class(void *ptr, const int *weight_class) noexcept; +extern "C" void f_aero_particle_get_weight_class(const void *ptr, int *weight_class) noexcept; +extern "C" void f_aero_particle_set_weight_group(void *ptr, const int *weight_group) noexcept; +extern "C" void f_aero_particle_get_weight_group(const void *ptr, int *weight_group) noexcept; +extern "C" void f_aero_particle_new_id(void *ptr) noexcept; struct AeroParticle { PMCResource ptr; @@ -408,4 +413,43 @@ struct AeroParticle { ); return refract_core; } + + static void set_weight_class(AeroParticle &self, const int weight_class) { + f_aero_particle_set_weight_class( + self.ptr.f_arg_non_const(), + &weight_class + ); + } + + static auto get_weight_class(const AeroParticle &self) { + int weight_class; + + f_aero_particle_get_weight_class( + self.ptr.f_arg(), + &weight_class + ); + return weight_class; + } + + static void set_weight_group(AeroParticle &self, const int weight_group) { + f_aero_particle_set_weight_group( + self.ptr.f_arg_non_const(), + &weight_group + ); + } + + static auto get_weight_group(const AeroParticle &self) { + int weight_group; + + f_aero_particle_get_weight_group( + self.ptr.f_arg(), + &weight_group + ); + return weight_group; + } + + static void new_id(AeroParticle &self) { + + f_aero_particle_new_id(self.ptr.f_arg_non_const()); + } }; diff --git a/src/pypartmc.cpp b/src/pypartmc.cpp index c758a2a0..34f2d6dc 100644 --- a/src/pypartmc.cpp +++ b/src/pypartmc.cpp @@ -283,6 +283,11 @@ NB_MODULE(_PyPartMC, m) { "Reset an aero_particle to be zero.") .def("set_vols", AeroParticle::set_vols, "Set the aerosol particle volumes.") + .def_prop_rw("weight_group", AeroParticle::get_weight_group, AeroParticle::set_weight_group, + "Weighting function group number.") + .def_prop_rw("weight_class", AeroParticle::get_weight_class, AeroParticle::set_weight_class, + "Weighting function class number.") + .def("new_id", AeroParticle::new_id, "Assigns a new unique particle ID") ; nb::class_(m, "AeroState", diff --git a/tests/test_aero_particle.py b/tests/test_aero_particle.py index 02ea55b3..2a10799f 100644 --- a/tests/test_aero_particle.py +++ b/tests/test_aero_particle.py @@ -415,6 +415,26 @@ def test_set_vols(): # assert assert sut.volumes == [3, 2, 1] + @staticmethod + def test_set_weighting(): + # arrange + aero_data = ppmc.AeroData(aero_data_arg) + volumes = [1, 2, 3] + sut = ppmc.AeroParticle(aero_data, volumes) + aero_data = None + volumes = None + gc.collect() + + # act + weight_class = 1 + weight_group = 2 + sut.weight_class = weight_class + sut.weight_group = weight_group + + # assert + assert sut.weight_class == weight_class + assert sut.weight_group == weight_group + @staticmethod def test_absorb_cross_sect(): # arrange @@ -490,6 +510,22 @@ def test_sources(): assert len(sources) == aero_dist.n_mode assert isinstance(sources[0], int) + @staticmethod + def test_get_weighting(): + # arrange + aero_data = ppmc.AeroData(AERO_DATA_CTOR_ARG_MINIMAL) + aero_dist = ppmc.AeroDist(aero_data, AERO_DIST_CTOR_ARG_MINIMAL) + aero_state = ppmc.AeroState(aero_data, *AERO_STATE_CTOR_ARG_MINIMAL) + _ = aero_state.dist_sample(aero_dist, 1.0, 0.0) + sut = aero_state.particle(0) + # act + weight_class = sut.weight_class + weight_group = sut.weight_group + + # assert + assert weight_class > 0 + assert weight_group > 0 + @staticmethod def test_least_create_time(): # arrange @@ -544,6 +580,19 @@ def test_id(): assert min(ids) > 0 assert len(np.unique(ids)) == len(aero_state) + @staticmethod + def test_new_id(): + # arrange + aero_data = ppmc.AeroData(AERO_DATA_CTOR_ARG_MINIMAL) + sut = ppmc.AeroParticle(aero_data, [123]) + + # act + id_orig = sut.id + sut.new_id() + + # assert + assert sut.id != id_orig + @staticmethod def test_is_frozen(): # arrange