|
| 1 | +import struct |
| 2 | +import time |
| 3 | + |
| 4 | +import numpy as np |
| 5 | +import psycopg |
| 6 | +from psycopg.adapt import Dumper, Loader |
| 7 | +from psycopg.pq import Format |
| 8 | +from psycopg.types import TypeInfo |
| 9 | + |
| 10 | +from ..base.module import BaseANN |
| 11 | + |
| 12 | + |
| 13 | +class VectorDumper(Dumper): |
| 14 | + format = Format.BINARY |
| 15 | + |
| 16 | + def dump(self, obj): |
| 17 | + return struct.pack(f"<H{len(obj)}f", len(obj), *obj) |
| 18 | + |
| 19 | + |
| 20 | +class VectorLoader(Loader): |
| 21 | + def load(self, buf): |
| 22 | + if isinstance(buf, memoryview): |
| 23 | + buf = bytes(buf) |
| 24 | + dim = struct.unpack_from("<H", buf)[0] |
| 25 | + return np.frombuffer(buf, dtype="<f", count=dim, offset=2) |
| 26 | + |
| 27 | + |
| 28 | +def register_vector(conn: psycopg.Connection): |
| 29 | + info = TypeInfo.fetch(conn=conn, name="vector") |
| 30 | + register_vector_type(conn, info) |
| 31 | + |
| 32 | + |
| 33 | +def register_vector_type(conn: psycopg.Connection, info: TypeInfo): |
| 34 | + if info is None: |
| 35 | + raise ValueError("vector type not found") |
| 36 | + info.register(conn) |
| 37 | + |
| 38 | + class VectorBinaryDumper(VectorDumper): |
| 39 | + oid = info.oid |
| 40 | + |
| 41 | + adapters = conn.adapters |
| 42 | + adapters.register_dumper(list, VectorBinaryDumper) |
| 43 | + adapters.register_dumper(np.ndarray, VectorBinaryDumper) |
| 44 | + adapters.register_loader(info.oid, VectorLoader) |
| 45 | + |
| 46 | + |
| 47 | +class PGVectoRS(BaseANN): |
| 48 | + def __init__(self, metric, method_param) -> None: |
| 49 | + self.metric = metric |
| 50 | + self.m = method_param["M"] |
| 51 | + self.ef_construction = method_param["efConstruction"] |
| 52 | + self.ef_search = 100 |
| 53 | + |
| 54 | + if metric == "angular": |
| 55 | + self.query_sql = "SELECT id FROM items ORDER BY embedding <=> %s LIMIT %s" |
| 56 | + self.index_sql = f"CREATE INDEX ON items USING vectors (embedding vector_cos_ops) WITH (options = $$[indexing.hnsw]\nm = {self.m}\nef_construction = {self.ef_construction}$$)" |
| 57 | + elif metric == "euclidean": |
| 58 | + self.query_sql = "SELECT id FROM items ORDER BY embedding <-> %s LIMIT %s" |
| 59 | + self.index_sql = f"CREATE INDEX ON items USING vectors (embedding vector_l2_ops) WITH (options = $$[indexing.hnsw]\nm = {self.m}\nef_construction = {self.ef_construction}$$)" |
| 60 | + else: |
| 61 | + raise RuntimeError(f"unknown metric {metric}") |
| 62 | + |
| 63 | + self.connect = psycopg.connect(user="postgres", password="password", autocommit=True) |
| 64 | + self.connect.execute("SET search_path = \"$user\", public, vectors") |
| 65 | + self.connect.execute("CREATE EXTENSION IF NOT EXISTS vectors") |
| 66 | + register_vector(self.connect) |
| 67 | + |
| 68 | + def fit(self, X): |
| 69 | + dim = X.shape[1] |
| 70 | + |
| 71 | + cur = self.connect.cursor() |
| 72 | + cur.execute("DROP TABLE IF EXISTS items") |
| 73 | + cur.execute(f"CREATE TABLE items (id int, embedding vector({dim}))") |
| 74 | + with cur.copy("COPY items (id, embedding) FROM STDIN WITH (FORMAT BINARY)") as copy: |
| 75 | + copy.set_types(["int4", "vector"]) |
| 76 | + for i, emb in enumerate(X): |
| 77 | + copy.write_row((i, emb)) |
| 78 | + |
| 79 | + cur.execute(self.index_sql) |
| 80 | + print("waiting for indexing to finish...") |
| 81 | + for _ in range(3600): |
| 82 | + cur.execute("SELECT idx_indexing FROM vectors.pg_vector_index_stat WHERE tablename='items'") |
| 83 | + if not cur.fetchone()[0]: |
| 84 | + break |
| 85 | + time.sleep(10) |
| 86 | + |
| 87 | + def set_query_arguments(self, ef_search): |
| 88 | + self.ef_search = ef_search |
| 89 | + self.connect.execute(f"SET vectors.hnsw_ef_search = {ef_search}") |
| 90 | + |
| 91 | + def query(self, vec, num): |
| 92 | + cur = self.connect.execute(self.query_sql, (vec, num), binary=True, prepare=True) |
| 93 | + return [id for (id,) in cur.fetchall()] |
| 94 | + |
| 95 | + def __str__(self): |
| 96 | + return ( |
| 97 | + f"PGVectoRS(metric={self.metric}, m={self.m}, " |
| 98 | + f"ef_construction={self.ef_construction}, ef_search={self.ef_search})" |
| 99 | + ) |
0 commit comments