|
25 | 25 | - The type implementation might want to be accompanied by corresponding support |
26 | 26 | for the `KNN_MATCH` function, similar to what the dialect already offers for |
27 | 27 | fulltext search through its `Match` predicate. |
| 28 | +- After dropping support for SQLAlchemy 1.3, use |
| 29 | + `class FloatVector(sa.TypeDecorator[t.Sequence[float]]):` |
28 | 30 |
|
29 | 31 | ## Origin |
30 | 32 | This module is based on the corresponding pgvector implementation |
|
44 | 46 | __all__ = ["FloatVector"] |
45 | 47 |
|
46 | 48 |
|
47 | | -def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]: |
| 49 | +def from_db(value: t.Iterable) -> t.Optional["npt.ArrayLike"]: |
48 | 50 | import numpy as np |
49 | 51 |
|
50 | 52 | # from `pgvector.utils` |
@@ -77,8 +79,7 @@ def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]: |
77 | 79 | return value |
78 | 80 |
|
79 | 81 |
|
80 | | -class FloatVector(sa.TypeDecorator[t.Sequence[float]]): |
81 | | - |
| 82 | +class FloatVector(sa.TypeDecorator): |
82 | 83 | """ |
83 | 84 | An improved implementation of the `FloatVector` data type for CrateDB, |
84 | 85 | compared to the previous implementation on behalf of the LangChain adapter. |
@@ -146,14 +147,14 @@ def __init__(self, dimensions: int = None): |
146 | 147 | def as_generic(self): |
147 | 148 | return sa.ARRAY |
148 | 149 |
|
149 | | - def bind_processor(self, dialect: sa.Dialect) -> t.Callable: |
| 150 | + def bind_processor(self, dialect: sa.engine.Dialect) -> t.Callable: |
150 | 151 | def process(value: t.Iterable) -> t.Optional[t.List]: |
151 | 152 | return to_db(value, self.dimensions) |
152 | 153 |
|
153 | 154 | return process |
154 | 155 |
|
155 | | - def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable: |
156 | | - def process(value: t.Any) -> t.Optional[npt.ArrayLike]: |
| 156 | + def result_processor(self, dialect: sa.engine.Dialect, coltype: t.Any) -> t.Callable: |
| 157 | + def process(value: t.Any) -> t.Optional["npt.ArrayLike"]: |
157 | 158 | return from_db(value) |
158 | 159 |
|
159 | 160 | return process |
0 commit comments