diff --git a/embedding-stores/langchain4j-community-oceanbase/README.md b/embedding-stores/langchain4j-community-oceanbase/README.md new file mode 100644 index 00000000..e53aa9ea --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/README.md @@ -0,0 +1,208 @@ +# OceanBase Vector Store for LangChain4j + +This module implements an `EmbeddingStore` backed by an OceanBase database. + +- [Product Documentation](https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002826816) + +The **OceanBase for LangChain4j** package provides a first-class experience for connecting to OceanBase instances from the LangChain4j ecosystem while providing the following benefits: + +- **Simplified Vector Storage**: Utilize OceanBase's vector data types and indexing capabilities for efficient similarity searches. +- **Improved Metadata Handling**: Store metadata in JSON columns instead of strings, resulting in significant performance improvements. +- **Clear Separation**: Clearly separate table and extension creation, allowing for distinct permissions and streamlined workflows. +- **Better Integration with OceanBase**: Built-in methods to take advantage of OceanBase's advanced indexing and scalability capabilities. + +## Quick Start + +In order to use this library, you first need to go through the following steps: + +1. [Install OceanBase Database](https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002012734) +2. Create a database and user +3. Configure vector memory limits (if needed) + +### Maven Dependency + +```xml + + dev.langchain4j + langchain4j-community-oceanbase + 1.2.0-beta8-SNAPSHOT + +``` + +### Supported Java Versions + +Java >= 17 + +## OceanBaseEmbeddingStore Usage + +`OceanBaseEmbeddingStore` is used to store text embedded data and perform vector search. Instances can be created by configuring the provided `Builder`, which requires: + +- A `DataSource` instance (connected to an OceanBase database) +- Table name +- Table configuration (optional, uses standard configuration by default) +- Exact search option (optional, uses approximate search by default) + +Example usage: + +```java +import dev.langchain4j.community.store.embedding.oceanbase.OceanBaseEmbeddingStore; +import dev.langchain4j.community.store.embedding.oceanbase.EmbeddingTable; +import dev.langchain4j.community.store.embedding.oceanbase.CreateOption; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; + +import javax.sql.DataSource; +import java.util.*; + +// Create a data source +DataSource dataSource = createDataSource(); // You need to implement this method + +// Create a vector store +OceanBaseEmbeddingStore store = OceanBaseEmbeddingStore.builder(dataSource) + .embeddingTable( + EmbeddingTable.builder("my_embeddings") + .vectorDimension(384) // Set vector dimension + .createOption(CreateOption.CREATE_IF_NOT_EXISTS) + .build()) + .build(); + +// Add embeddings +List testTexts = Arrays.asList("apple", "banana", "car", "truck"); +List embeddings = new ArrayList<>(); +List textSegments = new ArrayList<>(); + +for (String text : testTexts) { + Map metaMap = new HashMap<>(); + metaMap.put("category", text.length() <= 5 ? "fruit" : "vehicle"); + Metadata metadata = Metadata.from(metaMap); + textSegments.add(TextSegment.from(text, metadata)); + embeddings.add(myEmbeddingModel.embed(text).content()); // Use your embedding model +} + +// Batch add embeddings and text segments +List ids = store.addAll(embeddings, textSegments); + +// Search for similar vectors +EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(embeddings.get(0)) // Search for content similar to "apple" + .maxResults(10) + .minScore(0.7) + .build(); + +List> results = store.search(request).matches(); + +// Use metadata filtering +import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder; + +// Search only in the "fruit" category +EmbeddingSearchRequest filteredRequest = EmbeddingSearchRequest.builder() + .queryEmbedding(embeddings.get(0)) + .maxResults(10) + .filter(MetadataFilterBuilder.metadataKey("category").isEqualTo("fruit")) + .build(); + +List> filteredResults = store.search(filteredRequest).matches(); + +// Remove embeddings +store.remove(ids.get(0)); // Remove a single vector +store.removeAll(Arrays.asList(ids.get(1), ids.get(2))); // Remove multiple vectors +store.removeAll(MetadataFilterBuilder.metadataKey("category").isEqualTo("fruit")); // Remove by metadata +store.removeAll(); // Remove all vectors +``` + +## EmbeddingTable Configuration + +The `EmbeddingTable` class is used to configure the structure and creation options of the vector table: + +```java +import dev.langchain4j.community.store.embedding.oceanbase.EmbeddingTable; +import dev.langchain4j.community.store.embedding.oceanbase.CreateOption; + +// Basic configuration +EmbeddingTable table = EmbeddingTable.builder("my_embeddings") + .vectorDimension(384) // Set vector dimension + .createOption(CreateOption.CREATE_IF_NOT_EXISTS) + .build(); + +// Advanced configuration +EmbeddingTable advancedTable = EmbeddingTable.builder("advanced_embeddings") + .idColumn("custom_id") // Custom ID column name + .embeddingColumn("vector_data") // Custom vector column name + .textColumn("content") // Custom text column name + .metadataColumn("meta_info") // Custom metadata column name + .vectorDimension(768) // Set vector dimension + .vectorIndexName("idx_vector_search") // Custom index name + .distanceMetric("L2") // Set distance metric (L2, IP, COSINE) + .indexType("hnsw") // Set index type (hnsw, flat) + .createOption(CreateOption.CREATE_OR_REPLACE) // Table creation option + .build(); +``` + +### Table Creation Options + +The `CreateOption` enum provides the following options: + +- `CREATE_NONE`: Do not create a table, assumes the table already exists +- `CREATE_IF_NOT_EXISTS`: Create the table if it does not exist (default) +- `CREATE_OR_REPLACE`: Create the table, replacing it if it exists + +## Search Options + +`OceanBaseEmbeddingStore` supports two search modes: + +- **Approximate Search** (default): Faster but may not be 100% accurate +- **Exact Search**: Slower but 100% accurate + +```java +// Use exact search +OceanBaseEmbeddingStore exactStore = OceanBaseEmbeddingStore.builder(dataSource) + .embeddingTable("my_embeddings") + .exactSearch(true) // Enable exact search + .build(); +``` + +## Metadata Filtering + +Complex metadata filtering conditions can be built using `MetadataFilterBuilder`: + +```java +import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder; + +// Basic filtering +EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(queryEmbedding) + .filter(MetadataFilterBuilder.metadataKey("category").isEqualTo("fruit")) + .build(); + +// Combined filtering +EmbeddingSearchRequest complexRequest = EmbeddingSearchRequest.builder() + .queryEmbedding(queryEmbedding) + .filter( + MetadataFilterBuilder.metadataKey("category").isEqualTo("fruit") + .and(MetadataFilterBuilder.metadataKey("color").isEqualTo("red")) + ) + .build(); +``` + +## Performance Optimization + +For optimal performance, consider the following recommendations: + +1. Adjust OceanBase's vector memory limit appropriately for large vector collections: + ```sql + ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30; + ``` + +2. Choose the index type that suits your use case: + - `hnsw`: Suitable for most scenarios, provides a good balance of performance/accuracy + - `flat`: Use when highest accuracy is required + +3. Add embeddings in batches to improve performance: + ```java + store.addAll(embeddings, textSegments); + ``` + +4. For frequent similarity searches, consider using approximate search mode (default). diff --git a/embedding-stores/langchain4j-community-oceanbase/pom.xml b/embedding-stores/langchain4j-community-oceanbase/pom.xml new file mode 100644 index 00000000..6e9f8aa4 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/pom.xml @@ -0,0 +1,111 @@ + + + 4.0.0 + + dev.langchain4j + langchain4j-community + 1.2.0-beta8-SNAPSHOT + ../../pom.xml + + + langchain4j-community-oceanbase + LangChain4j :: Community :: Integration :: OceanBase + Community Integration with OceanBase Vector Storage + + + 2.4.12 + + + + + dev.langchain4j + langchain4j-core + ${langchain4j.core.version} + + + + + com.oceanbase + oceanbase-client + ${oceanbase-jdbc.version} + + + + org.slf4j + slf4j-api + + + + + com.fasterxml.jackson.core + jackson-databind + + + + + dev.langchain4j + langchain4j-core + ${langchain4j.core.version} + tests + test-jar + test + + + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2-q + test + + + + org.junit.jupiter + junit-jupiter + test + + + + org.assertj + assertj-core + test + + + + org.testcontainers + testcontainers + test + + + + org.testcontainers + junit-jupiter + test + + + + org.tinylog + tinylog-impl + test + + + + org.tinylog + slf4j-tinylog + test + + + + + + + org.honton.chas + license-maven-plugin + + + true + + + + + diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/CreateOption.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/CreateOption.java new file mode 100644 index 00000000..9cb2aa58 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/CreateOption.java @@ -0,0 +1,21 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +/** + * Option for creating a table. + */ +public enum CreateOption { + /** + * Do not create a table. This assumes that the table already exists. + */ + CREATE_NONE, + + /** + * Create the table if it does not exist. + */ + CREATE_IF_NOT_EXISTS, + + /** + * Create the table, replacing an existing table if it exists. + */ + CREATE_OR_REPLACE +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java new file mode 100644 index 00000000..1062c537 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java @@ -0,0 +1,386 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.Statement; +import javax.sql.DataSource; + +/** + * Represents a database table where embeddings, text, and metadata are stored. The columns of this table are listed + */ +public final class EmbeddingTable { + + /** + * Option which configures how the {@link #create(DataSource)} method creates this table + */ + private final CreateOption createOption; + + /** + * The name of this table + */ + private final String name; + + /** + * Name of a column which stores an id. + */ + private final String idColumn; + + /** + * Name of a column which stores an embedding. + */ + private final String embeddingColumn; + + /** + * Name of a column which stores text. + */ + private final String textColumn; + + /** + * Name of a column which stores metadata. + */ + private final String metadataColumn; + + /** + * The vector dimension + */ + private final int vectorDimension; + + /** + * The vector index name + */ + private final String vectorIndexName; + + /** + * The distance metric for vector similarity search + */ + private final String distanceMetric; + + /** + * The index type (e.g., hnsw, flat) + */ + private final String indexType; + + private EmbeddingTable(Builder builder) { + createOption = builder.createOption; + name = builder.name; + idColumn = builder.idColumn; + embeddingColumn = builder.embeddingColumn; + textColumn = builder.textColumn; + metadataColumn = builder.metadataColumn; + vectorDimension = builder.vectorDimension; + vectorIndexName = builder.vectorIndexName; + distanceMetric = builder.distanceMetric; + indexType = builder.indexType; + } + + /** + * Creates a table configured by the {@link Builder} of this EmbeddingTable. No table is created if the Builder was + * configured with {@link CreateOption#CREATE_NONE}. + * + * @param dataSource Data source that connects to an OceanBase Database where the table is (possibly) created. + * @throws SQLException If an error prevents the table from being created. + */ + void create(DataSource dataSource) throws SQLException { + if (createOption == CreateOption.CREATE_NONE) return; + + try (Connection connection = dataSource.getConnection(); + Statement statement = connection.createStatement()) { + if (createOption == CreateOption.CREATE_OR_REPLACE) { + statement.addBatch("DROP TABLE IF EXISTS " + name); + } + + statement.addBatch( + """ + CREATE TABLE IF NOT EXISTS %s( + %s VARCHAR(36) NOT NULL, + %s VECTOR(%d), + %s VARCHAR(4000), + %s JSON, + PRIMARY KEY (%s), + VECTOR INDEX %s(%s) WITH (distance=%s, type=%s) + ) + """ + .formatted( + name, + idColumn, + embeddingColumn, + vectorDimension, + textColumn, + metadataColumn, + idColumn, + vectorIndexName, + embeddingColumn, + distanceMetric, + indexType)); + statement.executeBatch(); + } + } + + /** + * Maps a metadata key to a JSON path expression for use in SQL queries. + * + * @param key Name of a metadata key. Not null. + * @return A JSON path expression that returns the value of the key. + */ + String mapMetadataKey(String key) { + return "JSON_VALUE(" + metadataColumn + ", '$." + key + "')"; + } + + /** + * Returns a MetadataKeyMapper that maps metadata keys to JSON path expressions. + * + * @return A metadata key mapper. Not null. + */ + public MetadataKeyMapper getMetadataKeyMapper() { + return key -> "JSON_VALUE(" + metadataColumn + ", '$." + key + "')"; + } + + /** + * Returns the name of this table. + * + * @return Table name. Not null. + */ + public String name() { + return name; + } + + /** + * Returns the name of the id column. + * + * @return Id column name. Not null. + */ + public String idColumn() { + return idColumn; + } + + /** + * Returns the name of the embedding column. + * + * @return Embedding column name. Not null. + */ + public String embeddingColumn() { + return embeddingColumn; + } + + /** + * Returns the name of the text column. + * + * @return Text column name. Not null. + */ + public String textColumn() { + return textColumn; + } + + /** + * Returns the name of the metadata column. + * + * @return Metadata column name. Not null. + */ + public String metadataColumn() { + return metadataColumn; + } + + /** + * Returns the dimension of the vector. + * + * @return Vector dimension. + */ + public int vectorDimension() { + return vectorDimension; + } + + /** + * Returns the name of the vector index. + * + * @return Vector index name. Not null. + */ + public String vectorIndexName() { + return vectorIndexName; + } + + /** + * Returns the distance metric used for vector similarity search. + * + * @return Distance metric. Not null. + */ + public String distanceMetric() { + return distanceMetric; + } + + /** + * Returns the index type. + * + * @return Index type. Not null. + */ + public String indexType() { + return indexType; + } + + /** + * Returns a builder for configuring an embedding table. + * + * @param tableName Name of the table. Not null. + * @return A builder. + */ + public static Builder builder(String tableName) { + return new Builder(tableName); + } + + /** + * Builder for an EmbeddingTable. + */ + public static final class Builder { + + private final String name; + private String idColumn = "id"; + private String embeddingColumn = "embedding"; + private String textColumn = "text"; + private String metadataColumn = "metadata"; + private int vectorDimension = 1536; + private String vectorIndexName = "idx_vector"; + private String distanceMetric = "cosine"; + private String indexType = "hnsw"; + private CreateOption createOption = CreateOption.CREATE_IF_NOT_EXISTS; + + /** + * Creates a builder for a table with a given name. + * + * @param tableName Name of the table. Not null. + * @throws IllegalArgumentException If the table name is null. + */ + private Builder(String tableName) { + ensureNotNull(tableName, "tableName"); + this.name = tableName; + } + + /** + * Sets the column name for the id column. The default value is "id". + * + * @param columnName Column name. Not null. + * @return This builder. + * @throws IllegalArgumentException If the column name is null. + */ + public Builder idColumn(String columnName) { + ensureNotNull(columnName, "columnName"); + this.idColumn = columnName; + return this; + } + + /** + * Sets the column name for the embedding column. The default value is "embedding". + * + * @param columnName Column name. Not null. + * @return This builder. + * @throws IllegalArgumentException If the column name is null. + */ + public Builder embeddingColumn(String columnName) { + ensureNotNull(columnName, "columnName"); + this.embeddingColumn = columnName; + return this; + } + + /** + * Sets the column name for the text column. The default value is "text". + * + * @param columnName Column name. Not null. + * @return This builder. + * @throws IllegalArgumentException If the column name is null. + */ + public Builder textColumn(String columnName) { + ensureNotNull(columnName, "columnName"); + this.textColumn = columnName; + return this; + } + + /** + * Sets the column name for the metadata column. The default value is "metadata". + * + * @param columnName Column name. Not null. + * @return This builder. + * @throws IllegalArgumentException If the column name is null. + */ + public Builder metadataColumn(String columnName) { + ensureNotNull(columnName, "columnName"); + this.metadataColumn = columnName; + return this; + } + + /** + * Sets the dimensionality of the vector. The default value is 1536. + * + * @param dimension Vector dimension + * @return This builder. + * @throws IllegalArgumentException If the dimension is negative or zero. + */ + public Builder vectorDimension(int dimension) { + if (dimension <= 0) { + throw new IllegalArgumentException("Vector dimension must be positive"); + } + this.vectorDimension = dimension; + return this; + } + + /** + * Sets the name of the vector index. The default value is "idx_vector". + * + * @param indexName Index name. Not null. + * @return This builder. + * @throws IllegalArgumentException If the index name is null. + */ + public Builder vectorIndexName(String indexName) { + ensureNotNull(indexName, "indexName"); + this.vectorIndexName = indexName; + return this; + } + + /** + * Sets the distance metric for vector similarity search. The default value is "L2". + * + * @param metric Distance metric. Not null. + * @return This builder. + * @throws IllegalArgumentException If the metric is null. + */ + public Builder distanceMetric(String metric) { + ensureNotNull(metric, "metric"); + this.distanceMetric = metric; + return this; + } + + /** + * Sets the index type. The default value is "hnsw". + * + * @param type Index type. Not null. + * @return This builder. + * @throws IllegalArgumentException If the type is null. + */ + public Builder indexType(String type) { + ensureNotNull(type, "type"); + this.indexType = type; + return this; + } + + /** + * Sets the create option for the table. + * + * @param createOption Create option. Not null. + * @return This builder. + * @throws IllegalArgumentException If the create option is null. + */ + public Builder createOption(CreateOption createOption) { + ensureNotNull(createOption, "createOption"); + this.createOption = createOption; + return this; + } + + /** + * Creates an EmbeddingTable with the configuration defined by this builder. + * + * @return A new EmbeddingTable. Not null. + */ + public EmbeddingTable build() { + return new EmbeddingTable(this); + } + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/MetadataKeyMapper.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/MetadataKeyMapper.java new file mode 100644 index 00000000..1a60fc9d --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/MetadataKeyMapper.java @@ -0,0 +1,21 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +/** + * Interface for mapping metadata keys to SQL expressions. + * This allows for different mapping strategies without exposing internal implementations. + */ +public interface MetadataKeyMapper { + + /** + * Maps a metadata key to a SQL expression that can be used in queries. + * + * @param key The metadata key to map + * @return A SQL expression that references the given metadata key + */ + String mapKey(String key); + + /** + * Default implementation that uses JSON_VALUE function + */ + MetadataKeyMapper DEFAULT = key -> "JSON_VALUE(metadata, '$." + key + "')"; +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java new file mode 100644 index 00000000..4368c83a --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java @@ -0,0 +1,466 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +import static dev.langchain4j.internal.Utils.randomUUID; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import dev.langchain4j.community.store.embedding.oceanbase.search.OceanBaseSearchTemplate; +import dev.langchain4j.community.store.embedding.oceanbase.search.SearchTemplate; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilter; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilterFactory; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.filter.Filter; +import java.sql.*; +import java.util.*; +import javax.sql.DataSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * OceanBase Vector EmbeddingStore Implementation + */ +public final class OceanBaseEmbeddingStore implements EmbeddingStore { + + private static final Logger log = LoggerFactory.getLogger(OceanBaseEmbeddingStore.class); + + /** + * DataSource configured to connect with an OceanBase Database. + */ + private final DataSource dataSource; + + /** + * Table where embeddings are stored. + */ + private final EmbeddingTable table; + + /** + * true if search should use an exact search, or false if + * it should use approximate search. + */ + private final boolean isExactSearch; + + /** + * JSON Object Mapper for metadata serialization/deserialization + */ + private final ObjectMapper objectMapper; + + /** + * Search template for vector search operations + */ + private final SearchTemplate searchTemplate; + + /** + * Constructs embedding store configured by a builder. + * + * @param builder Builder that configures the embedding store. Not null. + * @throws IllegalArgumentException If the configuration is not valid. + */ + private OceanBaseEmbeddingStore(Builder builder) { + dataSource = builder.dataSource; + table = builder.embeddingTable; + isExactSearch = builder.isExactSearch; + objectMapper = new ObjectMapper(); + searchTemplate = new OceanBaseSearchTemplate(table, isExactSearch); + + try { + table.create(dataSource); + } catch (SQLException sqlException) { + throw new UncheckSQLException(sqlException); + } + } + + @Override + public String add(Embedding embedding) { + ensureNotNull(embedding, "embedding"); + + String id = randomUUID(); + add(id, embedding); + return id; + } + + @Override + public void add(String id, Embedding embedding) { + ensureNotBlank(id, "id"); + ensureNotNull(embedding, "embedding"); + + String sql = "INSERT INTO " + table.name() + " (" + table.idColumn() + + ", " + table.embeddingColumn() + + ") VALUES (?, ?)"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(sql)) { + statement.setString(1, id); + statement.setString(2, vectorToString(embedding.vector())); + statement.executeUpdate(); + } catch (SQLException sqlException) { + throw new UncheckSQLException(sqlException); + } + } + + @Override + public List addAll(List embeddings) { + ensureNotNull(embeddings, "embeddings"); + + if (embeddings.isEmpty()) { + return List.of(); + } + + List ids = generateIds(embeddings.size()); + + String sql = "INSERT INTO " + table.name() + " (" + table.idColumn() + + ", " + table.embeddingColumn() + + ") VALUES (?, ?)"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(sql)) { + + for (int i = 0; i < ids.size(); i++) { + statement.setString(1, ids.get(i)); + statement.setString(2, vectorToString(embeddings.get(i).vector())); + statement.addBatch(); + } + + statement.executeBatch(); + return ids; + } catch (SQLException sqlException) { + throw new UncheckSQLException(sqlException); + } + } + + @Override + public String add(Embedding embedding, TextSegment textSegment) { + ensureNotNull(embedding, "embedding"); + ensureNotNull(textSegment, "textSegment"); + + String id = randomUUID(); + + String sql = "INSERT INTO " + table.name() + " (" + table.idColumn() + + ", " + table.embeddingColumn() + + ", " + table.textColumn() + + ", " + table.metadataColumn() + + ") VALUES (?, ?, ?, ?)"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(sql)) { + statement.setString(1, id); + statement.setString(2, vectorToString(embedding.vector())); + statement.setString(3, textSegment.text()); + statement.setString(4, metadataToJson(textSegment.metadata())); + statement.executeUpdate(); + return id; + } catch (SQLException sqlException) { + throw new UncheckSQLException(sqlException); + } + } + + @Override + public List addAll(List embeddings, List textSegments) { + ensureNotNull(embeddings, "embeddings"); + ensureNotNull(textSegments, "textSegments"); + + if (embeddings.isEmpty()) { + return Collections.emptyList(); + } + + if (embeddings.size() != textSegments.size()) { + throw new IllegalArgumentException(String.format( + "Number of embeddings (%d) and text segments (%d) must be the same", + embeddings.size(), textSegments.size())); + } + + List ids = generateIds(embeddings.size()); + addAll(ids, embeddings, textSegments); + return ids; + } + + @Override + public void addAll(List ids, List embeddings, List embedded) { + ensureNotNull(ids, "ids"); + ensureNotNull(embeddings, "embeddings"); + ensureNotNull(embedded, "embedded"); + + if (ids.isEmpty() || embeddings.isEmpty() || embedded.isEmpty()) { + return; + } + + if (ids.size() != embeddings.size() || ids.size() != embedded.size()) { + throw new IllegalArgumentException(String.format( + "Number of ids (%d), embeddings (%d), and embedded objects (%d) must be the same", + ids.size(), embeddings.size(), embedded.size())); + } + + String sql = "INSERT INTO " + table.name() + " (" + table.idColumn() + + ", " + table.embeddingColumn() + + ", " + table.textColumn() + + ", " + table.metadataColumn() + + ") VALUES (?, ?, ?, ?)"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(sql)) { + + for (int i = 0; i < ids.size(); i++) { + String id = ensureIndexNotNull(ids, i, "ids"); + Embedding embedding = ensureIndexNotNull(embeddings, i, "embeddings"); + TextSegment segment = ensureIndexNotNull(embedded, i, "embedded"); + + statement.setString(1, id); + statement.setString(2, vectorToString(embedding.vector())); + statement.setString(3, segment.text()); + statement.setString(4, metadataToJson(segment.metadata())); + statement.addBatch(); + } + + statement.executeBatch(); + } catch (SQLException sqlException) { + throw new UncheckSQLException(sqlException); + } + } + + @Override + public EmbeddingSearchResult search(EmbeddingSearchRequest request) { + // Use the search template (Template Method pattern) + return searchTemplate.search(dataSource, request); + } + + @Override + public void remove(String id) { + ensureNotBlank(id, "id"); + + String sql = "DELETE FROM " + table.name() + " WHERE " + table.idColumn() + " = ?"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(sql)) { + statement.setString(1, id); + statement.executeUpdate(); + } catch (SQLException sqlException) { + throw new UncheckSQLException(sqlException); + } + } + + @Override + public void removeAll(Collection ids) { + ensureNotEmpty(ids, "ids"); + + // Use placeholders for each ID in the IN clause + String placeholders = String.join(", ", Collections.nCopies(ids.size(), "?")); + String sql = "DELETE FROM " + table.name() + " WHERE " + table.idColumn() + " IN (" + placeholders + ")"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(sql)) { + + int index = 1; + for (String id : ids) { + statement.setString(index++, id); + } + + statement.executeUpdate(); + } catch (SQLException sqlException) { + throw new UncheckSQLException(sqlException); + } + } + + @Override + public void removeAll(Filter filter) { + ensureNotNull(filter, "filter"); + + SQLFilter sqlFilter = SQLFilterFactory.create(filter, (key, value) -> table.mapMetadataKey(key)); + + if (sqlFilter.matchesNoRows()) { + return; + } + + String sql; + if (sqlFilter.matchesAllRows()) { + sql = "DELETE FROM " + table.name(); + } else { + sql = "DELETE FROM " + table.name() + " WHERE " + sqlFilter.toSql(); + } + + try (Connection connection = dataSource.getConnection(); + Statement statement = connection.createStatement()) { + statement.executeUpdate(sql); + } catch (SQLException sqlException) { + throw new UncheckSQLException(sqlException); + } + } + + @Override + public void removeAll() { + String sql = "DELETE FROM " + table.name(); + + try (Connection connection = dataSource.getConnection(); + Statement statement = connection.createStatement()) { + statement.executeUpdate(sql); + } catch (SQLException sqlException) { + throw new UncheckSQLException(sqlException); + } + } + + /** + * Converts a vector to a string representation. + * + * @param vector The vector to convert. + * @return String representation of the vector in OceanBase format. + */ + private String vectorToString(float[] vector) { + if (vector == null || vector.length == 0) { + return "[]"; + } + + StringBuilder sb = new StringBuilder(); + sb.append("["); + for (int i = 0; i < vector.length; i++) { + if (i > 0) { + sb.append(", "); + } + sb.append(vector[i]); + } + sb.append("]"); + return sb.toString(); + } + + /** + * Converts Metadata to a JSON string. + * + * @param metadata The metadata to convert. + * @return JSON string representation of the metadata. + */ + private String metadataToJson(Metadata metadata) { + if (metadata == null || metadata.toMap().isEmpty()) { + return null; + } + + try { + return objectMapper.writeValueAsString(metadata.toMap()); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to convert metadata to JSON", e); + } + } + + /** + * Returns a null-safe item from a List. + * @param list List to get an item from. + * @param index Index of the item to get. + * @param name Name of the parameter, for error messages. + * @param Type of the list items. + * @return The item at the given index. + * @throws IllegalArgumentException If the item is null. + */ + public static T ensureIndexNotNull(List list, int index, String name) { + if (index < 0 || index >= list.size()) { + throw new IllegalArgumentException( + String.format("Index %d is out of bounds for %s list of size %d", index, name, list.size())); + } + T item = list.get(index); + if (item == null) { + throw new IllegalArgumentException(String.format("%s[%d] is null", name, index)); + } + return item; + } + + /** + * Returns a builder for configuring an OceanBaseEmbeddingStore. + * + * @param dataSource DataSource that connects to an OceanBase Database. + * @return A builder. + */ + public static Builder builder(DataSource dataSource) { + return new Builder(dataSource); + } + + /** + * Builder for an OceanBaseEmbeddingStore. + */ + public static final class Builder { + + private final DataSource dataSource; + private EmbeddingTable embeddingTable; + private boolean isExactSearch = false; + + /** + * Constructs a builder for an OceanBaseEmbeddingStore. + * + * @param dataSource DataSource that connects to an OceanBase Database. + * @throws IllegalArgumentException If the DataSource is null. + */ + private Builder(DataSource dataSource) { + ensureNotNull(dataSource, "dataSource"); + this.dataSource = dataSource; + } + + /** + * Sets the embedding table with a default table configuration. + * + * @param tableName Name of the embedding table. + * @return This builder. + * @throws IllegalArgumentException If the table name is null. + */ + public Builder embeddingTable(String tableName) { + ensureNotNull(tableName, "tableName"); + this.embeddingTable = EmbeddingTable.builder(tableName).build(); + return this; + } + + /** + * Sets the embedding table with a default table configuration and a create option. + * + * @param tableName Name of the embedding table. + * @param createOption Option for table creation. + * @return This builder. + * @throws IllegalArgumentException If the table name or create option is null. + */ + public Builder embeddingTable(String tableName, CreateOption createOption) { + ensureNotNull(tableName, "tableName"); + ensureNotNull(createOption, "createOption"); + this.embeddingTable = + EmbeddingTable.builder(tableName).createOption(createOption).build(); + return this; + } + + /** + * Sets the embedding table with a custom configuration. + * + * @param embeddingTable Embedding table configuration. + * @return This builder. + * @throws IllegalArgumentException If the embedding table is null. + */ + public Builder embeddingTable(EmbeddingTable embeddingTable) { + ensureNotNull(embeddingTable, "embeddingTable"); + this.embeddingTable = embeddingTable; + return this; + } + + /** + * Sets whether to use exact search instead of approximate search. + * + * @param isExactSearch true to use exact search, false to use approximate search. + * @return This builder. + */ + public Builder exactSearch(boolean isExactSearch) { + this.isExactSearch = isExactSearch; + return this; + } + + /** + * Builds an OceanBaseEmbeddingStore with the configuration defined by this builder. + * + * @return A new OceanBaseEmbeddingStore. + * @throws IllegalArgumentException If required parameters are missing. + */ + public OceanBaseEmbeddingStore build() { + if (embeddingTable == null) { + throw new IllegalArgumentException("embeddingTable must be configured"); + } + + return new OceanBaseEmbeddingStore(this); + } + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/UncheckSQLException.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/UncheckSQLException.java new file mode 100644 index 00000000..4a2d8d2c --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/UncheckSQLException.java @@ -0,0 +1,21 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +import dev.langchain4j.exception.LangChain4jException; + +/** + * Exception for OceanBase SQL related errors + */ +public class UncheckSQLException extends LangChain4jException { + + public UncheckSQLException(String message, Throwable cause) { + super(message, cause); + } + + public UncheckSQLException(String message) { + super(message); + } + + public UncheckSQLException(Throwable cause) { + super("SQL error: " + cause.getMessage(), cause); + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/CosineDistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/CosineDistanceConverter.java new file mode 100644 index 00000000..57af43ab --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/CosineDistanceConverter.java @@ -0,0 +1,13 @@ +package dev.langchain4j.community.store.embedding.oceanbase.distance; + +/** + * Converts cosine distance to similarity score. + * For cosine distance, similarity = 1 - distance + */ +public class CosineDistanceConverter implements DistanceConverter { + + @Override + public double toSimilarity(double distance) { + return 1.0 - distance; + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverter.java new file mode 100644 index 00000000..39d34f75 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverter.java @@ -0,0 +1,16 @@ +package dev.langchain4j.community.store.embedding.oceanbase.distance; + +/** + * Strategy interface for converting distance values to similarity scores. + * Different distance metrics require different conversion formulas. + */ +public interface DistanceConverter { + + /** + * Converts a distance value to a similarity score. + * + * @param distance The distance value to convert + * @return A similarity score between 0 and 1, where 1 means exact match + */ + double toSimilarity(double distance); +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java new file mode 100644 index 00000000..62490e76 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java @@ -0,0 +1,33 @@ +package dev.langchain4j.community.store.embedding.oceanbase.distance; + +/** + * Factory for creating appropriate distance converters based on the metric type. + * Implements the Factory Method pattern. + */ +public class DistanceConverterFactory { + + /** + * Returns the appropriate converter for the given distance metric. + * + * @param metric The distance metric name (e.g., "cosine", "manhattan", "euclidean") + * @return A DistanceConverter appropriate for the given metric + */ + public static DistanceConverter getConverter(String metric) { + if (metric == null) { + return new CosineDistanceConverter(); + } + + metric = metric.toLowerCase(); + + switch (metric) { + case "l1": + case "manhattan": + return new ManhattanDistanceConverter(); + case "l2": + case "euclidean": + return new EuclideanDistanceConverter(); + default: + return new CosineDistanceConverter(); + } + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java new file mode 100644 index 00000000..75b821a3 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java @@ -0,0 +1,14 @@ +package dev.langchain4j.community.store.embedding.oceanbase.distance; + +/** + * Euclidean distance to similarity converter. + * Uses exponential decay: similarity = e^(-distance) + * This works well for most distance metrics where smaller distance means higher similarity. + */ +public class EuclideanDistanceConverter implements DistanceConverter { + + @Override + public double toSimilarity(double distance) { + return Math.exp(-distance); + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/ManhattanDistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/ManhattanDistanceConverter.java new file mode 100644 index 00000000..3d6d4ec8 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/ManhattanDistanceConverter.java @@ -0,0 +1,14 @@ +package dev.langchain4j.community.store.embedding.oceanbase.distance; + +/** + * Manhattan distance to similarity converter. + * Uses a simple inverse function: similarity = 1 / (1 + distance) + * This works well for Manhattan distance, which is the sum of absolute differences. + */ +public class ManhattanDistanceConverter implements DistanceConverter { + + @Override + public double toSimilarity(double distance) { + return 1.0 / (1.0 + distance); + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/OceanBaseSearchTemplate.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/OceanBaseSearchTemplate.java new file mode 100644 index 00000000..b864f016 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/OceanBaseSearchTemplate.java @@ -0,0 +1,219 @@ +package dev.langchain4j.community.store.embedding.oceanbase.search; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import dev.langchain4j.community.store.embedding.oceanbase.EmbeddingTable; +import dev.langchain4j.community.store.embedding.oceanbase.distance.DistanceConverter; +import dev.langchain4j.community.store.embedding.oceanbase.distance.DistanceConverterFactory; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilter; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilterFactory; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SqlQueryBuilder; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.filter.Filter; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Concrete implementation of SearchTemplate for OceanBase. + */ +public class OceanBaseSearchTemplate extends SearchTemplate { + + private final EmbeddingTable table; + private final boolean isExactSearch; + private final ObjectMapper objectMapper; + + /** + * Creates a new OceanBaseSearchTemplate. + * + * @param table The embedding table configuration + * @param isExactSearch Whether to use exact search + */ + public OceanBaseSearchTemplate(EmbeddingTable table, boolean isExactSearch) { + this.table = table; + this.isExactSearch = isExactSearch; + this.objectMapper = new ObjectMapper(); + } + + /** + * Converts a vector to a string representation. + * + * @param vector The vector to convert + * @return String representation of the vector in OceanBase format + */ + protected String vectorToString(float[] vector) { + if (vector == null || vector.length == 0) { + return "[]"; + } + + StringBuilder sb = new StringBuilder(); + sb.append("["); + for (int i = 0; i < vector.length; i++) { + if (i > 0) { + sb.append(", "); + } + sb.append(vector[i]); + } + sb.append("]"); + return sb.toString(); + } + + /** + * Converts a string representation of a vector back to a float array. + * + * @param vectorStr The string representation of the vector + * @return The vector as a float array + */ + protected float[] stringToVector(String vectorStr) { + if (vectorStr == null || vectorStr.isEmpty() || vectorStr.equals("[]")) { + return new float[0]; + } + + String trimmed = vectorStr.trim(); + if (trimmed.startsWith("[")) { + trimmed = trimmed.substring(1); + } + if (trimmed.endsWith("]")) { + trimmed = trimmed.substring(0, trimmed.length() - 1); + } + + String[] parts = trimmed.split(","); + float[] result = new float[parts.length]; + + for (int i = 0; i < parts.length; i++) { + result[i] = Float.parseFloat(parts[i].trim()); + } + + return result; + } + + /** + * Converts Metadata to a JSON string. + * + * @param metadata The metadata to convert + * @return JSON string representation of the metadata + */ + protected String metadataToJson(Metadata metadata) { + if (metadata == null || metadata.toMap().isEmpty()) { + return null; + } + + try { + return objectMapper.writeValueAsString(metadata.toMap()); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to convert metadata to JSON", e); + } + } + + /** + * Converts a JSON string to Metadata. + * + * @param json The JSON string to convert + * @return Metadata object + */ + @SuppressWarnings("unchecked") + protected Metadata jsonToMetadata(String json) { + if (json == null || json.isEmpty()) { + return Metadata.from(java.util.Collections.emptyMap()); + } + + try { + Map map = objectMapper.readValue(json, java.util.Map.class); + return Metadata.from(map); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to convert JSON to metadata", e); + } + } + + @Override + protected String buildSearchQuery(EmbeddingSearchRequest request) { + // Use SqlQueryBuilder (Builder pattern) to construct the query + SqlQueryBuilder builder = SqlQueryBuilder.select() + .columns(table.idColumn(), table.embeddingColumn(), table.textColumn(), table.metadataColumn()) + .function(table.distanceMetric().toLowerCase() + "_distance", table.embeddingColumn(), "?") + .from(table.name()); + + // Add WHERE clause for filtering + Filter filter = request.filter(); + if (filter != null) { + String whereClause = buildWhereClause(filter); + if (whereClause != null && !whereClause.trim().isEmpty()) { + builder.where(whereClause); + } + } + + // Add ORDER BY clause + builder.orderByFunction(table.distanceMetric().toLowerCase() + "_distance", table.embeddingColumn(), "?"); + + // Use APPROXIMATE keyword if not exact search + if (!isExactSearch) { + builder.approximate(); + } + + // Add LIMIT clause + builder.limit(request.maxResults()); + + return builder.build(); + } + + @Override + protected void setParameters(PreparedStatement statement, EmbeddingSearchRequest request) throws SQLException { + String queryVector = vectorToString(request.queryEmbedding().vector()); + statement.setString(1, queryVector); // For distance calculation in SELECT + statement.setString(2, queryVector); // For distance calculation in ORDER BY + } + + @Override + protected List> processResults(ResultSet resultSet, EmbeddingSearchRequest request) + throws SQLException { + List> matches = new ArrayList<>(); + + while (resultSet.next()) { + String id = resultSet.getString(1); + String vectorStr = resultSet.getString(2); + String text = resultSet.getString(3); + String metadataJson = resultSet.getString(4); + double distance = resultSet.getDouble(5); + + // Convert distance to similarity score using appropriate converter + DistanceConverter converter = DistanceConverterFactory.getConverter(table.distanceMetric()); + double score = converter.toSimilarity(distance); + + // Apply minimum score filter if specified + if (score < request.minScore()) { + continue; + } + + Embedding embedding = new Embedding(stringToVector(vectorStr)); + Metadata metadata = jsonToMetadata(metadataJson); + TextSegment segment = (text != null) ? TextSegment.from(text, metadata) : null; + + matches.add(new EmbeddingMatch<>(score, id, embedding, segment)); + } + + return matches; + } + + @Override + protected String buildWhereClause(Filter filter) { + if (filter == null) { + return null; + } + + SQLFilter sqlFilter = SQLFilterFactory.create( + filter, (key, value) -> table.getMetadataKeyMapper().mapKey(key)); + + if (sqlFilter.matchesAllRows()) { + return null; + } + + return sqlFilter.toSql(); + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/SearchTemplate.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/SearchTemplate.java new file mode 100644 index 00000000..3e6ddb73 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/SearchTemplate.java @@ -0,0 +1,111 @@ +package dev.langchain4j.community.store.embedding.oceanbase.search; + +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; +import dev.langchain4j.store.embedding.filter.Filter; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import javax.sql.DataSource; + +/** + * Template for the search operation. + * Implements the Template Method pattern to define the skeleton of the search algorithm, + * deferring some steps to subclasses. + * + * @param The type of embedded objects + */ +public abstract class SearchTemplate { + + /** + * The main template method that defines the search algorithm. + * + * @param dataSource The data source to use + * @param request The search request + * @return The search result + */ + public final EmbeddingSearchResult search(DataSource dataSource, EmbeddingSearchRequest request) { + validateRequest(request); + + if (request.maxResults() <= 0) { + return new EmbeddingSearchResult<>(new ArrayList<>()); + } + + String query = buildSearchQuery(request); + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { + + setParameters(statement, request); + + try (ResultSet resultSet = statement.executeQuery()) { + List> matches = processResults(resultSet, request); + return new EmbeddingSearchResult<>(matches); + } + } catch (SQLException e) { + throw new RuntimeException("Error executing search query: " + e.getMessage(), e); + } + } + + /** + * Validates the search request. + * This is a hook that may be overridden by subclasses. + * + * @param request The search request to validate + */ + protected void validateRequest(EmbeddingSearchRequest request) { + if (request == null) { + throw new IllegalArgumentException("Search request cannot be null"); + } + if (request.queryEmbedding() == null) { + throw new IllegalArgumentException("Query embedding cannot be null"); + } + } + + /** + * Builds the search query. + * This is an abstract method that must be implemented by subclasses. + * + * @param request The search request + * @return The SQL query string + */ + protected abstract String buildSearchQuery(EmbeddingSearchRequest request); + + /** + * Sets the parameters for the prepared statement. + * This is an abstract method that must be implemented by subclasses. + * + * @param statement The prepared statement + * @param request The search request + * @throws SQLException If an error occurs setting the parameters + */ + protected abstract void setParameters(PreparedStatement statement, EmbeddingSearchRequest request) + throws SQLException; + + /** + * Processes the result set and creates embedding matches. + * This is an abstract method that must be implemented by subclasses. + * + * @param resultSet The result set from the query + * @param request The search request + * @return A list of embedding matches + * @throws SQLException If an error occurs processing the result set + */ + protected abstract List> processResults(ResultSet resultSet, EmbeddingSearchRequest request) + throws SQLException; + + /** + * Converts a filter to a WHERE clause. + * This is a hook that may be overridden by subclasses. + * + * @param filter The filter to convert + * @return The WHERE clause, or null if no filter is applied + */ + protected String buildWhereClause(Filter filter) { + return null; + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/FilterVisitor.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/FilterVisitor.java new file mode 100644 index 00000000..701f6be3 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/FilterVisitor.java @@ -0,0 +1,109 @@ +package dev.langchain4j.community.store.embedding.oceanbase.sql; + +import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsIn; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThan; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotIn; +import dev.langchain4j.store.embedding.filter.logical.And; +import dev.langchain4j.store.embedding.filter.logical.Not; +import dev.langchain4j.store.embedding.filter.logical.Or; + +/** + * Visitor interface for Filter objects. + * Implements the Visitor pattern to handle different filter types + * without using instanceof checks and type casting. + */ +public interface FilterVisitor { + + /** + * Visit an IsEqualTo filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsEqualTo filter); + + /** + * Visit an IsNotEqualTo filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsNotEqualTo filter); + + /** + * Visit an IsGreaterThan filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsGreaterThan filter); + + /** + * Visit an IsGreaterThanOrEqualTo filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsGreaterThanOrEqualTo filter); + + /** + * Visit an IsLessThan filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsLessThan filter); + + /** + * Visit an IsLessThanOrEqualTo filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsLessThanOrEqualTo filter); + + /** + * Visit an IsIn filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsIn filter); + + /** + * Visit an IsNotIn filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsNotIn filter); + + /** + * Visit an And filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(And filter); + + /** + * Visit an Or filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(Or filter); + + /** + * Visit a Not filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(Not filter); +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilter.java new file mode 100644 index 00000000..5451119c --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilter.java @@ -0,0 +1,37 @@ +package dev.langchain4j.community.store.embedding.oceanbase.sql; + +import dev.langchain4j.store.embedding.filter.Filter; + +/** + * Transforms a metadata {@link Filter} into a SQL expression that can be used in a WHERE clause. + */ +public interface SQLFilter { + + /** + * Returns a SQL expression that can be used in a WHERE clause to filter results based on metadata values. + * The returned SQL expression should evaluate to true when a row meets the criteria of the metadata filter, + * and false otherwise. + * + * @return A SQL expression, or null if this filter is equivalent to "always include" all rows. + */ + String toSql(); + + /** + * Returns true if the SQL filter is guaranteed to match all rows, or false otherwise. + * When a SQL filter matches all rows (e.g., as if there were no filter), the result of + * {@link #toSql()} may be null or an expression which always evaluates to true, such as "1=1". + *

+ * When {@link #toSql()} is null, the SQL filter MUST match all rows, and this method MUST return true. + * + * @return True if this SQL filter matches all rows, or false if it might exclude some rows. + */ + boolean matchesAllRows(); + + /** + * Returns true if this SQL filter is guaranteed to match no rows. When a SQL filter matches no rows, + * the result of {@link #toSql()} should be an expression which always evaluates to false, such as "1=0". + * + * @return True if this SQL filter matches no rows, or false if it might include some rows. + */ + boolean matchesNoRows(); +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterFactory.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterFactory.java new file mode 100644 index 00000000..a9fbfb34 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterFactory.java @@ -0,0 +1,27 @@ +package dev.langchain4j.community.store.embedding.oceanbase.sql; + +import dev.langchain4j.store.embedding.filter.Filter; +import java.util.function.BiFunction; + +/** + * Factory for creating SQL filters. + * Implements the Factory Method pattern. + */ +public class SQLFilterFactory { + + /** + * Creates a SQLFilter from a Filter and key mapper. + * + * @param filter The filter to convert + * @param keyMapper Function that maps a key and value to a SQL column expression + * @return A SQLFilter representing the given filter + */ + public static SQLFilter create(Filter filter, BiFunction keyMapper) { + if (filter == null) { + return SQLFilters.matchAllRows(); + } + + SQLFilterVisitor visitor = new SQLFilterVisitor(keyMapper); + return visitor.process(filter); + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterVisitor.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterVisitor.java new file mode 100644 index 00000000..31a628c5 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterVisitor.java @@ -0,0 +1,245 @@ +package dev.langchain4j.community.store.embedding.oceanbase.sql; + +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilters.MatchAllSQLFilter; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilters.MatchNoSQLFilter; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilters.SimpleSQLFilter; +import dev.langchain4j.store.embedding.filter.Filter; +import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsIn; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThan; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotIn; +import dev.langchain4j.store.embedding.filter.logical.And; +import dev.langchain4j.store.embedding.filter.logical.Not; +import dev.langchain4j.store.embedding.filter.logical.Or; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.function.BiFunction; + +/** + * Concrete visitor implementation that converts Filter objects to SQLFilter objects. + * Implements the Visitor pattern. + */ +public class SQLFilterVisitor implements FilterVisitor { + private final BiFunction keyMapper; + + /** + * Creates a new SQLFilterVisitor with the given key mapper. + * + * @param keyMapper Function that maps a key and value to a SQL column expression + */ + public SQLFilterVisitor(BiFunction keyMapper) { + this.keyMapper = keyMapper; + } + + /** + * Helper method to convert a value to its SQL string representation. + * + * @param value The value to convert + * @return SQL string representation of the value + */ + private String valueToSql(Object value) { + if (value == null) { + return "NULL"; + } else if (value instanceof String) { + return "'" + ((String) value).replace("'", "''") + "'"; + } else if (value instanceof Number || value instanceof Boolean) { + return value.toString(); + } else { + return "'" + value.toString().replace("'", "''") + "'"; + } + } + + @Override + public SQLFilter visit(IsEqualTo filter) { + String expr = keyMapper.apply(filter.key(), filter.comparisonValue()); + Object value = filter.comparisonValue(); + + if (value == null) { + return new SimpleSQLFilter(expr + " IS NULL"); + } else { + return new SimpleSQLFilter(expr + " = " + valueToSql(value)); + } + } + + @Override + public SQLFilter visit(IsNotEqualTo filter) { + String expr = keyMapper.apply(filter.key(), filter.comparisonValue()); + Object value = filter.comparisonValue(); + + if (value == null) { + return new SimpleSQLFilter(expr + " IS NOT NULL"); + } else { + return new SimpleSQLFilter("(" + expr + " <> " + valueToSql(value) + " OR " + expr + " IS NULL)"); + } + } + + @Override + public SQLFilter visit(IsGreaterThan filter) { + String expr = keyMapper.apply(filter.key(), filter.comparisonValue()); + return new SimpleSQLFilter(expr + " > " + valueToSql(filter.comparisonValue())); + } + + @Override + public SQLFilter visit(IsGreaterThanOrEqualTo filter) { + String expr = keyMapper.apply(filter.key(), filter.comparisonValue()); + return new SimpleSQLFilter(expr + " >= " + valueToSql(filter.comparisonValue())); + } + + @Override + public SQLFilter visit(IsLessThan filter) { + String expr = keyMapper.apply(filter.key(), filter.comparisonValue()); + return new SimpleSQLFilter(expr + " < " + valueToSql(filter.comparisonValue())); + } + + @Override + public SQLFilter visit(IsLessThanOrEqualTo filter) { + String expr = keyMapper.apply(filter.key(), filter.comparisonValue()); + return new SimpleSQLFilter(expr + " <= " + valueToSql(filter.comparisonValue())); + } + + @Override + public SQLFilter visit(IsIn filter) { + Collection values = filter.comparisonValues(); + if (values == null || values.isEmpty()) { + return new MatchNoSQLFilter(); + } + + List valueStrings = new ArrayList<>(values.size()); + for (Object value : values) { + valueStrings.add(valueToSql(value)); + } + + String expr = keyMapper.apply(filter.key(), values.iterator().next()); + return new SimpleSQLFilter(expr + " IN (" + String.join(", ", valueStrings) + ")"); + } + + @Override + public SQLFilter visit(IsNotIn filter) { + Collection values = filter.comparisonValues(); + if (values == null || values.isEmpty()) { + return new MatchAllSQLFilter(); + } + + List valueStrings = new ArrayList<>(values.size()); + for (Object value : values) { + valueStrings.add(valueToSql(value)); + } + + String expr = keyMapper.apply(filter.key(), values.iterator().next()); + return new SimpleSQLFilter( + "(" + expr + " NOT IN (" + String.join(", ", valueStrings) + ") OR " + expr + " IS NULL)"); + } + + @Override + public SQLFilter visit(And filter) { + SQLFilter leftFilter = process(filter.left()); + + // If left side matches no rows, the entire AND also matches no rows + if (leftFilter.matchesNoRows()) { + return new MatchNoSQLFilter(); + } + + // If left side matches all rows, the result depends on the right side + if (leftFilter.matchesAllRows()) { + return process(filter.right()); + } + + SQLFilter rightFilter = process(filter.right()); + + // If right side matches no rows, the entire AND also matches no rows + if (rightFilter.matchesNoRows()) { + return new MatchNoSQLFilter(); + } + + // If right side matches all rows, the result depends on the left side + if (rightFilter.matchesAllRows()) { + return leftFilter; + } + + // Both sides need to participate in the AND operation + return new SimpleSQLFilter("(" + leftFilter.toSql() + ") AND (" + rightFilter.toSql() + ")"); + } + + @Override + public SQLFilter visit(Or filter) { + SQLFilter leftFilter = process(filter.left()); + + // If left side matches all rows, the entire OR also matches all rows + if (leftFilter.matchesAllRows()) { + return new MatchAllSQLFilter(); + } + + // If left side matches no rows, the result depends on the right side + if (leftFilter.matchesNoRows()) { + return process(filter.right()); + } + + SQLFilter rightFilter = process(filter.right()); + + // If right side matches all rows, the entire OR also matches all rows + if (rightFilter.matchesAllRows()) { + return new MatchAllSQLFilter(); + } + + // If right side matches no rows, the result depends on the left side + if (rightFilter.matchesNoRows()) { + return leftFilter; + } + + // Both sides need to participate in the OR operation + return new SimpleSQLFilter("(" + leftFilter.toSql() + ") OR (" + rightFilter.toSql() + ")"); + } + + @Override + public SQLFilter visit(Not filter) { + SQLFilter expressionFilter = process(filter.expression()); + + if (expressionFilter.matchesAllRows()) { + return new MatchNoSQLFilter(); + } else if (expressionFilter.matchesNoRows()) { + return new MatchAllSQLFilter(); + } else { + return new SimpleSQLFilter("NOT (" + expressionFilter.toSql() + ")"); + } + } + + /** + * Process any Filter object by making it accept this visitor. + * + * @param filter The filter to process + * @return The resulting SQLFilter + */ + public SQLFilter process(Filter filter) { + if (filter instanceof IsEqualTo) { + return visit((IsEqualTo) filter); + } else if (filter instanceof IsNotEqualTo) { + return visit((IsNotEqualTo) filter); + } else if (filter instanceof IsGreaterThan) { + return visit((IsGreaterThan) filter); + } else if (filter instanceof IsGreaterThanOrEqualTo) { + return visit((IsGreaterThanOrEqualTo) filter); + } else if (filter instanceof IsLessThan) { + return visit((IsLessThan) filter); + } else if (filter instanceof IsLessThanOrEqualTo) { + return visit((IsLessThanOrEqualTo) filter); + } else if (filter instanceof IsIn) { + return visit((IsIn) filter); + } else if (filter instanceof IsNotIn) { + return visit((IsNotIn) filter); + } else if (filter instanceof And) { + return visit((And) filter); + } else if (filter instanceof Or) { + return visit((Or) filter); + } else if (filter instanceof Not) { + return visit((Not) filter); + } else { + throw new IllegalArgumentException( + "Unsupported filter type: " + filter.getClass().getName()); + } + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilters.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilters.java new file mode 100644 index 00000000..1e2cafd3 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilters.java @@ -0,0 +1,106 @@ +package dev.langchain4j.community.store.embedding.oceanbase.sql; + +/** + * Factory class for creating SQL filter instances. + */ +public class SQLFilters { + + /** + * Creates a SQLFilter that matches all rows. + * + * @return A SQLFilter that matches all rows + */ + public static SQLFilter matchAllRows() { + return new MatchAllSQLFilter(); + } + + /** + * Creates a SQLFilter that matches no rows. + * + * @return A SQLFilter that matches no rows + */ + public static SQLFilter matchNoRows() { + return new MatchNoSQLFilter(); + } + + /** + * Creates a SQLFilter with a simple SQL expression. + * + * @param sql The SQL expression + * @return A SQLFilter with the given SQL expression + */ + public static SQLFilter simple(String sql) { + return new SimpleSQLFilter(sql); + } + + /** + * Implementation of SQLFilter that matches all rows. + */ + public static class MatchAllSQLFilter implements SQLFilter { + @Override + public String toSql() { + return "1=1"; + } + + @Override + public boolean matchesAllRows() { + return true; + } + + @Override + public boolean matchesNoRows() { + return false; + } + } + + /** + * Implementation of SQLFilter that matches no rows. + */ + public static class MatchNoSQLFilter implements SQLFilter { + @Override + public String toSql() { + return "1=0"; + } + + @Override + public boolean matchesAllRows() { + return false; + } + + @Override + public boolean matchesNoRows() { + return true; + } + } + + /** + * Implementation of SQLFilter with a simple SQL expression. + */ + public static class SimpleSQLFilter implements SQLFilter { + private final String sql; + + /** + * Creates a new SimpleSQLFilter with the given SQL expression. + * + * @param sql The SQL expression + */ + public SimpleSQLFilter(String sql) { + this.sql = sql; + } + + @Override + public String toSql() { + return sql; + } + + @Override + public boolean matchesAllRows() { + return "1=1".equals(sql); + } + + @Override + public boolean matchesNoRows() { + return "1=0".equals(sql); + } + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SqlQueryBuilder.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SqlQueryBuilder.java new file mode 100644 index 00000000..140fb2b1 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SqlQueryBuilder.java @@ -0,0 +1,173 @@ +package dev.langchain4j.community.store.embedding.oceanbase.sql; + +import java.util.ArrayList; +import java.util.List; + +/** + * Builder for SQL queries. + * Implements the Builder pattern to create SQL queries in a fluent, readable way. + */ +public class SqlQueryBuilder { + private StringBuilder queryBuilder = new StringBuilder(); + private List selectColumns = new ArrayList<>(); + private String fromTable; + private String whereClause; + private List orderByColumns = new ArrayList<>(); + private boolean approximate = false; + private Integer limit; + + /** + * Start building a SELECT query. + * + * @return this builder + */ + public static SqlQueryBuilder select() { + return new SqlQueryBuilder(); + } + + /** + * Add columns to the SELECT clause. + * + * @param columns The columns to select + * @return this builder + */ + public SqlQueryBuilder columns(String... columns) { + for (String column : columns) { + selectColumns.add(column); + } + return this; + } + + /** + * Add a function-based column to the SELECT clause. + * + * @param functionName The name of the function + * @param args The function arguments + * @return this builder + */ + public SqlQueryBuilder function(String functionName, String... args) { + String column = functionName + "(" + String.join(", ", args) + ")"; + selectColumns.add(column); + return this; + } + + /** + * Add an alias to the last added column. + * + * @param alias The alias name + * @return this builder + */ + public SqlQueryBuilder as(String alias) { + if (!selectColumns.isEmpty()) { + int lastIndex = selectColumns.size() - 1; + String column = selectColumns.get(lastIndex) + " AS " + alias; + selectColumns.set(lastIndex, column); + } + return this; + } + + /** + * Set the FROM clause. + * + * @param table The table name + * @return this builder + */ + public SqlQueryBuilder from(String table) { + this.fromTable = table; + return this; + } + + /** + * Set the WHERE clause. + * + * @param condition The WHERE condition + * @return this builder + */ + public SqlQueryBuilder where(String condition) { + this.whereClause = condition; + return this; + } + + /** + * Add an ORDER BY clause. + * + * @param columns The columns to order by + * @return this builder + */ + public SqlQueryBuilder orderBy(String... columns) { + for (String column : columns) { + orderByColumns.add(column); + } + return this; + } + + /** + * Add an ORDER BY clause with a function. + * + * @param functionName The name of the function + * @param args The function arguments + * @return this builder + */ + public SqlQueryBuilder orderByFunction(String functionName, String... args) { + String column = functionName + "(" + String.join(", ", args) + ")"; + orderByColumns.add(column); + return this; + } + + /** + * Use approximate search (for vector search in OceanBase). + * + * @return this builder + */ + public SqlQueryBuilder approximate() { + this.approximate = true; + return this; + } + + /** + * Set the LIMIT clause. + * + * @param limit The maximum number of rows to return + * @return this builder + */ + public SqlQueryBuilder limit(int limit) { + this.limit = limit; + return this; + } + + /** + * Build the SQL query string. + * + * @return The complete SQL query + */ + public String build() { + if (selectColumns.isEmpty()) { + throw new IllegalStateException("No columns specified for SELECT"); + } + if (fromTable == null) { + throw new IllegalStateException("No table specified for FROM"); + } + + queryBuilder.append("SELECT ").append(String.join(", ", selectColumns)); + queryBuilder.append(" FROM ").append(fromTable); + + if (whereClause != null && !whereClause.isEmpty()) { + queryBuilder.append(" WHERE ").append(whereClause); + } + + if (!orderByColumns.isEmpty()) { + queryBuilder.append(" ORDER BY ").append(String.join(", ", orderByColumns)); + + // Add APPROXIMATE keyword if needed + if (approximate) { + queryBuilder.append(" APPROXIMATE"); + } + } + + if (limit != null) { + queryBuilder.append(" LIMIT ").append(limit); + } + + return queryBuilder.toString(); + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java new file mode 100644 index 00000000..48107d83 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java @@ -0,0 +1,187 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +import dev.langchain4j.model.embedding.EmbeddingModel; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.UUID; +import javax.sql.DataSource; + +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +/** + * Common operations for tests. + */ +public class CommonTestOperations { + + private static final Logger log = LoggerFactory.getLogger(CommonTestOperations.class); + + private static final String TABLE_NAME = + "langchain4j_oceanbase_test_" + UUID.randomUUID().toString().replace("-", ""); + + // OceanBase container configuration + private static final int OCEANBASE_PORT = 2881; // Default OceanBase port + private static GenericContainer oceanBaseContainer; + private static DataSource dataSource; + private static EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + static { + startOceanBaseContainer(); + } + + /** + * Starts the OceanBase container for tests. + */ + private static void startOceanBaseContainer() { + try { + // Using OceanBase's official Docker image + oceanBaseContainer = new GenericContainer<>("oceanbase/oceanbase-ce:4.3.5-lts") + .withExposedPorts(OCEANBASE_PORT) + .withEnv("MODE", "standalone") // For single-node deployment + // Wait for boot success message + .waitingFor(Wait.forLogMessage(".*boot success!.*", 1)); + + // Start the container + oceanBaseContainer.start(); + + // Get the mapped port and host + String jdbcUrl = String.format( + "jdbc:oceanbase://%s:%d/test", + oceanBaseContainer.getHost(), oceanBaseContainer.getMappedPort(OCEANBASE_PORT)); + + log.info("OceanBase container started at {}", jdbcUrl); + + // Create a data source with the container's connection info + dataSource = new SimpleDataSource(jdbcUrl, "root@test", ""); + + // Create test database and setup + try (Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement()) { + // Set memory limit for vector columns + stmt.execute("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30;"); + // Create a test database if needed + stmt.execute("CREATE DATABASE IF NOT EXISTS test_example"); + stmt.execute("USE test_example"); + // Add any other initialization SQL you need + } catch (SQLException e) { + log.error("Failed to initialize OceanBase container", e); + throw new RuntimeException("Failed to initialize OceanBase container", e); + } + } catch (Exception e) { + log.error("Failed to start OceanBase container", e); + throw new RuntimeException("Failed to start OceanBase container", e); + } + } + + /** + * Creates a new embedding store for testing. + * + * @return A new embedding store. + */ + public static OceanBaseEmbeddingStore newEmbeddingStore() { + return OceanBaseEmbeddingStore.builder(getDataSource()) + .embeddingTable(EmbeddingTable.builder(TABLE_NAME) + .vectorDimension(embeddingModel.dimension()) + .createOption(CreateOption.CREATE_OR_REPLACE) + .distanceMetric("cosine") + .build()) + .build(); + } + + /** + * Returns the data source for connecting to OceanBase. + * + * @return The data source. + */ + public static DataSource getDataSource() { + if (dataSource == null) { + throw new IllegalStateException( + "DataSource is not initialized. Container might not have started properly."); + } + return dataSource; + } + + /** + * Simple DataSource implementation for tests. + */ + private static class SimpleDataSource implements DataSource { + private final String url; + private final String user; + private final String password; + + public SimpleDataSource(String url, String user, String password) { + this.url = url; + this.user = user; + this.password = password; + } + + @Override + public Connection getConnection() throws SQLException { + return DriverManager.getConnection(url, user, password); + } + + @Override + public Connection getConnection(String username, String password) throws SQLException { + return DriverManager.getConnection(url, username, password); + } + + @Override + public java.io.PrintWriter getLogWriter() throws SQLException { + return null; + } + + @Override + public void setLogWriter(java.io.PrintWriter out) throws SQLException {} + + @Override + public void setLoginTimeout(int seconds) throws SQLException {} + + @Override + public int getLoginTimeout() throws SQLException { + return 0; + } + + @Override + public java.util.logging.Logger getParentLogger() { + return null; + } + + @Override + public T unwrap(Class iface) throws SQLException { + throw new SQLException("Unwrapping not supported"); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return false; + } + } + + /** + * Drops the test table. + * + * @throws SQLException If an error occurs. + */ + public static void dropTable() throws SQLException { + try (Connection connection = getDataSource().getConnection(); + Statement statement = connection.createStatement()) { + statement.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + log.info("Dropped table {}", TABLE_NAME); + } + } + + /** + * Stops the OceanBase container. + */ + public static void stopContainer() { + if (oceanBaseContainer != null && oceanBaseContainer.isRunning()) { + oceanBaseContainer.stop(); + log.info("OceanBase container stopped"); + } + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java new file mode 100644 index 00000000..b17bb048 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java @@ -0,0 +1,54 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT; +import java.sql.SQLException; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; + +/** + * Integration tests for {@link OceanBaseEmbeddingStore}. + */ +public class OceanBaseEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT { + + private OceanBaseEmbeddingStore embeddingStore = CommonTestOperations.newEmbeddingStore(); + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @AfterAll + static void cleanUp() throws SQLException { + try { + CommonTestOperations.dropTable(); + } finally { + CommonTestOperations.stopContainer(); + } + } + + @AfterEach + void tearDown() { + clearStore(); + } + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } + + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } + + @Override + protected void clearStore() { + embeddingStore.removeAll(); + } + + @Override + protected boolean supportsContains() { + return true; + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseWithRemovalIT.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseWithRemovalIT.java new file mode 100644 index 00000000..9772d79d --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseWithRemovalIT.java @@ -0,0 +1,51 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import java.sql.SQLException; + +/** + * Tests for {@link OceanBaseEmbeddingStore} with removal. + */ +public class OceanBaseWithRemovalIT extends EmbeddingStoreWithRemovalIT { + + private OceanBaseEmbeddingStore embeddingStore; + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @BeforeEach + void setUp() { + embeddingStore = CommonTestOperations.newEmbeddingStore(); + } + + @AfterAll + static void cleanUp() throws SQLException { + try { + CommonTestOperations.dropTable(); + } finally { + CommonTestOperations.stopContainer(); + } + } + + @AfterEach + void tearDown() { + embeddingStore.removeAll(); + } + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } + + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } +} diff --git a/langchain4j-community-bom/pom.xml b/langchain4j-community-bom/pom.xml index 6dc580ba..8daad7eb 100644 --- a/langchain4j-community-bom/pom.xml +++ b/langchain4j-community-bom/pom.xml @@ -109,6 +109,12 @@ ${project.version} + + dev.langchain4j + langchain4j-community-oceanbase + ${project.version} + + dev.langchain4j diff --git a/pom.xml b/pom.xml index 3dfb5423..b4ade995 100644 --- a/pom.xml +++ b/pom.xml @@ -49,6 +49,7 @@ embedding-stores/langchain4j-community-alloydb-pg embedding-stores/langchain4j-community-cloud-sql-pg embedding-stores/langchain4j-community-neo4j + embedding-stores/langchain4j-community-oceanbase content-retrievers/langchain4j-community-lucene