Skip to content

Commit c4fce7c

Browse files
committed
Refactor profile_plot xlabel handling. Add clarifiying comment for matplotlib margin properties.
1 parent 22c6acf commit c4fce7c

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

src/optimagic/visualization/profile_plot.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,35 +12,19 @@
1212
from optimagic.visualization.backends import line_plot
1313
from optimagic.visualization.plotting_utilities import LineData, get_palette_cycle
1414

15-
PROFILE_PLOT_XLABELS: dict[str, dict[bool, str]] = {
16-
"walltime": {
17-
True: "Multiple of Minimal Wall Time{linebreak}Needed to Solve the Problem",
18-
False: "Wall Time Needed to Solve the Problem",
19-
},
20-
"n_evaluations": {
21-
True: "Multiple of Minimal Number of Function Evaluations{linebreak}"
22-
"Needed to Solve the Problem",
23-
False: "Number of Function Evaluations",
24-
},
25-
"n_batches": {
26-
True: "Multiple of Minimal Number of Batches{linebreak}"
27-
"Needed to Solve the Problem",
28-
False: "Number of Batches",
29-
},
30-
}
31-
3215
BACKEND_TO_PROFILE_PLOT_LEGEND_PROPERTIES: dict[str, dict[str, Any]] = {
16+
"plotly": {"title": {"text": "algorithm"}},
3317
"matplotlib": {
3418
"bbox_to_anchor": (1.02, 1),
3519
"loc": "upper left",
3620
"fontsize": "x-small",
3721
"title": "algorithm",
3822
},
39-
"plotly": {"title": {"text": "algorithm"}},
4023
}
4124

4225
BACKEND_TO_PROFILE_PLOT_MARGIN_PROPERTIES: dict[str, dict[str, Any]] = {
4326
"plotly": {"l": 10, "r": 10, "t": 30, "b": 30},
27+
# "matplotlib": handles margins automatically via tight_layout()
4428
}
4529

4630

@@ -151,7 +135,7 @@ def profile_plot(
151135
fig = line_plot(
152136
lines,
153137
backend=backend,
154-
xlabel=PROFILE_PLOT_XLABELS[runtime_measure][normalize_runtime],
138+
xlabel=_get_profile_plot_xlabel(runtime_measure, normalize_runtime),
155139
ylabel="Share of Problems Solved",
156140
template=template,
157141
height=300,
@@ -284,3 +268,23 @@ def _find_switch_points(solution_times: pd.DataFrame) -> NDArray[np.float64]:
284268
switch_points += 1e-10
285269
switch_points = switch_points[np.isfinite(switch_points)]
286270
return switch_points
271+
272+
273+
def _get_profile_plot_xlabel(runtime_measure: str, normalize_runtime: bool) -> str:
274+
if normalize_runtime:
275+
runtime_measure_to_xlabel = {
276+
"walltime": "Multiple of Minimal Wall Time"
277+
"{linebreak}Needed to Solve the Problem",
278+
"n_evaluations": "Multiple of Minimal Number of Function Evaluations"
279+
"{linebreak}Needed to Solve the Problem",
280+
"n_batches": "Multiple of Minimal Number of Batches"
281+
"{linebreak}Needed to Solve the Problem",
282+
}
283+
else:
284+
runtime_measure_to_xlabel = {
285+
"walltime": "Wall Time Needed to Solve the Problem",
286+
"n_evaluations": "Number of Function Evaluations",
287+
"n_batches": "Number of Batches",
288+
}
289+
290+
return runtime_measure_to_xlabel[runtime_measure]

0 commit comments

Comments
 (0)