-
Notifications
You must be signed in to change notification settings - Fork 96
First edition removing hyperparameter from DAG #411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c1863a0
109fdd0
53945dd
faa4da6
c5bba25
28e6585
bdc429c
c8d8d41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,7 +41,7 @@ | |
| import numpy as np | ||
| from more_itertools import unique_everseen | ||
|
|
||
| from ConfigSpace.conditions import Condition, Conjunction | ||
| from ConfigSpace.conditions import Condition, ConditionLike, Conjunction | ||
| from ConfigSpace.exceptions import ( | ||
| AmbiguousConditionError, | ||
| ChildNotFoundError, | ||
|
|
@@ -62,7 +62,6 @@ | |
| from ConfigSpace.types import f64 | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ConfigSpace.conditions import ConditionLike | ||
| from ConfigSpace.hyperparameters import Hyperparameter | ||
| from ConfigSpace.types import Array | ||
|
|
||
|
|
@@ -658,6 +657,107 @@ def add(self, hp: Hyperparameter) -> None: | |
| self.nodes[hp.name] = node | ||
| self.roots[hp.name] = node | ||
|
|
||
| def remove(self, value: Hyperparameter) -> None: | ||
| """Remove a hyperparameter from the DAG.""" | ||
| if not self._updating: | ||
| raise RuntimeError( | ||
| "Cannot remove hyperparameters outside of transaction." | ||
| "Please use `remove` inside `with dag.transaction():`", | ||
| ) | ||
|
|
||
| existing = self.nodes.get(value.name, None) | ||
| if existing is None: | ||
| raise HyperparameterNotFoundError( | ||
| f"Hyperparameter '{value.name}' does not exist in space.", | ||
| ) | ||
|
|
||
| # Update each condition containing this hyperparameter | ||
| def remove_hyperparameter_from_condition( | ||
| target: Conjunction | Condition | ForbiddenRelation | ForbiddenClause, | ||
| ) -> ( | ||
| Conjunction | ||
| | Condition | ||
| | ForbiddenClause | ||
| | ForbiddenRelation | ||
| | ForbiddenConjunction | ||
| | None | ||
| ): | ||
| if isinstance(target, ForbiddenRelation) and ( | ||
| value in (target.left, target.right) | ||
| ): | ||
| return None | ||
| if isinstance(target, ForbiddenClause) and target.hyperparameter == value: | ||
| return None | ||
| if isinstance(target, Condition) and ( | ||
| value in (target.parent, target.child) | ||
| ): | ||
| return None | ||
| if isinstance(target, (Conjunction, ForbiddenConjunction)): | ||
| new_components = [] | ||
| for component in target.components: | ||
| new_component = remove_hyperparameter_from_condition(component) | ||
| if new_component is not None: | ||
| new_components.append(new_component) | ||
| if len(new_components) >= 2: # Can create a conjunction | ||
| return type(target)(*new_components) | ||
| if len(new_components) == 1: # Only one component remains | ||
| return new_components[0] | ||
| return None # No components remain | ||
| return target # Nothing to change | ||
|
Comment on lines
+685
to
+706
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One nit, it might be safer to compare based on hyperparameter name ( It's not uncommon for people to hack in small changes to hyperparameter properties which would end up breaking equality. I kind of regret relying on equality/hashing as much as I did with the DAG, instead of explicit name identification. It would have made it more robust to people hacking around with the properties. |
||
|
|
||
| # Update each of the forbiddens containing this hyperparameter | ||
| for findex, forbidden in enumerate(self.unconditional_forbiddens): | ||
| self.unconditional_forbiddens[findex] = ( | ||
| remove_hyperparameter_from_condition(forbidden) | ||
| ) | ||
| for findex, forbidden in enumerate(self.conditional_forbiddens): | ||
| self.conditional_forbiddens[findex] = remove_hyperparameter_from_condition( | ||
| forbidden | ||
| ) | ||
| # Filter None values from the forbiddens | ||
| self.unconditional_forbiddens = [ | ||
| f for f in self.unconditional_forbiddens if f is not None | ||
| ] | ||
| self.conditional_forbiddens = [ | ||
| f for f in self.conditional_forbiddens if f is not None | ||
| ] | ||
|
|
||
| for node in self.nodes.values(): | ||
| if node.parent_condition is None: | ||
| continue | ||
| node.parent_condition = remove_hyperparameter_from_condition( | ||
| node.parent_condition | ||
| ) | ||
|
|
||
| self.nodes.pop(value.name) | ||
| for child, _ in existing.children.values(): | ||
| del child.parents[existing.name] | ||
|
|
||
| # Recalculate the depth of the children | ||
| def mark_children_recursively(node: HPNode, marked: set[str]): | ||
| for child, _ in node.children.values(): | ||
| if child.maximum_depth == node.maximum_depth + 1: | ||
| marked.add(child.name) | ||
| mark_children_recursively(child, marked) | ||
|
|
||
| marked_nodes: set[str] = set() | ||
| mark_children_recursively(existing, marked_nodes) | ||
| while marked_nodes: # Update the maximum depth of the marked nodes | ||
| remove = [] | ||
| for node_name in marked_nodes: | ||
| node = self.nodes.get(node_name) | ||
| if not node.parents: | ||
| node.maximum_depth = 0 | ||
| remove.append(node_name) | ||
| elif all(p.name not in marked_nodes for p, _ in node.parents.values()): | ||
| node.maximum_depth = ( | ||
| max(parent.maximum_depth for parent, _ in node.parents.values()) | ||
| + 1 | ||
| ) | ||
| remove.append(node_name) | ||
| for node_name in remove: | ||
| marked_nodes.remove(node_name) | ||
|
Comment on lines
+708
to
+759
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is possible to reconstruct the graph using what already exists, instead of manually doing so here?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, I'm pretty sure this will happen once the |
||
|
|
||
| def add_condition(self, condition: ConditionLike) -> None: | ||
| """Add a condition to the DAG.""" | ||
| if not self._updating: | ||
|
|
@@ -782,8 +882,7 @@ def _minimum_conditions(self) -> list[ConditionNode]: | |
| # i.e. two hyperparameters both rely on algorithm == "A" | ||
| base_conditions: dict[int, ConditionNode] = {} | ||
| for node in self.nodes.values(): | ||
| # This node has no parent as is a root | ||
| if node.parent_condition is None: | ||
| if node.parent_condition is None: # This node has no parent as it is a root node | ||
| assert node.name in self.roots | ||
| continue | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -350,6 +350,38 @@ def _put_to_list( | |
| self._len = len(self._dag.nodes) | ||
| self._check_default_configuration() | ||
|
|
||
| def remove( | ||
| self, | ||
| *args: Hyperparameter, | ||
| ) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a hint on API's, a
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would also make the testing of this feature quite direct! |
||
| """Remove a hyperparameter from the configuration space. | ||
|
|
||
| If the hyperparameter has children, the children are also removed. | ||
| This includes defined conditions and conjunctions! | ||
|
|
||
| !!! note | ||
|
|
||
| If removing multiple hyperparameters, it is better to remove them all | ||
| at once with one call to `remove()`, as we rebuilt a cache after each | ||
| call to `remove()`. | ||
|
|
||
| Args: | ||
| args: Hyperparameter(s) to remove | ||
| """ | ||
| hps = [] | ||
| for arg in args: | ||
| if isinstance(arg, Hyperparameter): | ||
| hps.append(arg) | ||
| else: | ||
| raise TypeError(f"Unknown type {type(arg)}") | ||
|
|
||
| with self._dag.update(): | ||
| for hp in hps: | ||
| self._dag.remove(hp) | ||
|
|
||
| self._len = len(self._dag.nodes) | ||
| self._check_default_configuration() | ||
|
|
||
| def add_configuration_space( | ||
| self, | ||
| prefix: str, | ||
|
|
@@ -878,7 +910,7 @@ def __iter__(self) -> Iterator[str]: | |
| return iter(self._dag.nodes.keys()) | ||
|
|
||
| def items(self) -> ItemsView[str, Hyperparameter]: | ||
| """Return an items view of the hyperparameters, same as `dict.items()`.""" # noqa: D402 | ||
| """Return an items view of the hyperparameters, same as `dict.items()`.""" | ||
| return {name: node.hp for name, node in self._dag.nodes.items()}.items() | ||
|
|
||
| def __len__(self) -> int: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice. happy that the decomposition works out