-
Notifications
You must be signed in to change notification settings - Fork 1.6k
add alternative termination criteria #591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
* Refactors the bayesian_optimization class to have optional termination_criteria argument and uses this instead of iteration number * Keeps termination on number of iterations, adds objective function value, runtime, and objective function improvement termination criteria * Tests each implementation
WalkthroughAdds termination criteria to BayesianOptimization, updates the maximize loop to stop based on criteria, introduces a termination_criteria_met method, and extends state persistence (save/load) to include termination-related data. Tests are added to validate iterations, value threshold, time limit, and convergence tolerance behaviors. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant BO as BayesianOptimization
participant AF as AcquisitionFunction
participant F as Target Function
User->>BO: maximize(init_points, n_iter, termination_criteria)
note over BO: If time criterion, record start_time and timedelta
loop Until termination_criteria_met()
alt Initial phase (queue/probes)
BO->>F: probe(initial point)
BO->>BO: _initial_iterations++
else Optimization phase
BO->>AF: suggest(next point)
AF-->>BO: candidate x
BO->>F: probe(x)
BO->>BO: _optimizing_iterations++
end
BO->>BO: update best value/history
end
BO-->>User: results (best params/target)
note over BO: save_state includes termination counters and timing
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Nitpick comments (8)
tests/test_bayesian_optimization.py (3)
3-3
: Remove unnecessary import.The
datetime
import is unused in this test file - it's only used within thetest_termination_criteria
function where it's imported locally.-from datetime import datetime import pickle
632-632
: Consider making the timing tolerance configurable or more robust.The 200ms tolerance may be fragile across different systems and CI environments. Consider either:
- Making this configurable
- Using a percentage-based tolerance
- Increasing the tolerance for CI environments
- assert abs((datetime.now() - start).total_seconds() - termination_criteria["time"]["seconds"]) < 0.2 + # Allow 10% tolerance or 500ms, whichever is larger, for timing accuracy + expected_time = termination_criteria["time"]["seconds"] + actual_time = (datetime.now(timezone.utc) - start).total_seconds() + tolerance = max(expected_time * 0.1, 0.5) + assert abs(actual_time - expected_time) < tolerance, f"Expected ~{expected_time}s, got {actual_time}s"
644-645
: Simplify convergence check logic.The current check is more complex than needed and could be clearer.
- no_improvement_in_3 = all([value < opt._space.max()["target"] for value in opt._space.target[-3:]]) - assert no_improvement_in_3 + # Verify that the maximum is not in the last 3 iterations (indicating no improvement) + max_target = opt.max["target"] + assert all(value < max_target for value in opt._space.target[-3:])bayes_opt/bayesian_optimization.py (5)
389-413
: Optimize convergence tolerance computation.The current implementation recalculates the running maximum and improvements on every call. Consider caching these values.
+ # Cache running max to avoid recomputation + if not hasattr(self, '_cached_running_max') or len(self._cached_running_max) != len(self._space.target): + self._cached_running_max = list(accumulate(self._space.target, max)) + if "convergence_tol" in self._termination_criteria and len(self._space.target) > 2: # Find the maximum value of the target function at each iteration - running_max = list(accumulate(self._space.target, max)) + running_max = self._cached_running_max # Determine improvements that have occurred each iteration improvements = np.diff(running_max)
399-403
: Simplify and clarify slice selection logic.The current ternary expression for selecting relevant improvements is confusing and could be clearer.
- relevant_improvements = ( - improvements - if len(self._space.target) == self._termination_criteria["convergence_tol"]["n_iters"] - else improvements[-self._termination_criteria["convergence_tol"]["n_iters"] :] - ) + # Get the last n_iters improvements, or all if we have exactly n_iters points + n_iters = self._termination_criteria["convergence_tol"]["n_iters"] + relevant_improvements = improvements[-n_iters:] if len(improvements) > n_iters else improvements
459-462
: Use ISO format for datetime serialization.Using
strftime
with a custom format is less portable than using ISO format. Also, storing empty strings for None values is not idiomatic JSON.- "start_time": datetime.strftime(self._start_time, "%Y-%m-%dT%H:%M:%SZ") - if self._start_time - else "", - "timedelta": self._timedelta.total_seconds() if self._timedelta else "", + "start_time": self._start_time.isoformat() if self._start_time else None, + "timedelta": self._timedelta.total_seconds() if self._timedelta else None,
526-531
: Simplify datetime parsing with ISO format.If using ISO format for serialization, the deserialization becomes simpler and more robust.
- self._start_time = ( - datetime.strptime(state["start_time"], "%Y-%m-%dT%H:%M:%SZ").replace(tzinfo=timezone.utc) - if state["start_time"] != "" - else None - ) - self._timedelta = timedelta(seconds=state["timedelta"]) if state["timedelta"] else None + self._start_time = ( + datetime.fromisoformat(state["start_time"]) if state.get("start_time") else None + ) + self._timedelta = timedelta(seconds=state["timedelta"]) if state.get("timedelta") else None
371-415
: Consider extracting termination checks to separate methods.The
termination_criteria_met
method is getting complex. Consider breaking it down into smaller, focused methods for better maintainability.def termination_criteria_met(self) -> bool: """Determine if the termination criteria have been met.""" return ( self._check_iteration_termination() or self._check_value_termination() or self._check_time_termination() or self._check_convergence_termination() ) def _check_iteration_termination(self) -> bool: """Check if iteration limit has been reached.""" if "iterations" not in self._termination_criteria: return False total_iterations = self._optimizing_iterations + self._initial_iterations return total_iterations >= self._termination_criteria["iterations"] def _check_value_termination(self) -> bool: """Check if target value threshold has been reached.""" if "value" not in self._termination_criteria: return False return ( self.max is not None and self.max["target"] >= self._termination_criteria["value"] ) def _check_time_termination(self) -> bool: """Check if time limit has been exceeded.""" if "time" not in self._termination_criteria: return False # Initialize if needed if self._start_time is None: return False time_taken = datetime.now(timezone.utc) - self._start_time return time_taken >= self._timedelta def _check_convergence_termination(self) -> bool: """Check if convergence criteria have been met.""" if "convergence_tol" not in self._termination_criteria: return False if len(self._space.target) <= 2: return False # Implementation of convergence check... # (move existing logic here)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
bayes_opt/bayesian_optimization.py
(8 hunks)tests/test_bayesian_optimization.py
(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
bayes_opt/bayesian_optimization.py (1)
bayes_opt/target_space.py (3)
max
(624-652)probe
(520-563)target
(159-166)
tests/test_bayesian_optimization.py (2)
bayes_opt/bayesian_optimization.py (4)
BayesianOptimization
(44-531)maximize
(302-353)res
(186-191)max
(178-183)bayes_opt/target_space.py (3)
res
(654-685)max
(624-652)target
(159-166)
🪛 Ruff (0.13.1)
tests/test_bayesian_optimization.py
592-592: Redefinition of unused tmp_path
from line 7
(F811)
592-592: Unused function argument: tmp_path
(ARG001)
🔇 Additional comments (1)
bayes_opt/bayesian_optimization.py (1)
327-329
: Document or enforce class-level thread-safety. TheBayesianOptimization
class maintains multiple mutable fields, so locking only_start_time
is insufficient. Either wrap the entiremaximize
method (and any other state-mutating calls) in a lock or explicitly document that concurrent usage is unsupported.Likely an incorrect or invalid review comment.
verbose: int = 2, | ||
bounds_transformer: DomainTransformer | None = None, | ||
allow_duplicate_points: bool = False, | ||
termination_criteria: Mapping[str, float | Mapping[str, float]] | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Add validation for termination criteria.
The termination criteria should be validated to ensure proper types and reasonable values.
Add validation after line 145:
# Validate termination criteria
if self._termination_criteria:
valid_keys = {"iterations", "value", "time", "convergence_tol"}
invalid_keys = set(self._termination_criteria.keys()) - valid_keys
if invalid_keys:
raise ValueError(f"Invalid termination criteria keys: {invalid_keys}")
if "iterations" in self._termination_criteria:
if not isinstance(self._termination_criteria["iterations"], (int, float)):
raise TypeError("iterations must be a number")
if self._termination_criteria["iterations"] <= 0:
raise ValueError("iterations must be positive")
if "convergence_tol" in self._termination_criteria:
conv_tol = self._termination_criteria["convergence_tol"]
if not isinstance(conv_tol, dict):
raise TypeError("convergence_tol must be a dictionary")
if "n_iters" not in conv_tol or "abs_tol" not in conv_tol:
raise ValueError("convergence_tol must contain 'n_iters' and 'abs_tol'")
🤖 Prompt for AI Agents
In bayes_opt/bayesian_optimization.py around lines 97 to 145, add validation
immediately after line 145 to verify self._termination_criteria: ensure keys are
limited to {"iterations","value","time","convergence_tol"} and raise ValueError
on any invalid keys; if "iterations" is present ensure it's an int/float and > 0
(raise TypeError/ValueError as appropriate); if "convergence_tol" is present
ensure it's a dict containing both "n_iters" and "abs_tol" and raise
TypeError/ValueError when the type or required keys are missing; keep all error
messages clear and specific to the failing condition.
self._termination_criteria["iterations"] = max( | ||
self._termination_criteria.get("iterations", 0) + n_iter + init_points, 1 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Clarify iteration accumulation behavior in docstring.
The automatic accumulation of iterations when maximize
is called multiple times should be documented more clearly, as this could be surprising behavior for users who expect each call to be independent.
Add to the method docstring:
"""
...
Note
----
When termination_criteria includes 'iterations', calling maximize multiple
times will accumulate the total iterations. Each call adds init_points + n_iter
to the existing iteration count. This allows for incremental optimization while
respecting the overall iteration budget.
"""
🤖 Prompt for AI Agents
In bayes_opt/bayesian_optimization.py around lines 331 to 333, the code
accumulates iterations when maximize is called multiple times but the method
docstring doesn't explain this; update the method docstring to add a Note
explaining that when termination_criteria includes 'iterations', repeated calls
to maximize will accumulate iterations by adding init_points + n_iter to the
existing count so users understand the incremental optimization/overall
iteration budget behavior.
if "time" in self._termination_criteria: | ||
time_taken = datetime.now(timezone.utc) - self._start_time | ||
if time_taken >= self._timedelta: | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle None start_time edge case.
If _start_time
is None (when time criteria is not set initially but added later), this will raise an exception.
if "time" in self._termination_criteria:
+ if self._start_time is None:
+ # Initialize start time if not already set
+ self._start_time = datetime.now(timezone.utc)
+ if self._timedelta is None and "time" in self._termination_criteria:
+ self._timedelta = timedelta(**self._termination_criteria["time"])
time_taken = datetime.now(timezone.utc) - self._start_time
if time_taken >= self._timedelta:
return True
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if "time" in self._termination_criteria: | |
time_taken = datetime.now(timezone.utc) - self._start_time | |
if time_taken >= self._timedelta: | |
return True | |
if "time" in self._termination_criteria: | |
if self._start_time is None: | |
# Initialize start time if not already set | |
self._start_time = datetime.now(timezone.utc) | |
if self._timedelta is None and "time" in self._termination_criteria: | |
self._timedelta = timedelta(**self._termination_criteria["time"]) | |
time_taken = datetime.now(timezone.utc) - self._start_time | |
if time_taken >= self._timedelta: | |
return True |
🤖 Prompt for AI Agents
In bayes_opt/bayesian_optimization.py around lines 384-387, the code assumes
self._start_time is always set, which raises if it's None; when "time" is added
later we should initialize the timer: if self._start_time is None set
self._start_time = datetime.now(timezone.utc) before computing time_taken, then
proceed to compare with self._timedelta and return True if exceeded.
import pickle | ||
from pathlib import Path | ||
|
||
from _pytest.tmpdir import tmp_path |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import causing static analysis warning.
The tmp_path
import from _pytest.tmpdir
is unnecessary and conflicts with the pytest fixture parameter. Pytest automatically provides tmp_path
as a fixture.
-from _pytest.tmpdir import tmp_path
import numpy as np
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from _pytest.tmpdir import tmp_path | |
import numpy as np |
🤖 Prompt for AI Agents
In tests/test_bayesian_optimization.py around line 7, remove the unused import
"from _pytest.tmpdir import tmp_path" because pytest already provides tmp_path
as a fixture and the import causes a static analysis warning; simply delete that
import line so the test uses the pytest fixture parameter without conflict.
f=target_func_trivial(), pbounds=pbounds, termination_criteria=termination_criteria | ||
) | ||
|
||
start = datetime.now() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use timezone-aware datetime for consistency.
The test uses datetime.now()
without timezone, while the main code uses datetime.now(timezone.utc)
. This could lead to timing discrepancies.
+ from datetime import timezone
start = datetime.now()
Should be:
+ from datetime import timezone
- start = datetime.now()
+ start = datetime.now(timezone.utc)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
start = datetime.now() | |
from datetime import timezone | |
start = datetime.now(timezone.utc) |
🤖 Prompt for AI Agents
In tests/test_besian_optimization.py around line 627, the test uses naive
datetime.now() which is inconsistent with production code using timezone-aware
datetimes; change the test to use datetime.now(timezone.utc) (import timezone
from datetime if not already) so the created timestamp is UTC-aware and matches
the main code's expectations, and update any comparisons or fixtures that assume
naive datetimes to handle the timezone-aware value.
improvement_below_tol = np.max(opt._space.target[-10:] - opt._space.max()["target"]) < 1 | ||
|
||
assert improvement_below_tol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix incorrect improvement calculation.
The improvement calculation appears to be inverted - it's calculating how much worse the recent values are compared to the maximum, not the improvement.
- improvement_below_tol = np.max(opt._space.target[-10:] - opt._space.max()["target"]) < 1
+ # Calculate the actual improvement in the last 10 iterations
+ max_target = opt.max["target"]
+ last_10_values = opt._space.target[-10:]
+ # The improvement is the difference between max in last 10 and the previous max
+ previous_max = max(opt._space.target[:-10]) if len(opt._space.target) > 10 else 0
+ current_max = max(last_10_values)
+ improvement = current_max - previous_max
+ improvement_below_tol = improvement < 1
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
improvement_below_tol = np.max(opt._space.target[-10:] - opt._space.max()["target"]) < 1 | |
assert improvement_below_tol | |
# Calculate the actual improvement in the last 10 iterations | |
max_target = opt.max["target"] | |
last_10_values = opt._space.target[-10:] | |
# The improvement is the difference between max in last 10 and the previous max | |
previous_max = max(opt._space.target[:-10]) if len(opt._space.target) > 10 else 0 | |
current_max = max(last_10_values) | |
improvement = current_max - previous_max | |
improvement_below_tol = improvement < 1 | |
assert improvement_below_tol |
🤖 Prompt for AI Agents
In tests/test_bayesian_optimization.py around lines 655-657, the improvement
calculation is inverted: it does recent_values - global_max which measures
decline rather than improvement. Replace the subtraction order so you compute
global_max - recent_max (e.g., opt._space.max()["target"] -
np.max(opt._space.target[-10:])) and compare that to the tolerance (1) so the
assertion checks that the improvement is below the threshold.
Hey @cballam, apologies for the delay, I was on vacation. To proceed with this PR, could you have a look at the test failures and fix them? It looks like your changes break the |
This is a WIP attempt to satisfy some of the goals in this issue thread about adding alternative (not just iteration based) termination criteria for the optimizer: #381
New termination criteria are:
I've added tests to make sure these are functional. The "time taken" one is challenging as this depends on the user's machine (the old "it works on my machine" excuse), but is empirically accurate to within about 200ms for low complexity target/acquisition functions
Open to any criticism/suggestions here, I am sure the code could be improved but am not sure on the best way to go about it.
Summary by CodeRabbit