Skip to content

Commit 2a3e846

Browse files
phmbressanMateusStanoGui-FernandesBR
authored
ENH: Discretized and No-Pickle Encoding Options (#827)
* ENH: add an option to discretize callable sources encoding. * TST: implement testing for discretized encoding. * ENH: allow for disallowing pickle on encoding. * MNT: Update CHANGELOG. * MNT: change pickle encoding name to allow_pickle and test it. * MNT: Tweak the discretization bounds. --------- Co-authored-by: Mateus Stano Junqueira <69485049+MateusStano@users.noreply.github.com> Co-authored-by: Gui-FernandesBR <63590233+Gui-FernandesBR@users.noreply.github.com>
1 parent 75b8e5f commit 2a3e846

24 files changed

+388
-145
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ Attention: The newest changes should be on top -->
3232

3333
### Added
3434

35-
- ENH: Add the Coriolis Force to the Flight class [#799](https://github.com/RocketPy-Team/RocketPy/pull/799)
35+
- ENH: Discretized and No-Pickle Encoding Options [#827] (https://github.com/RocketPy-Team/RocketPy/pull/827)
36+
- ENH: Add the Coriolis Force to the Flight class [#799](https://github.com/RocketPy-Team/RocketPy/pull/799)
3637

3738
### Changed
3839

rocketpy/_encoders.py

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,46 @@
1313

1414
class RocketPyEncoder(json.JSONEncoder):
1515
"""Custom JSON encoder for RocketPy objects. It defines how to encode
16-
different types of objects to a JSON supported format."""
16+
different types of objects to a JSON supported format.
17+
"""
1718

1819
def __init__(self, *args, **kwargs):
20+
"""Initializes the encoder with parameter options.
21+
22+
Parameters
23+
----------
24+
*args : tuple
25+
Positional arguments to pass to the parent class.
26+
**kwargs : dict
27+
Keyword arguments to configure the encoder. The following
28+
options are available:
29+
- include_outputs: bool, whether to include simulation outputs.
30+
Default is False.
31+
- include_function_data: bool, whether to include Function
32+
data in the encoding. If False, Functions will be encoded by their
33+
``__repr__``. This is useful for reducing the size of the outputs,
34+
but it prevents full restoration of the object upon decoding.
35+
Default is True.
36+
- discretize: bool, whether to discretize Functions whose source
37+
are callables. If True, the accuracy of the decoding may be reduced.
38+
Default is False.
39+
- allow_pickle: bool, whether to pickle callable objects. If
40+
False, callable sources (such as user-defined functions, parachute
41+
triggers or simulation callable outputs) will have their name
42+
stored instead of the function itself. This is useful for
43+
reducing the size of the outputs, but it prevents full restoration
44+
of the object upon decoding.
45+
Default is True.
46+
"""
1947
self.include_outputs = kwargs.pop("include_outputs", False)
2048
self.include_function_data = kwargs.pop("include_function_data", True)
49+
self.discretize = kwargs.pop("discretize", False)
50+
self.allow_pickle = kwargs.pop("allow_pickle", True)
2151
super().__init__(*args, **kwargs)
2252

2353
def default(self, o):
24-
if isinstance(
25-
o,
26-
(
27-
np.int_,
28-
np.intc,
29-
np.intp,
30-
np.int8,
31-
np.int16,
32-
np.int32,
33-
np.int64,
34-
np.uint8,
35-
np.uint16,
36-
np.uint32,
37-
np.uint64,
38-
),
39-
):
40-
return int(o)
41-
elif isinstance(o, (np.float16, np.float32, np.float64)):
42-
return float(o)
54+
if isinstance(o, np.generic):
55+
return o.item()
4356
elif isinstance(o, np.ndarray):
4457
return o.tolist()
4558
elif isinstance(o, datetime):
@@ -50,11 +63,19 @@ def default(self, o):
5063
if not self.include_function_data:
5164
return str(o)
5265
else:
53-
encoding = o.to_dict(self.include_outputs)
66+
encoding = o.to_dict(
67+
include_outputs=self.include_outputs,
68+
discretize=self.discretize,
69+
allow_pickle=self.allow_pickle,
70+
)
5471
encoding["signature"] = get_class_signature(o)
5572
return encoding
5673
elif hasattr(o, "to_dict"):
57-
encoding = o.to_dict(self.include_outputs)
74+
encoding = o.to_dict(
75+
include_outputs=self.include_outputs,
76+
discretize=self.discretize,
77+
allow_pickle=self.allow_pickle,
78+
)
5879
encoding = remove_circular_references(encoding)
5980

6081
encoding["signature"] = get_class_signature(o)

rocketpy/environment/environment.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2630,7 +2630,21 @@ def decimal_degrees_to_arc_seconds(angle):
26302630
arc_seconds = (remainder * 60 - arc_minutes) * 60
26312631
return degrees, arc_minutes, arc_seconds
26322632

2633-
def to_dict(self, include_outputs=False):
2633+
def to_dict(self, **kwargs):
2634+
wind_velocity_x = self.wind_velocity_x
2635+
wind_velocity_y = self.wind_velocity_y
2636+
wind_heading = self.wind_heading
2637+
wind_direction = self.wind_direction
2638+
wind_speed = self.wind_speed
2639+
density = self.density
2640+
if kwargs.get("discretize", False):
2641+
wind_velocity_x = wind_velocity_x.set_discrete(0, self.max_expected_height)
2642+
wind_velocity_y = wind_velocity_y.set_discrete(0, self.max_expected_height)
2643+
wind_heading = wind_heading.set_discrete(0, self.max_expected_height)
2644+
wind_direction = wind_direction.set_discrete(0, self.max_expected_height)
2645+
wind_speed = wind_speed.set_discrete(0, self.max_expected_height)
2646+
density = density.set_discrete(0, self.max_expected_height)
2647+
26342648
env_dict = {
26352649
"gravity": self.gravity,
26362650
"date": self.date,
@@ -2643,15 +2657,15 @@ def to_dict(self, include_outputs=False):
26432657
"atmospheric_model_type": self.atmospheric_model_type,
26442658
"pressure": self.pressure,
26452659
"temperature": self.temperature,
2646-
"wind_velocity_x": self.wind_velocity_x,
2647-
"wind_velocity_y": self.wind_velocity_y,
2648-
"wind_heading": self.wind_heading,
2649-
"wind_direction": self.wind_direction,
2650-
"wind_speed": self.wind_speed,
2660+
"wind_velocity_x": wind_velocity_x,
2661+
"wind_velocity_y": wind_velocity_y,
2662+
"wind_heading": wind_heading,
2663+
"wind_direction": wind_direction,
2664+
"wind_speed": wind_speed,
26512665
}
26522666

2653-
if include_outputs:
2654-
env_dict["density"] = self.density
2667+
if kwargs.get("include_outputs", False):
2668+
env_dict["density"] = density
26552669
env_dict["barometric_height"] = self.barometric_height
26562670
env_dict["speed_of_sound"] = self.speed_of_sound
26572671
env_dict["dynamic_viscosity"] = self.dynamic_viscosity

rocketpy/mathutils/function.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3565,7 +3565,7 @@ def __validate_extrapolation(self, extrapolation):
35653565
extrapolation = "natural"
35663566
return extrapolation
35673567

3568-
def to_dict(self, include_outputs=False): # pylint: disable=unused-argument
3568+
def to_dict(self, **kwargs): # pylint: disable=unused-argument
35693569
"""Serializes the Function instance to a dictionary.
35703570
35713571
Returns
@@ -3576,7 +3576,10 @@ def to_dict(self, include_outputs=False): # pylint: disable=unused-argument
35763576
source = self.source
35773577

35783578
if callable(source):
3579-
source = to_hex_encode(source)
3579+
if kwargs.get("allow_pickle", True):
3580+
source = to_hex_encode(source)
3581+
else:
3582+
source = source.__name__
35803583

35813584
return {
35823585
"source": source,

rocketpy/mathutils/vector_matrix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def zeros():
403403
"""Returns the zero vector."""
404404
return Vector([0, 0, 0])
405405

406-
def to_dict(self, include_outputs=False): # pylint: disable=unused-argument
406+
def to_dict(self, **kwargs): # pylint: disable=unused-argument
407407
"""Returns the vector as a JSON compatible element."""
408408
return list(self.components)
409409

@@ -1007,7 +1007,7 @@ def __repr__(self):
10071007
+ f" [{self.zx}, {self.zy}, {self.zz}])"
10081008
)
10091009

1010-
def to_dict(self, include_outputs=False): # pylint: disable=unused-argument
1010+
def to_dict(self, **kwargs): # pylint: disable=unused-argument
10111011
"""Returns the matrix as a JSON compatible element."""
10121012
return [list(row) for row in self.components]
10131013

rocketpy/motors/fluid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def __str__(self):
6161

6262
return f"Fluid: {self.name}"
6363

64-
def to_dict(self, include_outputs=False): # pylint: disable=unused-argument
64+
def to_dict(self, **kwargs): # pylint: disable=unused-argument
6565
return {"name": self.name, "density": self.density}
6666

6767
@classmethod

rocketpy/motors/hybrid_motor.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -641,8 +641,8 @@ def draw(self, *, filename=None):
641641
"""
642642
self.plots.draw(filename=filename)
643643

644-
def to_dict(self, include_outputs=False):
645-
data = super().to_dict(include_outputs)
644+
def to_dict(self, **kwargs):
645+
data = super().to_dict(**kwargs)
646646
data.update(
647647
{
648648
"grain_number": self.grain_number,
@@ -660,13 +660,18 @@ def to_dict(self, include_outputs=False):
660660
}
661661
)
662662

663-
if include_outputs:
663+
if kwargs.get("include_outputs", False):
664+
burn_rate = self.solid.burn_rate
665+
if kwargs.get("discretize", False):
666+
burn_rate = burn_rate.set_discrete_based_on_model(
667+
self.thrust, mutate_self=False
668+
)
664669
data.update(
665670
{
666671
"grain_inner_radius": self.solid.grain_inner_radius,
667672
"grain_height": self.solid.grain_height,
668673
"burn_area": self.solid.burn_area,
669-
"burn_rate": self.solid.burn_rate,
674+
"burn_rate": burn_rate,
670675
}
671676
)
672677

rocketpy/motors/liquid_motor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,8 +497,8 @@ def draw(self, *, filename=None):
497497
"""
498498
self.plots.draw(filename=filename)
499499

500-
def to_dict(self, include_outputs=False):
501-
data = super().to_dict(include_outputs)
500+
def to_dict(self, **kwargs):
501+
data = super().to_dict(**kwargs)
502502
data.update(
503503
{
504504
"positioned_tanks": [

0 commit comments

Comments
 (0)