|
12 | 12 | normalize, |
13 | 13 | scale, |
14 | 14 | ) |
15 | | -from ConfigSpace.types import DType, f64, i64 |
| 15 | +from ConfigSpace.types import DType, ObjectArray, f64, i64 |
16 | 16 |
|
17 | 17 | if TYPE_CHECKING: |
18 | 18 | from ConfigSpace.types import Array, Mask |
@@ -446,11 +446,14 @@ def ordinal_neighborhood( |
446 | 446 | return np.array([seed.choice(neighbors)], dtype=f64) |
447 | 447 |
|
448 | 448 |
|
| 449 | +# HACK: Technically `Any` isn't an `np.number` that the Transformer expects |
| 450 | +# as it's type variable. However for a Constant, we can like with this typing |
| 451 | +# hack. |
449 | 452 | @dataclass |
450 | | -class TransformerConstant(Transformer[DType]): |
| 453 | +class TransformerConstant(Transformer[Any]): |
451 | 454 | """Implementation of a transformer for a constant value.""" |
452 | 455 |
|
453 | | - value: DType |
| 456 | + value: Any |
454 | 457 | """The constant value.""" |
455 | 458 |
|
456 | 459 | vector_value_yes: f64 |
@@ -479,23 +482,19 @@ def __post_init__(self) -> None: |
479 | 482 | self.upper_vectorized = self.vector_value_yes |
480 | 483 |
|
481 | 484 | @override |
482 | | - def to_vector(self, value: Array[DType]) -> Array[f64]: |
| 485 | + def to_vector(self, value: ObjectArray) -> Array[f64]: |
483 | 486 | return np.where( |
484 | 487 | value == self.value, |
485 | 488 | self.vector_value_yes, |
486 | 489 | self.vector_value_no, |
487 | 490 | ) |
488 | 491 |
|
489 | 492 | @override |
490 | | - def to_value(self, vector: Array[f64]) -> Array[DType]: |
491 | | - try: |
492 | | - return np.full_like(vector, self.value, dtype=type(self.value)) |
493 | | - except TypeError: |
494 | | - # Let numpy figure it out |
495 | | - return np.array([self.value] * len(vector)) |
| 493 | + def to_value(self, vector: Array[f64]) -> ObjectArray: |
| 494 | + return np.full_like(vector, self.value, dtype=object) |
496 | 495 |
|
497 | 496 | @override |
498 | | - def legal_value(self, value: Array[DType]) -> Mask: |
| 497 | + def legal_value(self, value: ObjectArray) -> Mask: |
499 | 498 | return value == self.value # type: ignore |
500 | 499 |
|
501 | 500 | @override |
|
0 commit comments