-
Notifications
You must be signed in to change notification settings - Fork 1.6k
add alternative termination criteria #591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -12,6 +12,8 @@ | |||||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||||||
from typing import TYPE_CHECKING, Any | ||||||||||||||||||||||||||||
from warnings import warn | ||||||||||||||||||||||||||||
from datetime import timedelta, datetime, timezone | ||||||||||||||||||||||||||||
from itertools import accumulate | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||
from scipy.optimize import NonlinearConstraint | ||||||||||||||||||||||||||||
|
@@ -92,6 +94,7 @@ def __init__( | |||||||||||||||||||||||||||
verbose: int = 2, | ||||||||||||||||||||||||||||
bounds_transformer: DomainTransformer | None = None, | ||||||||||||||||||||||||||||
allow_duplicate_points: bool = False, | ||||||||||||||||||||||||||||
termination_criteria: Mapping[str, float | Mapping[str, float]] | None = None, | ||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||
self._random_state = ensure_rng(random_state) | ||||||||||||||||||||||||||||
self._allow_duplicate_points = allow_duplicate_points | ||||||||||||||||||||||||||||
|
@@ -139,6 +142,18 @@ def __init__( | |||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
self._sorting_warning_already_shown = False # TODO: remove in future version | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
self._termination_criteria = termination_criteria if termination_criteria is not None else {} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
self._initial_iterations = 0 | ||||||||||||||||||||||||||||
self._optimizing_iterations = 0 | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
self._start_time: datetime | None = None | ||||||||||||||||||||||||||||
self._timedelta: timedelta | None = None | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Directly instantiate timedelta if provided | ||||||||||||||||||||||||||||
if termination_criteria and "time" in termination_criteria: | ||||||||||||||||||||||||||||
self._timedelta = timedelta(**termination_criteria["time"]) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Initialize logger | ||||||||||||||||||||||||||||
self.logger = ScreenLogger(verbose=self._verbose, is_constrained=self.is_constrained) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
@@ -295,7 +310,7 @@ def maximize(self, init_points: int = 5, n_iter: int = 25) -> None: | |||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
n_iter: int, optional(default=25) | ||||||||||||||||||||||||||||
Number of iterations where the method attempts to find the maximum | ||||||||||||||||||||||||||||
value. | ||||||||||||||||||||||||||||
value. Used when other termination criteria are not provided. | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
Warning | ||||||||||||||||||||||||||||
------- | ||||||||||||||||||||||||||||
|
@@ -309,19 +324,27 @@ def maximize(self, init_points: int = 5, n_iter: int = 25) -> None: | |||||||||||||||||||||||||||
# Log optimization start | ||||||||||||||||||||||||||||
self.logger.log_optimization_start(self._space.keys) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if self._start_time is None and "time" in self._termination_criteria: | ||||||||||||||||||||||||||||
self._start_time = datetime.now(timezone.utc) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Set iterations as termination criteria if others not supplied, increment existing if it already exists. | ||||||||||||||||||||||||||||
self._termination_criteria["iterations"] = max( | ||||||||||||||||||||||||||||
self._termination_criteria.get("iterations", 0) + n_iter + init_points, 1 | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
Comment on lines
+331
to
+333
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Clarify iteration accumulation behavior in docstring. The automatic accumulation of iterations when Add to the method docstring: """
...
Note
----
When termination_criteria includes 'iterations', calling maximize multiple
times will accumulate the total iterations. Each call adds init_points + n_iter
to the existing iteration count. This allows for incremental optimization while
respecting the overall iteration budget.
""" 🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Prime the queue with random points | ||||||||||||||||||||||||||||
self._prime_queue(init_points) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
iteration = 0 | ||||||||||||||||||||||||||||
while self._queue or iteration < n_iter: | ||||||||||||||||||||||||||||
while self._queue or not self.termination_criteria_met(): | ||||||||||||||||||||||||||||
try: | ||||||||||||||||||||||||||||
x_probe = self._queue.popleft() | ||||||||||||||||||||||||||||
self._initial_iterations += 1 | ||||||||||||||||||||||||||||
except IndexError: | ||||||||||||||||||||||||||||
x_probe = self.suggest() | ||||||||||||||||||||||||||||
iteration += 1 | ||||||||||||||||||||||||||||
self._optimizing_iterations += 1 | ||||||||||||||||||||||||||||
self.probe(x_probe, lazy=False) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if self._bounds_transformer and iteration > 0: | ||||||||||||||||||||||||||||
if self._bounds_transformer and not self._queue: | ||||||||||||||||||||||||||||
# The bounds transformer should only modify the bounds after | ||||||||||||||||||||||||||||
# the init_points points (only for the true iterations) | ||||||||||||||||||||||||||||
self.set_bounds(self._bounds_transformer.transform(self._space)) | ||||||||||||||||||||||||||||
|
@@ -345,6 +368,51 @@ def set_gp_params(self, **params: Any) -> None: | |||||||||||||||||||||||||||
params["kernel"] = wrap_kernel(kernel=params["kernel"], transform=self._space.kernel_transform) | ||||||||||||||||||||||||||||
self._gp.set_params(**params) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def termination_criteria_met(self) -> bool: | ||||||||||||||||||||||||||||
"""Determine if the termination criteria have been met.""" | ||||||||||||||||||||||||||||
if "iterations" in self._termination_criteria: | ||||||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||||||
self._optimizing_iterations + self._initial_iterations | ||||||||||||||||||||||||||||
>= self._termination_criteria["iterations"] | ||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||
return True | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if "value" in self._termination_criteria: | ||||||||||||||||||||||||||||
if self.max is not None and self.max["target"] >= self._termination_criteria["value"]: | ||||||||||||||||||||||||||||
return True | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if "time" in self._termination_criteria: | ||||||||||||||||||||||||||||
time_taken = datetime.now(timezone.utc) - self._start_time | ||||||||||||||||||||||||||||
if time_taken >= self._timedelta: | ||||||||||||||||||||||||||||
return True | ||||||||||||||||||||||||||||
Comment on lines
+384
to
+387
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle None start_time edge case. If if "time" in self._termination_criteria:
+ if self._start_time is None:
+ # Initialize start time if not already set
+ self._start_time = datetime.now(timezone.utc)
+ if self._timedelta is None and "time" in self._termination_criteria:
+ self._timedelta = timedelta(**self._termination_criteria["time"])
time_taken = datetime.now(timezone.utc) - self._start_time
if time_taken >= self._timedelta:
return True 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
if "convergence_tol" in self._termination_criteria and len(self._space.target) > 2: | ||||||||||||||||||||||||||||
# Find the maximum value of the target function at each iteration | ||||||||||||||||||||||||||||
running_max = list(accumulate(self._space.target, max)) | ||||||||||||||||||||||||||||
# Determine improvements that have occurred each iteration | ||||||||||||||||||||||||||||
improvements = np.diff(running_max) | ||||||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||||||
self._initial_iterations + self._optimizing_iterations | ||||||||||||||||||||||||||||
>= self._termination_criteria["convergence_tol"]["n_iters"] | ||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||
# Check if there are improvements in the specified number of iterations | ||||||||||||||||||||||||||||
relevant_improvements = ( | ||||||||||||||||||||||||||||
improvements | ||||||||||||||||||||||||||||
if len(self._space.target) == self._termination_criteria["convergence_tol"]["n_iters"] | ||||||||||||||||||||||||||||
else improvements[-self._termination_criteria["convergence_tol"]["n_iters"] :] | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
# There has been no improvement within the iterations specified | ||||||||||||||||||||||||||||
if len(set(relevant_improvements)) == 1: | ||||||||||||||||||||||||||||
return True | ||||||||||||||||||||||||||||
# The improvement(s) are lower than specified | ||||||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||||||
max(relevant_improvements) - min(relevant_improvements) | ||||||||||||||||||||||||||||
< self._termination_criteria["convergence_tol"]["abs_tol"] | ||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||
return True | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
return False | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def save_state(self, path: str | PathLike[str]) -> None: | ||||||||||||||||||||||||||||
"""Save complete state for reconstruction of the optimizer. | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
@@ -385,6 +453,13 @@ def save_state(self, path: str | PathLike[str]) -> None: | |||||||||||||||||||||||||||
"verbose": self._verbose, | ||||||||||||||||||||||||||||
"random_state": random_state, | ||||||||||||||||||||||||||||
"acquisition_params": acquisition_params, | ||||||||||||||||||||||||||||
"termination_criteria": self._termination_criteria, | ||||||||||||||||||||||||||||
"initial_iterations": self._initial_iterations, | ||||||||||||||||||||||||||||
"optimizing_iterations": self._optimizing_iterations, | ||||||||||||||||||||||||||||
"start_time": datetime.strftime(self._start_time, "%Y-%m-%dT%H:%M:%SZ") | ||||||||||||||||||||||||||||
if self._start_time | ||||||||||||||||||||||||||||
else "", | ||||||||||||||||||||||||||||
"timedelta": self._timedelta.total_seconds() if self._timedelta else "", | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
with Path(path).open("w") as f: | ||||||||||||||||||||||||||||
|
@@ -443,3 +518,14 @@ def load_state(self, path: str | PathLike[str]) -> None: | |||||||||||||||||||||||||||
state["random_state"]["cached_gaussian"], | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
self._random_state.set_state(random_state_tuple) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
self._termination_criteria = state["termination_criteria"] | ||||||||||||||||||||||||||||
self._initial_iterations = state["initial_iterations"] | ||||||||||||||||||||||||||||
self._optimizing_iterations = state["optimizing_iterations"] | ||||||||||||||||||||||||||||
# Previously saved as UTC, so explicitly parse as UTC time. | ||||||||||||||||||||||||||||
self._start_time = ( | ||||||||||||||||||||||||||||
datetime.strptime(state["start_time"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) | ||||||||||||||||||||||||||||
if state["start_time"] != "" | ||||||||||||||||||||||||||||
else None | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
self._timedelta = timedelta(seconds=state["timedelta"]) if state["timedelta"] else None |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,8 +1,10 @@ | ||||||||||||||||||||||||||||
from __future__ import annotations | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
from datetime import datetime | ||||||||||||||||||||||||||||
import pickle | ||||||||||||||||||||||||||||
from pathlib import Path | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
from _pytest.tmpdir import tmp_path | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove unused import causing static analysis warning. The -from _pytest.tmpdir import tmp_path
import numpy as np 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||
import pytest | ||||||||||||||||||||||||||||
from scipy.optimize import NonlinearConstraint | ||||||||||||||||||||||||||||
|
@@ -585,3 +587,71 @@ def area_of_triangle(sides): | |||||||||||||||||||||||||||
suggestion1 = optimizer.suggest() | ||||||||||||||||||||||||||||
suggestion2 = new_optimizer.suggest() | ||||||||||||||||||||||||||||
np.testing.assert_array_almost_equal(suggestion1["sides"], suggestion2["sides"], decimal=7) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def test_termination_criteria(tmp_path): | ||||||||||||||||||||||||||||
"""Test each termination criteria individually.""" | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
def target_func_trivial(): | ||||||||||||||||||||||||||||
# Max at 0, 1 | ||||||||||||||||||||||||||||
return lambda x, y: -(x**2) - ((y - 1) ** 2) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
termination_criteria = {"iterations": 10} | ||||||||||||||||||||||||||||
pbounds = {"x": [-10.0, 10.0], "y": [-10.0, 10.0]} | ||||||||||||||||||||||||||||
opt = BayesianOptimization( | ||||||||||||||||||||||||||||
f=target_func_trivial(), pbounds=pbounds, termination_criteria=termination_criteria | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Ensure no initial points are specified. | ||||||||||||||||||||||||||||
opt.maximize(init_points=0, n_iter=10) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
assert len(opt.res) == termination_criteria["iterations"] | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Provide reasonable target value for objective fn | ||||||||||||||||||||||||||||
termination_criteria = {"value": -0.05} | ||||||||||||||||||||||||||||
opt = BayesianOptimization( | ||||||||||||||||||||||||||||
f=target_func_trivial(), pbounds=pbounds, termination_criteria=termination_criteria | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Call with large number of iterations, so that this is not the termination criteria | ||||||||||||||||||||||||||||
opt.maximize(init_points=5, n_iter=1_000) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
assert opt.max["target"] > termination_criteria["value"] | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# 3 seconds of maximizing before termination | ||||||||||||||||||||||||||||
termination_criteria = {"time": {"seconds": 3}} | ||||||||||||||||||||||||||||
opt = BayesianOptimization( | ||||||||||||||||||||||||||||
f=target_func_trivial(), pbounds=pbounds, termination_criteria=termination_criteria | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
start = datetime.now() | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use timezone-aware datetime for consistency. The test uses + from datetime import timezone
start = datetime.now() Should be: + from datetime import timezone
- start = datetime.now()
+ start = datetime.now(timezone.utc) 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||
# Call with large number of iterations, so that this is not the termination criteria | ||||||||||||||||||||||||||||
opt.maximize(n_iter=1_000, init_points=1) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Allow ~200ms tolerance on timing | ||||||||||||||||||||||||||||
assert abs((datetime.now() - start).total_seconds() - termination_criteria["time"]["seconds"]) < 0.2 | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Terminate if no improvement in last 3 iterations | ||||||||||||||||||||||||||||
termination_criteria = {"convergence_tol": {"n_iters": 3, "abs_tol": 0}} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
opt = BayesianOptimization( | ||||||||||||||||||||||||||||
f=target_func_trivial(), pbounds=pbounds, termination_criteria=termination_criteria | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
# Call with number of iterations which will not lead to termination criteria on iterations | ||||||||||||||||||||||||||||
opt.maximize(n_iter=1_000, init_points=5) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Check that none of the last 3 values are the maximum | ||||||||||||||||||||||||||||
no_improvement_in_3 = all([value < opt._space.max()["target"] for value in opt._space.target[-3:]]) | ||||||||||||||||||||||||||||
assert no_improvement_in_3 | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
# Converged if minimum improvement below 1 in last 10 iterations | ||||||||||||||||||||||||||||
termination_criteria = {"convergence_tol": {"n_iters": 10, "abs_tol": 1}} | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
opt = BayesianOptimization( | ||||||||||||||||||||||||||||
f=target_func_trivial(), pbounds=pbounds, termination_criteria=termination_criteria | ||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||
opt.maximize(n_iter=1_000, init_points=5) | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
improvement_below_tol = np.max(opt._space.target[-10:] - opt._space.max()["target"]) < 1 | ||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||
assert improvement_below_tol | ||||||||||||||||||||||||||||
Comment on lines
+655
to
+657
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix incorrect improvement calculation. The improvement calculation appears to be inverted - it's calculating how much worse the recent values are compared to the maximum, not the improvement. - improvement_below_tol = np.max(opt._space.target[-10:] - opt._space.max()["target"]) < 1
+ # Calculate the actual improvement in the last 10 iterations
+ max_target = opt.max["target"]
+ last_10_values = opt._space.target[-10:]
+ # The improvement is the difference between max in last 10 and the previous max
+ previous_max = max(opt._space.target[:-10]) if len(opt._space.target) > 10 else 0
+ current_max = max(last_10_values)
+ improvement = current_max - previous_max
+ improvement_below_tol = improvement < 1 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add validation for termination criteria.
The termination criteria should be validated to ensure proper types and reasonable values.
Add validation after line 145:
🤖 Prompt for AI Agents