Skip to content

Commit 82e14ea

Browse files
authored
fix(Constant): Use object dtype with array creation (#375)
1 parent 5cc8a07 commit 82e14ea

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

src/ConfigSpace/hyperparameters/hp_components.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
normalize,
1313
scale,
1414
)
15-
from ConfigSpace.types import DType, f64, i64
15+
from ConfigSpace.types import DType, ObjectArray, f64, i64
1616

1717
if TYPE_CHECKING:
1818
from ConfigSpace.types import Array, Mask
@@ -446,11 +446,14 @@ def ordinal_neighborhood(
446446
return np.array([seed.choice(neighbors)], dtype=f64)
447447

448448

449+
# HACK: Technically `Any` isn't an `np.number` that the Transformer expects
450+
# as it's type variable. However for a Constant, we can like with this typing
451+
# hack.
449452
@dataclass
450-
class TransformerConstant(Transformer[DType]):
453+
class TransformerConstant(Transformer[Any]):
451454
"""Implementation of a transformer for a constant value."""
452455

453-
value: DType
456+
value: Any
454457
"""The constant value."""
455458

456459
vector_value_yes: f64
@@ -479,23 +482,19 @@ def __post_init__(self) -> None:
479482
self.upper_vectorized = self.vector_value_yes
480483

481484
@override
482-
def to_vector(self, value: Array[DType]) -> Array[f64]:
485+
def to_vector(self, value: ObjectArray) -> Array[f64]:
483486
return np.where(
484487
value == self.value,
485488
self.vector_value_yes,
486489
self.vector_value_no,
487490
)
488491

489492
@override
490-
def to_value(self, vector: Array[f64]) -> Array[DType]:
491-
try:
492-
return np.full_like(vector, self.value, dtype=type(self.value))
493-
except TypeError:
494-
# Let numpy figure it out
495-
return np.array([self.value] * len(vector))
493+
def to_value(self, vector: Array[f64]) -> ObjectArray:
494+
return np.full_like(vector, self.value, dtype=object)
496495

497496
@override
498-
def legal_value(self, value: Array[DType]) -> Mask:
497+
def legal_value(self, value: ObjectArray) -> Mask:
499498
return value == self.value # type: ignore
500499

501500
@override

src/ConfigSpace/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727
Array: TypeAlias = npt.NDArray[DType]
2828
"""Array, a numpy array of a specific dtype."""
2929

30+
ObjectArray: TypeAlias = npt.NDArray[np.object_]
31+
"""Object array, a numpy array of objects."""
32+
3033
f64: TypeAlias = np.float64
3134
"""64-bit floating point number."""
3235

0 commit comments

Comments
 (0)