From 640641ba37db0a82d14c01e82f5b71241f792822 Mon Sep 17 00:00:00 2001 From: SimonChou <2484152300@qq.com> Date: Tue, 24 Jun 2025 02:00:35 +0800 Subject: [PATCH 1/6] [FEATURE] Support OceanBase Embedding Store --- .../langchain4j-community-oceanbase/README.md | 208 ++++++++ .../langchain4j-community-oceanbase/pom.xml | 82 +++ .../embedding/oceanbase/CreateOption.java | 21 + .../embedding/oceanbase/EmbeddingTable.java | 375 +++++++++++++ .../oceanbase/MetadataKeyMapper.java | 21 + .../oceanbase/OceanBaseEmbeddingStore.java | 503 ++++++++++++++++++ .../distance/CosineDistanceConverter.java | 13 + .../distance/DefaultDistanceConverter.java | 14 + .../oceanbase/distance/DistanceConverter.java | 16 + .../distance/DistanceConverterFactory.java | 31 ++ .../distance/EuclideanDistanceConverter.java | 13 + .../search/OceanBaseSearchTemplate.java | 232 ++++++++ .../oceanbase/search/SearchTemplate.java | 112 ++++ .../oceanbase/sql/FilterVisitor.java | 109 ++++ .../embedding/oceanbase/sql/SQLFilter.java | 37 ++ .../oceanbase/sql/SQLFilterFactory.java | 28 + .../oceanbase/sql/SQLFilterVisitor.java | 244 +++++++++ .../embedding/oceanbase/sql/SQLFilters.java | 107 ++++ .../oceanbase/sql/SqlQueryBuilder.java | 173 ++++++ .../oceanbase/CommonTestOperations.java | 189 +++++++ .../oceanbase/OceanBaseEmbeddingStoreIT.java | 242 +++++++++ .../store/embedding/oceanbase/TestData.java | 81 +++ langchain4j-community-bom/pom.xml | 6 + pom.xml | 1 + 24 files changed, 2858 insertions(+) create mode 100644 embedding-stores/langchain4j-community-oceanbase/README.md create mode 100644 embedding-stores/langchain4j-community-oceanbase/pom.xml create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/CreateOption.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/MetadataKeyMapper.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/CosineDistanceConverter.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DefaultDistanceConverter.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverter.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/OceanBaseSearchTemplate.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/SearchTemplate.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/FilterVisitor.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilter.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterFactory.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterVisitor.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilters.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SqlQueryBuilder.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/TestData.java diff --git a/embedding-stores/langchain4j-community-oceanbase/README.md b/embedding-stores/langchain4j-community-oceanbase/README.md new file mode 100644 index 00000000..e53aa9ea --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/README.md @@ -0,0 +1,208 @@ +# OceanBase Vector Store for LangChain4j + +This module implements an `EmbeddingStore` backed by an OceanBase database. + +- [Product Documentation](https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002826816) + +The **OceanBase for LangChain4j** package provides a first-class experience for connecting to OceanBase instances from the LangChain4j ecosystem while providing the following benefits: + +- **Simplified Vector Storage**: Utilize OceanBase's vector data types and indexing capabilities for efficient similarity searches. +- **Improved Metadata Handling**: Store metadata in JSON columns instead of strings, resulting in significant performance improvements. +- **Clear Separation**: Clearly separate table and extension creation, allowing for distinct permissions and streamlined workflows. +- **Better Integration with OceanBase**: Built-in methods to take advantage of OceanBase's advanced indexing and scalability capabilities. + +## Quick Start + +In order to use this library, you first need to go through the following steps: + +1. [Install OceanBase Database](https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002012734) +2. Create a database and user +3. Configure vector memory limits (if needed) + +### Maven Dependency + +```xml + + dev.langchain4j + langchain4j-community-oceanbase + 1.2.0-beta8-SNAPSHOT + +``` + +### Supported Java Versions + +Java >= 17 + +## OceanBaseEmbeddingStore Usage + +`OceanBaseEmbeddingStore` is used to store text embedded data and perform vector search. Instances can be created by configuring the provided `Builder`, which requires: + +- A `DataSource` instance (connected to an OceanBase database) +- Table name +- Table configuration (optional, uses standard configuration by default) +- Exact search option (optional, uses approximate search by default) + +Example usage: + +```java +import dev.langchain4j.community.store.embedding.oceanbase.OceanBaseEmbeddingStore; +import dev.langchain4j.community.store.embedding.oceanbase.EmbeddingTable; +import dev.langchain4j.community.store.embedding.oceanbase.CreateOption; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; + +import javax.sql.DataSource; +import java.util.*; + +// Create a data source +DataSource dataSource = createDataSource(); // You need to implement this method + +// Create a vector store +OceanBaseEmbeddingStore store = OceanBaseEmbeddingStore.builder(dataSource) + .embeddingTable( + EmbeddingTable.builder("my_embeddings") + .vectorDimension(384) // Set vector dimension + .createOption(CreateOption.CREATE_IF_NOT_EXISTS) + .build()) + .build(); + +// Add embeddings +List testTexts = Arrays.asList("apple", "banana", "car", "truck"); +List embeddings = new ArrayList<>(); +List textSegments = new ArrayList<>(); + +for (String text : testTexts) { + Map metaMap = new HashMap<>(); + metaMap.put("category", text.length() <= 5 ? "fruit" : "vehicle"); + Metadata metadata = Metadata.from(metaMap); + textSegments.add(TextSegment.from(text, metadata)); + embeddings.add(myEmbeddingModel.embed(text).content()); // Use your embedding model +} + +// Batch add embeddings and text segments +List ids = store.addAll(embeddings, textSegments); + +// Search for similar vectors +EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(embeddings.get(0)) // Search for content similar to "apple" + .maxResults(10) + .minScore(0.7) + .build(); + +List> results = store.search(request).matches(); + +// Use metadata filtering +import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder; + +// Search only in the "fruit" category +EmbeddingSearchRequest filteredRequest = EmbeddingSearchRequest.builder() + .queryEmbedding(embeddings.get(0)) + .maxResults(10) + .filter(MetadataFilterBuilder.metadataKey("category").isEqualTo("fruit")) + .build(); + +List> filteredResults = store.search(filteredRequest).matches(); + +// Remove embeddings +store.remove(ids.get(0)); // Remove a single vector +store.removeAll(Arrays.asList(ids.get(1), ids.get(2))); // Remove multiple vectors +store.removeAll(MetadataFilterBuilder.metadataKey("category").isEqualTo("fruit")); // Remove by metadata +store.removeAll(); // Remove all vectors +``` + +## EmbeddingTable Configuration + +The `EmbeddingTable` class is used to configure the structure and creation options of the vector table: + +```java +import dev.langchain4j.community.store.embedding.oceanbase.EmbeddingTable; +import dev.langchain4j.community.store.embedding.oceanbase.CreateOption; + +// Basic configuration +EmbeddingTable table = EmbeddingTable.builder("my_embeddings") + .vectorDimension(384) // Set vector dimension + .createOption(CreateOption.CREATE_IF_NOT_EXISTS) + .build(); + +// Advanced configuration +EmbeddingTable advancedTable = EmbeddingTable.builder("advanced_embeddings") + .idColumn("custom_id") // Custom ID column name + .embeddingColumn("vector_data") // Custom vector column name + .textColumn("content") // Custom text column name + .metadataColumn("meta_info") // Custom metadata column name + .vectorDimension(768) // Set vector dimension + .vectorIndexName("idx_vector_search") // Custom index name + .distanceMetric("L2") // Set distance metric (L2, IP, COSINE) + .indexType("hnsw") // Set index type (hnsw, flat) + .createOption(CreateOption.CREATE_OR_REPLACE) // Table creation option + .build(); +``` + +### Table Creation Options + +The `CreateOption` enum provides the following options: + +- `CREATE_NONE`: Do not create a table, assumes the table already exists +- `CREATE_IF_NOT_EXISTS`: Create the table if it does not exist (default) +- `CREATE_OR_REPLACE`: Create the table, replacing it if it exists + +## Search Options + +`OceanBaseEmbeddingStore` supports two search modes: + +- **Approximate Search** (default): Faster but may not be 100% accurate +- **Exact Search**: Slower but 100% accurate + +```java +// Use exact search +OceanBaseEmbeddingStore exactStore = OceanBaseEmbeddingStore.builder(dataSource) + .embeddingTable("my_embeddings") + .exactSearch(true) // Enable exact search + .build(); +``` + +## Metadata Filtering + +Complex metadata filtering conditions can be built using `MetadataFilterBuilder`: + +```java +import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder; + +// Basic filtering +EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() + .queryEmbedding(queryEmbedding) + .filter(MetadataFilterBuilder.metadataKey("category").isEqualTo("fruit")) + .build(); + +// Combined filtering +EmbeddingSearchRequest complexRequest = EmbeddingSearchRequest.builder() + .queryEmbedding(queryEmbedding) + .filter( + MetadataFilterBuilder.metadataKey("category").isEqualTo("fruit") + .and(MetadataFilterBuilder.metadataKey("color").isEqualTo("red")) + ) + .build(); +``` + +## Performance Optimization + +For optimal performance, consider the following recommendations: + +1. Adjust OceanBase's vector memory limit appropriately for large vector collections: + ```sql + ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30; + ``` + +2. Choose the index type that suits your use case: + - `hnsw`: Suitable for most scenarios, provides a good balance of performance/accuracy + - `flat`: Use when highest accuracy is required + +3. Add embeddings in batches to improve performance: + ```java + store.addAll(embeddings, textSegments); + ``` + +4. For frequent similarity searches, consider using approximate search mode (default). diff --git a/embedding-stores/langchain4j-community-oceanbase/pom.xml b/embedding-stores/langchain4j-community-oceanbase/pom.xml new file mode 100644 index 00000000..20a7afb4 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/pom.xml @@ -0,0 +1,82 @@ + + + 4.0.0 + + dev.langchain4j + langchain4j-community + 1.2.0-beta8-SNAPSHOT + ../../pom.xml + + + langchain4j-community-oceanbase + LangChain4j :: Community :: Integration :: OceanBase + Community Integration with OceanBase Vector Storage + + + + dev.langchain4j + langchain4j-core + ${langchain4j.core.version} + + + + + com.oceanbase + oceanbase-client + 2.4.5 + + + + org.slf4j + slf4j-api + + + + + com.fasterxml.jackson.core + jackson-databind + + + + + dev.langchain4j + langchain4j-core + 1.1.0-SNAPSHOT + tests + test-jar + test + + + + org.junit.jupiter + junit-jupiter + test + + + + org.assertj + assertj-core + test + + + + org.testcontainers + testcontainers + test + + + + org.testcontainers + junit-jupiter + test + + + + ch.qos.logback + logback-classic + test + + + diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/CreateOption.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/CreateOption.java new file mode 100644 index 00000000..9cb2aa58 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/CreateOption.java @@ -0,0 +1,21 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +/** + * Option for creating a table. + */ +public enum CreateOption { + /** + * Do not create a table. This assumes that the table already exists. + */ + CREATE_NONE, + + /** + * Create the table if it does not exist. + */ + CREATE_IF_NOT_EXISTS, + + /** + * Create the table, replacing an existing table if it exists. + */ + CREATE_OR_REPLACE +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java new file mode 100644 index 00000000..d54dede2 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java @@ -0,0 +1,375 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.Statement; +import javax.sql.DataSource; + +/** + * Represents a database table where embeddings, text, and metadata are stored. The columns of this table are listed + */ +public final class EmbeddingTable { + + /** + * Option which configures how the {@link #create(DataSource)} method creates this table + */ + private final CreateOption createOption; + + /** + * The name of this table + */ + private final String name; + + /** + * Name of a column which stores an id. + */ + private final String idColumn; + + /** + * Name of a column which stores an embedding. + */ + private final String embeddingColumn; + + /** + * Name of a column which stores text. + */ + private final String textColumn; + + /** + * Name of a column which stores metadata. + */ + private final String metadataColumn; + + /** + * The vector dimension + */ + private final int vectorDimension; + + /** + * The vector index name + */ + private final String vectorIndexName; + + /** + * The distance metric for vector similarity search + */ + private final String distanceMetric; + + /** + * The index type (e.g., hnsw, flat) + */ + private final String indexType; + + private EmbeddingTable(Builder builder) { + createOption = builder.createOption; + name = builder.name; + idColumn = builder.idColumn; + embeddingColumn = builder.embeddingColumn; + textColumn = builder.textColumn; + metadataColumn = builder.metadataColumn; + vectorDimension = builder.vectorDimension; + vectorIndexName = builder.vectorIndexName; + distanceMetric = builder.distanceMetric; + indexType = builder.indexType; + } + + /** + * Creates a table configured by the {@link Builder} of this EmbeddingTable. No table is created if the Builder was + * configured with {@link CreateOption#CREATE_NONE}. + * + * @param dataSource Data source that connects to an OceanBase Database where the table is (possibly) created. + * @throws SQLException If an error prevents the table from being created. + */ + void create(DataSource dataSource) throws SQLException { + if (createOption == CreateOption.CREATE_NONE) return; + + try (Connection connection = dataSource.getConnection(); + Statement statement = connection.createStatement()) { + if (createOption == CreateOption.CREATE_OR_REPLACE) { + statement.addBatch("DROP TABLE IF EXISTS " + name); + } + + StringBuilder createTableSql = new StringBuilder(); + createTableSql.append("CREATE TABLE IF NOT EXISTS ").append(name) + .append("(").append(idColumn).append(" VARCHAR(36) NOT NULL, ") + .append(embeddingColumn).append(" VECTOR(").append(vectorDimension).append("), ") + .append(textColumn).append(" VARCHAR(4000), ") + .append(metadataColumn).append(" JSON, ") + .append("PRIMARY KEY (").append(idColumn).append("), ") + .append("VECTOR INDEX ").append(vectorIndexName).append("(") + .append(embeddingColumn).append(") WITH (distance=").append(distanceMetric) + .append(", type=").append(indexType).append("))"); + + statement.addBatch(createTableSql.toString()); + statement.executeBatch(); + } + } + + /** + * Maps a metadata key to a JSON path expression for use in SQL queries. + * + * @param key Name of a metadata key. Not null. + * @return A JSON path expression that returns the value of the key. + */ + String mapMetadataKey(String key) { + return "JSON_VALUE(" + metadataColumn + ", '$." + key + "')"; + } + + /** + * Returns a MetadataKeyMapper that maps metadata keys to JSON path expressions. + * + * @return A metadata key mapper. Not null. + */ + public MetadataKeyMapper getMetadataKeyMapper() { + return key -> "JSON_VALUE(" + metadataColumn + ", '$." + key + "')"; + } + + /** + * Returns the name of this table. + * + * @return Table name. Not null. + */ + public String name() { + return name; + } + + /** + * Returns the name of the id column. + * + * @return Id column name. Not null. + */ + public String idColumn() { + return idColumn; + } + + /** + * Returns the name of the embedding column. + * + * @return Embedding column name. Not null. + */ + public String embeddingColumn() { + return embeddingColumn; + } + + /** + * Returns the name of the text column. + * + * @return Text column name. Not null. + */ + public String textColumn() { + return textColumn; + } + + /** + * Returns the name of the metadata column. + * + * @return Metadata column name. Not null. + */ + public String metadataColumn() { + return metadataColumn; + } + + /** + * Returns the dimension of the vector. + * + * @return Vector dimension. + */ + public int vectorDimension() { + return vectorDimension; + } + + /** + * Returns the name of the vector index. + * + * @return Vector index name. Not null. + */ + public String vectorIndexName() { + return vectorIndexName; + } + + /** + * Returns the distance metric used for vector similarity search. + * + * @return Distance metric. Not null. + */ + public String distanceMetric() { + return distanceMetric; + } + + /** + * Returns the index type. + * + * @return Index type. Not null. + */ + public String indexType() { + return indexType; + } + + /** + * Returns a builder for configuring an embedding table. + * + * @param tableName Name of the table. Not null. + * @return A builder. + */ + public static Builder builder(String tableName) { + return new Builder(tableName); + } + + /** + * Builder for an EmbeddingTable. + */ + public static final class Builder { + + private final String name; + private String idColumn = "id"; + private String embeddingColumn = "embedding"; + private String textColumn = "text"; + private String metadataColumn = "metadata"; + private int vectorDimension = 1536; + private String vectorIndexName = "idx_vector"; + private String distanceMetric = "L2"; + private String indexType = "hnsw"; + private CreateOption createOption = CreateOption.CREATE_IF_NOT_EXISTS; + + /** + * Creates a builder for a table with a given name. + * + * @param tableName Name of the table. Not null. + * @throws IllegalArgumentException If the table name is null. + */ + private Builder(String tableName) { + ensureNotNull(tableName, "tableName"); + this.name = tableName; + } + + /** + * Sets the column name for the id column. The default value is "id". + * + * @param columnName Column name. Not null. + * @return This builder. + * @throws IllegalArgumentException If the column name is null. + */ + public Builder idColumn(String columnName) { + ensureNotNull(columnName, "columnName"); + this.idColumn = columnName; + return this; + } + + /** + * Sets the column name for the embedding column. The default value is "embedding". + * + * @param columnName Column name. Not null. + * @return This builder. + * @throws IllegalArgumentException If the column name is null. + */ + public Builder embeddingColumn(String columnName) { + ensureNotNull(columnName, "columnName"); + this.embeddingColumn = columnName; + return this; + } + + /** + * Sets the column name for the text column. The default value is "text". + * + * @param columnName Column name. Not null. + * @return This builder. + * @throws IllegalArgumentException If the column name is null. + */ + public Builder textColumn(String columnName) { + ensureNotNull(columnName, "columnName"); + this.textColumn = columnName; + return this; + } + + /** + * Sets the column name for the metadata column. The default value is "metadata". + * + * @param columnName Column name. Not null. + * @return This builder. + * @throws IllegalArgumentException If the column name is null. + */ + public Builder metadataColumn(String columnName) { + ensureNotNull(columnName, "columnName"); + this.metadataColumn = columnName; + return this; + } + + /** + * Sets the dimensionality of the vector. The default value is 1536. + * + * @param dimension Vector dimension + * @return This builder. + * @throws IllegalArgumentException If the dimension is negative or zero. + */ + public Builder vectorDimension(int dimension) { + if (dimension <= 0) { + throw new IllegalArgumentException("Vector dimension must be positive"); + } + this.vectorDimension = dimension; + return this; + } + + /** + * Sets the name of the vector index. The default value is "idx_vector". + * + * @param indexName Index name. Not null. + * @return This builder. + * @throws IllegalArgumentException If the index name is null. + */ + public Builder vectorIndexName(String indexName) { + ensureNotNull(indexName, "indexName"); + this.vectorIndexName = indexName; + return this; + } + + /** + * Sets the distance metric for vector similarity search. The default value is "L2". + * + * @param metric Distance metric. Not null. + * @return This builder. + * @throws IllegalArgumentException If the metric is null. + */ + public Builder distanceMetric(String metric) { + ensureNotNull(metric, "metric"); + this.distanceMetric = metric; + return this; + } + + /** + * Sets the index type. The default value is "hnsw". + * + * @param type Index type. Not null. + * @return This builder. + * @throws IllegalArgumentException If the type is null. + */ + public Builder indexType(String type) { + ensureNotNull(type, "type"); + this.indexType = type; + return this; + } + + /** + * Sets the create option for the table. + * + * @param createOption Create option. Not null. + * @return This builder. + * @throws IllegalArgumentException If the create option is null. + */ + public Builder createOption(CreateOption createOption) { + ensureNotNull(createOption, "createOption"); + this.createOption = createOption; + return this; + } + + /** + * Creates an EmbeddingTable with the configuration defined by this builder. + * + * @return A new EmbeddingTable. Not null. + */ + public EmbeddingTable build() { + return new EmbeddingTable(this); + } + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/MetadataKeyMapper.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/MetadataKeyMapper.java new file mode 100644 index 00000000..b06a5538 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/MetadataKeyMapper.java @@ -0,0 +1,21 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +/** + * Interface for mapping metadata keys to SQL expressions. + * This allows for different mapping strategies without exposing internal implementations. + */ +public interface MetadataKeyMapper { + + /** + * Maps a metadata key to a SQL expression that can be used in queries. + * + * @param key The metadata key to map + * @return A SQL expression that references the given metadata key + */ + String mapKey(String key); + + /** + * Default implementation that uses JSON_VALUE function + */ + MetadataKeyMapper DEFAULT = key -> "JSON_VALUE(metadata, '$." + key + "')"; +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java new file mode 100644 index 00000000..9da74db3 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java @@ -0,0 +1,503 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +import static dev.langchain4j.internal.Utils.randomUUID; +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilterFactory; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.filter.Filter; +import dev.langchain4j.community.store.embedding.oceanbase.search.OceanBaseSearchTemplate; +import dev.langchain4j.community.store.embedding.oceanbase.search.SearchTemplate; + +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.*; +import java.util.*; +import javax.sql.DataSource; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * OceanBase Vector EmbeddingStore Implementation + */ +public final class OceanBaseEmbeddingStore implements EmbeddingStore { + + private static final Logger log = LoggerFactory.getLogger(OceanBaseEmbeddingStore.class); + + /** + * DataSource configured to connect with an OceanBase Database. + */ + private final DataSource dataSource; + + /** + * Table where embeddings are stored. + */ + private final EmbeddingTable table; + + /** + * true if search should use an exact search, or false if + * it should use approximate search. + */ + private final boolean isExactSearch; + + /** + * JSON Object Mapper for metadata serialization/deserialization + */ + private final ObjectMapper objectMapper; + + /** + * Search template for vector search operations + */ + private final SearchTemplate searchTemplate; + + /** + * Constructs embedding store configured by a builder. + * + * @param builder Builder that configures the embedding store. Not null. + * @throws IllegalArgumentException If the configuration is not valid. + */ + private OceanBaseEmbeddingStore(Builder builder) { + dataSource = builder.dataSource; + table = builder.embeddingTable; + isExactSearch = builder.isExactSearch; + objectMapper = new ObjectMapper(); + searchTemplate = new OceanBaseSearchTemplate(table, isExactSearch); + + try { + table.create(dataSource); + } catch (SQLException sqlException) { + throw uncheckSQLException(sqlException); + } + } + + /** + * Transforms a SQLException into a RuntimeException. + * @param sqlException SQLException to transform. Not null. + * @return A RuntimeException that wraps the SQLException. + */ + private static RuntimeException uncheckSQLException(SQLException sqlException) { + return new RuntimeException("SQL error: " + sqlException.getMessage(), sqlException); + } + + /** + * Returns a null-safe item from a List. + * @param list List to get an item from. + * @param index Index of the item to get. + * @param name Name of the parameter, for error messages. + * @param Type of the list items. + * @return The item at the given index. + * @throws IllegalArgumentException If the item is null. + */ + private static T ensureIndexNotNull(List list, int index, String name) { + if (index < 0 || index >= list.size()) { + throw new IllegalArgumentException(String.format( + "Index %d is out of bounds for %s list of size %d", index, name, list.size())); + } + T item = list.get(index); + if (item == null) { + throw new IllegalArgumentException(String.format("%s[%d] is null", name, index)); + } + return item; + } + + private static String ensureNotEmpty(String value, String name) { + ensureNotNull(value, name); + if (value.trim().isEmpty()) { + throw new IllegalArgumentException(name + " cannot be empty"); + } + return value; + } + + @Override + public String add(Embedding embedding) { + ensureNotNull(embedding, "embedding"); + + String id = randomUUID(); + add(id, embedding); + return id; + } + + @Override + public void add(String id, Embedding embedding) { + ensureNotEmpty(id, "id"); + ensureNotNull(embedding, "embedding"); + + String sql = "INSERT INTO " + table.name() + " (" + + table.idColumn() + ", " + + table.embeddingColumn() + + ") VALUES (?, ?)"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(sql)) { + statement.setString(1, id); + statement.setString(2, vectorToString(embedding.vector())); + statement.executeUpdate(); + } catch (SQLException sqlException) { + throw uncheckSQLException(sqlException); + } + } + + @Override + public List generateIds(final int n) { + List ids = new ArrayList<>(n); + for (int i = 0; i < n; i++) { + ids.add(randomUUID()); + } + return ids; + } + + @Override + public List addAll(List embeddings) { + ensureNotNull(embeddings, "embeddings"); + + if (embeddings.isEmpty()) { + return Collections.emptyList(); + } + + List ids = generateIds(embeddings.size()); + + String sql = "INSERT INTO " + table.name() + " (" + + table.idColumn() + ", " + + table.embeddingColumn() + + ") VALUES (?, ?)"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(sql)) { + + for (int i = 0; i < ids.size(); i++) { + statement.setString(1, ids.get(i)); + statement.setString(2, vectorToString(embeddings.get(i).vector())); + statement.addBatch(); + } + + statement.executeBatch(); + return ids; + } catch (SQLException sqlException) { + throw uncheckSQLException(sqlException); + } + } + + @Override + public String add(Embedding embedding, TextSegment textSegment) { + ensureNotNull(embedding, "embedding"); + ensureNotNull(textSegment, "textSegment"); + + String id = randomUUID(); + + String sql = "INSERT INTO " + table.name() + " (" + + table.idColumn() + ", " + + table.embeddingColumn() + ", " + + table.textColumn() + ", " + + table.metadataColumn() + + ") VALUES (?, ?, ?, ?)"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(sql)) { + statement.setString(1, id); + statement.setString(2, vectorToString(embedding.vector())); + statement.setString(3, textSegment.text()); + statement.setString(4, metadataToJson(textSegment.metadata())); + statement.executeUpdate(); + return id; + } catch (SQLException sqlException) { + throw uncheckSQLException(sqlException); + } + } + + @Override + public List addAll(List embeddings, List textSegments) { + ensureNotNull(embeddings, "embeddings"); + ensureNotNull(textSegments, "textSegments"); + + if (embeddings.isEmpty()) { + return Collections.emptyList(); + } + + if (embeddings.size() != textSegments.size()) { + throw new IllegalArgumentException(String.format( + "Number of embeddings (%d) and text segments (%d) must be the same", + embeddings.size(), textSegments.size())); + } + + List ids = generateIds(embeddings.size()); + addAll(ids, embeddings, textSegments); + return ids; + } + + @Override + public void addAll(List ids, List embeddings, List embedded) { + ensureNotNull(ids, "ids"); + ensureNotNull(embeddings, "embeddings"); + ensureNotNull(embedded, "embedded"); + + if (ids.isEmpty() || embeddings.isEmpty() || embedded.isEmpty()) { + return; + } + + if (ids.size() != embeddings.size() || ids.size() != embedded.size()) { + throw new IllegalArgumentException(String.format( + "Number of ids (%d), embeddings (%d), and embedded objects (%d) must be the same", + ids.size(), embeddings.size(), embedded.size())); + } + + String sql = "INSERT INTO " + table.name() + " (" + + table.idColumn() + ", " + + table.embeddingColumn() + ", " + + table.textColumn() + ", " + + table.metadataColumn() + + ") VALUES (?, ?, ?, ?)"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(sql)) { + + for (int i = 0; i < ids.size(); i++) { + String id = ensureIndexNotNull(ids, i, "ids"); + Embedding embedding = ensureIndexNotNull(embeddings, i, "embeddings"); + TextSegment segment = ensureIndexNotNull(embedded, i, "embedded"); + + statement.setString(1, id); + statement.setString(2, vectorToString(embedding.vector())); + statement.setString(3, segment.text()); + statement.setString(4, metadataToJson(segment.metadata())); + statement.addBatch(); + } + + statement.executeBatch(); + } catch (SQLException sqlException) { + throw uncheckSQLException(sqlException); + } + } + + @Override + public EmbeddingSearchResult search(EmbeddingSearchRequest request) { + // Use the search template (Template Method pattern) + return searchTemplate.search(dataSource, request); + } + + @Override + public void remove(String id) { + ensureNotNull(id, "id"); + + String sql = "DELETE FROM " + table.name() + " WHERE " + table.idColumn() + " = ?"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(sql)) { + statement.setString(1, id); + statement.executeUpdate(); + } catch (SQLException sqlException) { + throw uncheckSQLException(sqlException); + } + } + + @Override + public void removeAll(Collection ids) { + ensureNotNull(ids, "ids"); + + if (ids.isEmpty()) { + return; + } + + // Use placeholders for each ID in the IN clause + String placeholders = String.join(", ", Collections.nCopies(ids.size(), "?")); + String sql = "DELETE FROM " + table.name() + " WHERE " + table.idColumn() + " IN (" + placeholders + ")"; + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(sql)) { + + int index = 1; + for (String id : ids) { + statement.setString(index++, id); + } + + statement.executeUpdate(); + } catch (SQLException sqlException) { + throw uncheckSQLException(sqlException); + } + } + + @Override + public void removeAll(Filter filter) { + ensureNotNull(filter, "filter"); + + SQLFilter sqlFilter = SQLFilterFactory.create( + filter, (key, value) -> table.mapMetadataKey(key)); + + if (sqlFilter.matchesNoRows()) { + return; + } + + String sql; + if (sqlFilter.matchesAllRows()) { + sql = "DELETE FROM " + table.name(); + } else { + sql = "DELETE FROM " + table.name() + " WHERE " + sqlFilter.toSql(); + } + + try (Connection connection = dataSource.getConnection(); + Statement statement = connection.createStatement()) { + statement.executeUpdate(sql); + } catch (SQLException sqlException) { + throw uncheckSQLException(sqlException); + } + } + + @Override + public void removeAll() { + String sql = "DELETE FROM " + table.name(); + + try (Connection connection = dataSource.getConnection(); + Statement statement = connection.createStatement()) { + statement.executeUpdate(sql); + } catch (SQLException sqlException) { + throw uncheckSQLException(sqlException); + } + } + + /** + * Converts a vector to a string representation. + * + * @param vector The vector to convert. + * @return String representation of the vector in OceanBase format. + */ + private String vectorToString(float[] vector) { + if (vector == null || vector.length == 0) { + return "[]"; + } + + StringBuilder sb = new StringBuilder(); + sb.append("["); + for (int i = 0; i < vector.length; i++) { + if (i > 0) { + sb.append(", "); + } + sb.append(vector[i]); + } + sb.append("]"); + return sb.toString(); + } + + /** + * Converts Metadata to a JSON string. + * + * @param metadata The metadata to convert. + * @return JSON string representation of the metadata. + */ + private String metadataToJson(Metadata metadata) { + if (metadata == null || metadata.toMap().isEmpty()) { + return null; + } + + try { + return objectMapper.writeValueAsString(metadata.toMap()); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to convert metadata to JSON", e); + } + } + + /** + * Returns a builder for configuring an OceanBaseEmbeddingStore. + * + * @param dataSource DataSource that connects to an OceanBase Database. + * @return A builder. + */ + public static Builder builder(DataSource dataSource) { + return new Builder(dataSource); + } + + /** + * Builder for an OceanBaseEmbeddingStore. + */ + public static final class Builder { + + private final DataSource dataSource; + private EmbeddingTable embeddingTable; + private boolean isExactSearch = false; + + /** + * Constructs a builder for an OceanBaseEmbeddingStore. + * + * @param dataSource DataSource that connects to an OceanBase Database. + * @throws IllegalArgumentException If the DataSource is null. + */ + private Builder(DataSource dataSource) { + ensureNotNull(dataSource, "dataSource"); + this.dataSource = dataSource; + } + + /** + * Sets the embedding table with a default table configuration. + * + * @param tableName Name of the embedding table. + * @return This builder. + * @throws IllegalArgumentException If the table name is null. + */ + public Builder embeddingTable(String tableName) { + ensureNotNull(tableName, "tableName"); + this.embeddingTable = EmbeddingTable.builder(tableName).build(); + return this; + } + + /** + * Sets the embedding table with a default table configuration and a create option. + * + * @param tableName Name of the embedding table. + * @param createOption Option for table creation. + * @return This builder. + * @throws IllegalArgumentException If the table name or create option is null. + */ + public Builder embeddingTable(String tableName, CreateOption createOption) { + ensureNotNull(tableName, "tableName"); + ensureNotNull(createOption, "createOption"); + this.embeddingTable = EmbeddingTable.builder(tableName) + .createOption(createOption) + .build(); + return this; + } + + /** + * Sets the embedding table with a custom configuration. + * + * @param embeddingTable Embedding table configuration. + * @return This builder. + * @throws IllegalArgumentException If the embedding table is null. + */ + public Builder embeddingTable(EmbeddingTable embeddingTable) { + ensureNotNull(embeddingTable, "embeddingTable"); + this.embeddingTable = embeddingTable; + return this; + } + + /** + * Sets whether to use exact search instead of approximate search. + * + * @param isExactSearch true to use exact search, false to use approximate search. + * @return This builder. + */ + public Builder exactSearch(boolean isExactSearch) { + this.isExactSearch = isExactSearch; + return this; + } + + /** + * Builds an OceanBaseEmbeddingStore with the configuration defined by this builder. + * + * @return A new OceanBaseEmbeddingStore. + * @throws IllegalArgumentException If required parameters are missing. + */ + public OceanBaseEmbeddingStore build() { + if (embeddingTable == null) { + throw new IllegalArgumentException("embeddingTable must be configured"); + } + + return new OceanBaseEmbeddingStore(this); + } + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/CosineDistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/CosineDistanceConverter.java new file mode 100644 index 00000000..349d0c38 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/CosineDistanceConverter.java @@ -0,0 +1,13 @@ +package dev.langchain4j.community.store.embedding.oceanbase.distance; + +/** + * Converts cosine distance to similarity score. + * For cosine distance, similarity = 1 - distance + */ +public class CosineDistanceConverter implements DistanceConverter { + + @Override + public double toSimilarity(double distance) { + return 1.0 - distance; + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DefaultDistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DefaultDistanceConverter.java new file mode 100644 index 00000000..413060f1 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DefaultDistanceConverter.java @@ -0,0 +1,14 @@ +package dev.langchain4j.community.store.embedding.oceanbase.distance; + +/** + * Default distance to similarity converter. + * Uses exponential decay: similarity = e^(-distance) + * This works well for most distance metrics where smaller distance means higher similarity. + */ +public class DefaultDistanceConverter implements DistanceConverter { + + @Override + public double toSimilarity(double distance) { + return Math.exp(-distance); + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverter.java new file mode 100644 index 00000000..b45a8cda --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverter.java @@ -0,0 +1,16 @@ +package dev.langchain4j.community.store.embedding.oceanbase.distance; + +/** + * Strategy interface for converting distance values to similarity scores. + * Different distance metrics require different conversion formulas. + */ +public interface DistanceConverter { + + /** + * Converts a distance value to a similarity score. + * + * @param distance The distance value to convert + * @return A similarity score between 0 and 1, where 1 means exact match + */ + double toSimilarity(double distance); +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java new file mode 100644 index 00000000..f824760a --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java @@ -0,0 +1,31 @@ +package dev.langchain4j.community.store.embedding.oceanbase.distance; + +/** + * Factory for creating appropriate distance converters based on the metric type. + * Implements the Factory Method pattern. + */ +public class DistanceConverterFactory { + + /** + * Returns the appropriate converter for the given distance metric. + * + * @param metric The distance metric name (e.g., "cosine", "euclidean") + * @return A DistanceConverter appropriate for the given metric + */ + public static DistanceConverter getConverter(String metric) { + if (metric == null) { + return new DefaultDistanceConverter(); + } + + metric = metric.toLowerCase(); + + switch (metric) { + case "cosine": + return new CosineDistanceConverter(); + case "euclidean": + return new EuclideanDistanceConverter(); + default: + return new DefaultDistanceConverter(); + } + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java new file mode 100644 index 00000000..8979db87 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java @@ -0,0 +1,13 @@ +package dev.langchain4j.community.store.embedding.oceanbase.distance; + +/** + * Converts euclidean distance to similarity score. + * For euclidean distance, similarity = 1 / (1 + distance) + */ +public class EuclideanDistanceConverter implements DistanceConverter { + + @Override + public double toSimilarity(double distance) { + return 1.0 / (1.0 + distance); + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/OceanBaseSearchTemplate.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/OceanBaseSearchTemplate.java new file mode 100644 index 00000000..6b775406 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/OceanBaseSearchTemplate.java @@ -0,0 +1,232 @@ +package dev.langchain4j.community.store.embedding.oceanbase.search; + +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilterFactory; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.community.store.embedding.oceanbase.EmbeddingTable; +import dev.langchain4j.community.store.embedding.oceanbase.distance.DistanceConverter; +import dev.langchain4j.community.store.embedding.oceanbase.distance.DistanceConverterFactory; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilter; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SqlQueryBuilder; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.filter.Filter; +import dev.langchain4j.store.embedding.EmbeddingMatch; + + +/** + * Concrete implementation of SearchTemplate for OceanBase. + */ +public class OceanBaseSearchTemplate extends SearchTemplate { + + private final EmbeddingTable table; + private final boolean isExactSearch; + private final ObjectMapper objectMapper; + + /** + * Creates a new OceanBaseSearchTemplate. + * + * @param table The embedding table configuration + * @param isExactSearch Whether to use exact search + */ + public OceanBaseSearchTemplate(EmbeddingTable table, boolean isExactSearch) { + this.table = table; + this.isExactSearch = isExactSearch; + this.objectMapper = new ObjectMapper(); + } + + /** + * Converts a vector to a string representation. + * + * @param vector The vector to convert + * @return String representation of the vector in OceanBase format + */ + protected String vectorToString(float[] vector) { + if (vector == null || vector.length == 0) { + return "[]"; + } + + StringBuilder sb = new StringBuilder(); + sb.append("["); + for (int i = 0; i < vector.length; i++) { + if (i > 0) { + sb.append(", "); + } + sb.append(vector[i]); + } + sb.append("]"); + return sb.toString(); + } + + /** + * Converts a string representation of a vector back to a float array. + * + * @param vectorStr The string representation of the vector + * @return The vector as a float array + */ + protected float[] stringToVector(String vectorStr) { + if (vectorStr == null || vectorStr.isEmpty() || vectorStr.equals("[]")) { + return new float[0]; + } + + String trimmed = vectorStr.trim(); + if (trimmed.startsWith("[")) { + trimmed = trimmed.substring(1); + } + if (trimmed.endsWith("]")) { + trimmed = trimmed.substring(0, trimmed.length() - 1); + } + + String[] parts = trimmed.split(","); + float[] result = new float[parts.length]; + + for (int i = 0; i < parts.length; i++) { + result[i] = Float.parseFloat(parts[i].trim()); + } + + return result; + } + + /** + * Converts Metadata to a JSON string. + * + * @param metadata The metadata to convert + * @return JSON string representation of the metadata + */ + protected String metadataToJson(Metadata metadata) { + if (metadata == null || metadata.toMap().isEmpty()) { + return null; + } + + try { + return objectMapper.writeValueAsString(metadata.toMap()); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to convert metadata to JSON", e); + } + } + + /** + * Converts a JSON string to Metadata. + * + * @param json The JSON string to convert + * @return Metadata object + */ + @SuppressWarnings("unchecked") + protected Metadata jsonToMetadata(String json) { + if (json == null || json.isEmpty()) { + return Metadata.from(java.util.Collections.emptyMap()); + } + + try { + Map map = objectMapper.readValue(json, java.util.Map.class); + return Metadata.from(map); + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to convert JSON to metadata", e); + } + } + + @Override + protected String buildSearchQuery(EmbeddingSearchRequest request) { + // Use SqlQueryBuilder (Builder pattern) to construct the query + SqlQueryBuilder builder = SqlQueryBuilder.select() + .columns(table.idColumn(), table.embeddingColumn(), table.textColumn(), table.metadataColumn()) + .function( + table.distanceMetric().toLowerCase() + "_distance", + table.embeddingColumn(), + "?" + ) + .from(table.name()); + + // Add WHERE clause for filtering + Filter filter = request.filter(); + if (filter != null) { + String whereClause = buildWhereClause(filter); + if (whereClause != null && !whereClause.trim().isEmpty()) { + builder.where(whereClause); + } + } + + // Add ORDER BY clause + builder.orderByFunction( + table.distanceMetric().toLowerCase() + "_distance", + table.embeddingColumn(), + "?" + ); + + // Use APPROXIMATE keyword if not exact search + if (!isExactSearch) { + builder.approximate(); + } + + // Add LIMIT clause + builder.limit(request.maxResults()); + + return builder.build(); + } + + @Override + protected void setParameters(PreparedStatement statement, EmbeddingSearchRequest request) throws SQLException { + String queryVector = vectorToString(request.queryEmbedding().vector()); + statement.setString(1, queryVector); // For distance calculation in SELECT + statement.setString(2, queryVector); // For distance calculation in ORDER BY + } + + @Override + protected List> processResults(ResultSet resultSet, EmbeddingSearchRequest request) throws SQLException { + List> matches = new ArrayList<>(); + + while (resultSet.next()) { + String id = resultSet.getString(1); + String vectorStr = resultSet.getString(2); + String text = resultSet.getString(3); + String metadataJson = resultSet.getString(4); + double distance = resultSet.getDouble(5); + + // Convert distance to similarity score using appropriate converter + DistanceConverter converter = DistanceConverterFactory.getConverter(table.distanceMetric()); + double score = converter.toSimilarity(distance); + + // Apply minimum score filter if specified + if (score < request.minScore()) { + continue; + } + + Embedding embedding = new Embedding(stringToVector(vectorStr)); + Metadata metadata = jsonToMetadata(metadataJson); + TextSegment segment = (text != null) ? + TextSegment.from(text, metadata) : + null; + + matches.add(new EmbeddingMatch<>(score, id, embedding, segment)); + } + + return matches; + } + + @Override + protected String buildWhereClause(Filter filter) { + if (filter == null) { + return null; + } + + SQLFilter sqlFilter = SQLFilterFactory.create(filter, + (key, value) -> table.getMetadataKeyMapper().mapKey(key)); + + if (sqlFilter.matchesAllRows()) { + return null; + } + + return sqlFilter.toSql(); + } +} + diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/SearchTemplate.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/SearchTemplate.java new file mode 100644 index 00000000..17785c29 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/SearchTemplate.java @@ -0,0 +1,112 @@ +package dev.langchain4j.community.store.embedding.oceanbase.search; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; +import dev.langchain4j.store.embedding.filter.Filter; + +import javax.sql.DataSource; + + +/** + * Template for the search operation. + * Implements the Template Method pattern to define the skeleton of the search algorithm, + * deferring some steps to subclasses. + * + * @param The type of embedded objects + */ +public abstract class SearchTemplate { + + /** + * The main template method that defines the search algorithm. + * + * @param dataSource The data source to use + * @param request The search request + * @return The search result + */ + public final EmbeddingSearchResult search(DataSource dataSource, EmbeddingSearchRequest request) { + validateRequest(request); + + if (request.maxResults() <= 0) { + return new EmbeddingSearchResult<>(new ArrayList<>()); + } + + String query = buildSearchQuery(request); + + try (Connection connection = dataSource.getConnection(); + PreparedStatement statement = connection.prepareStatement(query)) { + + setParameters(statement, request); + + try (ResultSet resultSet = statement.executeQuery()) { + List> matches = processResults(resultSet, request); + return new EmbeddingSearchResult<>(matches); + } + } catch (SQLException e) { + throw new RuntimeException("Error executing search query: " + e.getMessage(), e); + } + } + + /** + * Validates the search request. + * This is a hook that may be overridden by subclasses. + * + * @param request The search request to validate + */ + protected void validateRequest(EmbeddingSearchRequest request) { + if (request == null) { + throw new IllegalArgumentException("Search request cannot be null"); + } + if (request.queryEmbedding() == null) { + throw new IllegalArgumentException("Query embedding cannot be null"); + } + } + + /** + * Builds the search query. + * This is an abstract method that must be implemented by subclasses. + * + * @param request The search request + * @return The SQL query string + */ + protected abstract String buildSearchQuery(EmbeddingSearchRequest request); + + /** + * Sets the parameters for the prepared statement. + * This is an abstract method that must be implemented by subclasses. + * + * @param statement The prepared statement + * @param request The search request + * @throws SQLException If an error occurs setting the parameters + */ + protected abstract void setParameters(PreparedStatement statement, EmbeddingSearchRequest request) throws SQLException; + + /** + * Processes the result set and creates embedding matches. + * This is an abstract method that must be implemented by subclasses. + * + * @param resultSet The result set from the query + * @param request The search request + * @return A list of embedding matches + * @throws SQLException If an error occurs processing the result set + */ + protected abstract List> processResults(ResultSet resultSet, EmbeddingSearchRequest request) throws SQLException; + + /** + * Converts a filter to a WHERE clause. + * This is a hook that may be overridden by subclasses. + * + * @param filter The filter to convert + * @return The WHERE clause, or null if no filter is applied + */ + protected String buildWhereClause(Filter filter) { + return null; + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/FilterVisitor.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/FilterVisitor.java new file mode 100644 index 00000000..701f6be3 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/FilterVisitor.java @@ -0,0 +1,109 @@ +package dev.langchain4j.community.store.embedding.oceanbase.sql; + +import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsIn; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThan; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotIn; +import dev.langchain4j.store.embedding.filter.logical.And; +import dev.langchain4j.store.embedding.filter.logical.Not; +import dev.langchain4j.store.embedding.filter.logical.Or; + +/** + * Visitor interface for Filter objects. + * Implements the Visitor pattern to handle different filter types + * without using instanceof checks and type casting. + */ +public interface FilterVisitor { + + /** + * Visit an IsEqualTo filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsEqualTo filter); + + /** + * Visit an IsNotEqualTo filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsNotEqualTo filter); + + /** + * Visit an IsGreaterThan filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsGreaterThan filter); + + /** + * Visit an IsGreaterThanOrEqualTo filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsGreaterThanOrEqualTo filter); + + /** + * Visit an IsLessThan filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsLessThan filter); + + /** + * Visit an IsLessThanOrEqualTo filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsLessThanOrEqualTo filter); + + /** + * Visit an IsIn filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsIn filter); + + /** + * Visit an IsNotIn filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(IsNotIn filter); + + /** + * Visit an And filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(And filter); + + /** + * Visit an Or filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(Or filter); + + /** + * Visit a Not filter. + * + * @param filter The filter to visit + * @return The resulting SQL filter + */ + SQLFilter visit(Not filter); +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilter.java new file mode 100644 index 00000000..5451119c --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilter.java @@ -0,0 +1,37 @@ +package dev.langchain4j.community.store.embedding.oceanbase.sql; + +import dev.langchain4j.store.embedding.filter.Filter; + +/** + * Transforms a metadata {@link Filter} into a SQL expression that can be used in a WHERE clause. + */ +public interface SQLFilter { + + /** + * Returns a SQL expression that can be used in a WHERE clause to filter results based on metadata values. + * The returned SQL expression should evaluate to true when a row meets the criteria of the metadata filter, + * and false otherwise. + * + * @return A SQL expression, or null if this filter is equivalent to "always include" all rows. + */ + String toSql(); + + /** + * Returns true if the SQL filter is guaranteed to match all rows, or false otherwise. + * When a SQL filter matches all rows (e.g., as if there were no filter), the result of + * {@link #toSql()} may be null or an expression which always evaluates to true, such as "1=1". + *

+ * When {@link #toSql()} is null, the SQL filter MUST match all rows, and this method MUST return true. + * + * @return True if this SQL filter matches all rows, or false if it might exclude some rows. + */ + boolean matchesAllRows(); + + /** + * Returns true if this SQL filter is guaranteed to match no rows. When a SQL filter matches no rows, + * the result of {@link #toSql()} should be an expression which always evaluates to false, such as "1=0". + * + * @return True if this SQL filter matches no rows, or false if it might include some rows. + */ + boolean matchesNoRows(); +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterFactory.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterFactory.java new file mode 100644 index 00000000..5b418860 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterFactory.java @@ -0,0 +1,28 @@ +package dev.langchain4j.community.store.embedding.oceanbase.sql; + +import java.util.function.BiFunction; + +import dev.langchain4j.store.embedding.filter.Filter; + +/** + * Factory for creating SQL filters. + * Implements the Factory Method pattern. + */ +public class SQLFilterFactory { + + /** + * Creates a SQLFilter from a Filter and key mapper. + * + * @param filter The filter to convert + * @param keyMapper Function that maps a key and value to a SQL column expression + * @return A SQLFilter representing the given filter + */ + public static SQLFilter create(Filter filter, BiFunction keyMapper) { + if (filter == null) { + return SQLFilters.matchAllRows(); + } + + SQLFilterVisitor visitor = new SQLFilterVisitor(keyMapper); + return visitor.process(filter); + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterVisitor.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterVisitor.java new file mode 100644 index 00000000..c2de1d00 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterVisitor.java @@ -0,0 +1,244 @@ +package dev.langchain4j.community.store.embedding.oceanbase.sql; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.function.BiFunction; + +import dev.langchain4j.store.embedding.filter.Filter; +import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan; +import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsIn; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThan; +import dev.langchain4j.store.embedding.filter.comparison.IsLessThanOrEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotEqualTo; +import dev.langchain4j.store.embedding.filter.comparison.IsNotIn; +import dev.langchain4j.store.embedding.filter.logical.And; +import dev.langchain4j.store.embedding.filter.logical.Not; +import dev.langchain4j.store.embedding.filter.logical.Or; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilters.MatchAllSQLFilter; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilters.MatchNoSQLFilter; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilters.SimpleSQLFilter; + +/** + * Concrete visitor implementation that converts Filter objects to SQLFilter objects. + * Implements the Visitor pattern. + */ +public class SQLFilterVisitor implements FilterVisitor { + private final BiFunction keyMapper; + + /** + * Creates a new SQLFilterVisitor with the given key mapper. + * + * @param keyMapper Function that maps a key and value to a SQL column expression + */ + public SQLFilterVisitor(BiFunction keyMapper) { + this.keyMapper = keyMapper; + } + + /** + * Helper method to convert a value to its SQL string representation. + * + * @param value The value to convert + * @return SQL string representation of the value + */ + private String valueToSql(Object value) { + if (value == null) { + return "NULL"; + } else if (value instanceof String) { + return "'" + ((String) value).replace("'", "''") + "'"; + } else if (value instanceof Number || value instanceof Boolean) { + return value.toString(); + } else { + return "'" + value.toString().replace("'", "''") + "'"; + } + } + + @Override + public SQLFilter visit(IsEqualTo filter) { + String expr = keyMapper.apply(filter.key(), filter.comparisonValue()); + Object value = filter.comparisonValue(); + + if (value == null) { + return new SimpleSQLFilter(expr + " IS NULL"); + } else { + return new SimpleSQLFilter(expr + " = " + valueToSql(value)); + } + } + + @Override + public SQLFilter visit(IsNotEqualTo filter) { + String expr = keyMapper.apply(filter.key(), filter.comparisonValue()); + Object value = filter.comparisonValue(); + + if (value == null) { + return new SimpleSQLFilter(expr + " IS NOT NULL"); + } else { + return new SimpleSQLFilter("(" + expr + " <> " + valueToSql(value) + " OR " + expr + " IS NULL)"); + } + } + + @Override + public SQLFilter visit(IsGreaterThan filter) { + String expr = keyMapper.apply(filter.key(), filter.comparisonValue()); + return new SimpleSQLFilter(expr + " > " + valueToSql(filter.comparisonValue())); + } + + @Override + public SQLFilter visit(IsGreaterThanOrEqualTo filter) { + String expr = keyMapper.apply(filter.key(), filter.comparisonValue()); + return new SimpleSQLFilter(expr + " >= " + valueToSql(filter.comparisonValue())); + } + + @Override + public SQLFilter visit(IsLessThan filter) { + String expr = keyMapper.apply(filter.key(), filter.comparisonValue()); + return new SimpleSQLFilter(expr + " < " + valueToSql(filter.comparisonValue())); + } + + @Override + public SQLFilter visit(IsLessThanOrEqualTo filter) { + String expr = keyMapper.apply(filter.key(), filter.comparisonValue()); + return new SimpleSQLFilter(expr + " <= " + valueToSql(filter.comparisonValue())); + } + + @Override + public SQLFilter visit(IsIn filter) { + Collection values = filter.comparisonValues(); + if (values == null || values.isEmpty()) { + return new MatchNoSQLFilter(); + } + + List valueStrings = new ArrayList<>(values.size()); + for (Object value : values) { + valueStrings.add(valueToSql(value)); + } + + String expr = keyMapper.apply(filter.key(), values.iterator().next()); + return new SimpleSQLFilter(expr + " IN (" + String.join(", ", valueStrings) + ")"); + } + + @Override + public SQLFilter visit(IsNotIn filter) { + Collection values = filter.comparisonValues(); + if (values == null || values.isEmpty()) { + return new MatchAllSQLFilter(); + } + + List valueStrings = new ArrayList<>(values.size()); + for (Object value : values) { + valueStrings.add(valueToSql(value)); + } + + String expr = keyMapper.apply(filter.key(), values.iterator().next()); + return new SimpleSQLFilter("(" + expr + " NOT IN (" + String.join(", ", valueStrings) + ") OR " + expr + " IS NULL)"); + } + + @Override + public SQLFilter visit(And filter) { + SQLFilter leftFilter = process(filter.left()); + + // If left side matches no rows, the entire AND also matches no rows + if (leftFilter.matchesNoRows()) { + return new MatchNoSQLFilter(); + } + + // If left side matches all rows, the result depends on the right side + if (leftFilter.matchesAllRows()) { + return process(filter.right()); + } + + SQLFilter rightFilter = process(filter.right()); + + // If right side matches no rows, the entire AND also matches no rows + if (rightFilter.matchesNoRows()) { + return new MatchNoSQLFilter(); + } + + // If right side matches all rows, the result depends on the left side + if (rightFilter.matchesAllRows()) { + return leftFilter; + } + + // Both sides need to participate in the AND operation + return new SimpleSQLFilter("(" + leftFilter.toSql() + ") AND (" + rightFilter.toSql() + ")"); + } + + @Override + public SQLFilter visit(Or filter) { + SQLFilter leftFilter = process(filter.left()); + + // If left side matches all rows, the entire OR also matches all rows + if (leftFilter.matchesAllRows()) { + return new MatchAllSQLFilter(); + } + + // If left side matches no rows, the result depends on the right side + if (leftFilter.matchesNoRows()) { + return process(filter.right()); + } + + SQLFilter rightFilter = process(filter.right()); + + // If right side matches all rows, the entire OR also matches all rows + if (rightFilter.matchesAllRows()) { + return new MatchAllSQLFilter(); + } + + // If right side matches no rows, the result depends on the left side + if (rightFilter.matchesNoRows()) { + return leftFilter; + } + + // Both sides need to participate in the OR operation + return new SimpleSQLFilter("(" + leftFilter.toSql() + ") OR (" + rightFilter.toSql() + ")"); + } + + @Override + public SQLFilter visit(Not filter) { + SQLFilter expressionFilter = process(filter.expression()); + + if (expressionFilter.matchesAllRows()) { + return new MatchNoSQLFilter(); + } else if (expressionFilter.matchesNoRows()) { + return new MatchAllSQLFilter(); + } else { + return new SimpleSQLFilter("NOT (" + expressionFilter.toSql() + ")"); + } + } + + /** + * Process any Filter object by making it accept this visitor. + * + * @param filter The filter to process + * @return The resulting SQLFilter + */ + public SQLFilter process(Filter filter) { + if (filter instanceof IsEqualTo) { + return visit((IsEqualTo) filter); + } else if (filter instanceof IsNotEqualTo) { + return visit((IsNotEqualTo) filter); + } else if (filter instanceof IsGreaterThan) { + return visit((IsGreaterThan) filter); + } else if (filter instanceof IsGreaterThanOrEqualTo) { + return visit((IsGreaterThanOrEqualTo) filter); + } else if (filter instanceof IsLessThan) { + return visit((IsLessThan) filter); + } else if (filter instanceof IsLessThanOrEqualTo) { + return visit((IsLessThanOrEqualTo) filter); + } else if (filter instanceof IsIn) { + return visit((IsIn) filter); + } else if (filter instanceof IsNotIn) { + return visit((IsNotIn) filter); + } else if (filter instanceof And) { + return visit((And) filter); + } else if (filter instanceof Or) { + return visit((Or) filter); + } else if (filter instanceof Not) { + return visit((Not) filter); + } else { + throw new IllegalArgumentException("Unsupported filter type: " + filter.getClass().getName()); + } + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilters.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilters.java new file mode 100644 index 00000000..49ca90c8 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilters.java @@ -0,0 +1,107 @@ +package dev.langchain4j.community.store.embedding.oceanbase.sql; + +/** + * Factory class for creating SQL filter instances. + */ +public class SQLFilters { + + /** + * Creates a SQLFilter that matches all rows. + * + * @return A SQLFilter that matches all rows + */ + public static SQLFilter matchAllRows() { + return new MatchAllSQLFilter(); + } + + /** + * Creates a SQLFilter that matches no rows. + * + * @return A SQLFilter that matches no rows + */ + public static SQLFilter matchNoRows() { + return new MatchNoSQLFilter(); + } + + /** + * Creates a SQLFilter with a simple SQL expression. + * + * @param sql The SQL expression + * @return A SQLFilter with the given SQL expression + */ + public static SQLFilter simple(String sql) { + return new SimpleSQLFilter(sql); + } + + /** + * Implementation of SQLFilter that matches all rows. + */ + public static class MatchAllSQLFilter implements SQLFilter { + @Override + public String toSql() { + return "1=1"; + } + + @Override + public boolean matchesAllRows() { + return true; + } + + @Override + public boolean matchesNoRows() { + return false; + } + } + + /** + * Implementation of SQLFilter that matches no rows. + */ + public static class MatchNoSQLFilter implements SQLFilter { + @Override + public String toSql() { + return "1=0"; + } + + @Override + public boolean matchesAllRows() { + return false; + } + + @Override + public boolean matchesNoRows() { + return true; + } + } + + /** + * Implementation of SQLFilter with a simple SQL expression. + */ + public static class SimpleSQLFilter implements SQLFilter { + private final String sql; + + /** + * Creates a new SimpleSQLFilter with the given SQL expression. + * + * @param sql The SQL expression + */ + public SimpleSQLFilter(String sql) { + this.sql = sql; + } + + @Override + public String toSql() { + return sql; + } + + @Override + public boolean matchesAllRows() { + return "1=1".equals(sql); + } + + @Override + public boolean matchesNoRows() { + return "1=0".equals(sql); + } + } +} + diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SqlQueryBuilder.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SqlQueryBuilder.java new file mode 100644 index 00000000..140fb2b1 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SqlQueryBuilder.java @@ -0,0 +1,173 @@ +package dev.langchain4j.community.store.embedding.oceanbase.sql; + +import java.util.ArrayList; +import java.util.List; + +/** + * Builder for SQL queries. + * Implements the Builder pattern to create SQL queries in a fluent, readable way. + */ +public class SqlQueryBuilder { + private StringBuilder queryBuilder = new StringBuilder(); + private List selectColumns = new ArrayList<>(); + private String fromTable; + private String whereClause; + private List orderByColumns = new ArrayList<>(); + private boolean approximate = false; + private Integer limit; + + /** + * Start building a SELECT query. + * + * @return this builder + */ + public static SqlQueryBuilder select() { + return new SqlQueryBuilder(); + } + + /** + * Add columns to the SELECT clause. + * + * @param columns The columns to select + * @return this builder + */ + public SqlQueryBuilder columns(String... columns) { + for (String column : columns) { + selectColumns.add(column); + } + return this; + } + + /** + * Add a function-based column to the SELECT clause. + * + * @param functionName The name of the function + * @param args The function arguments + * @return this builder + */ + public SqlQueryBuilder function(String functionName, String... args) { + String column = functionName + "(" + String.join(", ", args) + ")"; + selectColumns.add(column); + return this; + } + + /** + * Add an alias to the last added column. + * + * @param alias The alias name + * @return this builder + */ + public SqlQueryBuilder as(String alias) { + if (!selectColumns.isEmpty()) { + int lastIndex = selectColumns.size() - 1; + String column = selectColumns.get(lastIndex) + " AS " + alias; + selectColumns.set(lastIndex, column); + } + return this; + } + + /** + * Set the FROM clause. + * + * @param table The table name + * @return this builder + */ + public SqlQueryBuilder from(String table) { + this.fromTable = table; + return this; + } + + /** + * Set the WHERE clause. + * + * @param condition The WHERE condition + * @return this builder + */ + public SqlQueryBuilder where(String condition) { + this.whereClause = condition; + return this; + } + + /** + * Add an ORDER BY clause. + * + * @param columns The columns to order by + * @return this builder + */ + public SqlQueryBuilder orderBy(String... columns) { + for (String column : columns) { + orderByColumns.add(column); + } + return this; + } + + /** + * Add an ORDER BY clause with a function. + * + * @param functionName The name of the function + * @param args The function arguments + * @return this builder + */ + public SqlQueryBuilder orderByFunction(String functionName, String... args) { + String column = functionName + "(" + String.join(", ", args) + ")"; + orderByColumns.add(column); + return this; + } + + /** + * Use approximate search (for vector search in OceanBase). + * + * @return this builder + */ + public SqlQueryBuilder approximate() { + this.approximate = true; + return this; + } + + /** + * Set the LIMIT clause. + * + * @param limit The maximum number of rows to return + * @return this builder + */ + public SqlQueryBuilder limit(int limit) { + this.limit = limit; + return this; + } + + /** + * Build the SQL query string. + * + * @return The complete SQL query + */ + public String build() { + if (selectColumns.isEmpty()) { + throw new IllegalStateException("No columns specified for SELECT"); + } + if (fromTable == null) { + throw new IllegalStateException("No table specified for FROM"); + } + + queryBuilder.append("SELECT ").append(String.join(", ", selectColumns)); + queryBuilder.append(" FROM ").append(fromTable); + + if (whereClause != null && !whereClause.isEmpty()) { + queryBuilder.append(" WHERE ").append(whereClause); + } + + if (!orderByColumns.isEmpty()) { + queryBuilder.append(" ORDER BY ").append(String.join(", ", orderByColumns)); + + // Add APPROXIMATE keyword if needed + if (approximate) { + queryBuilder.append(" APPROXIMATE"); + } + } + + if (limit != null) { + queryBuilder.append(" LIMIT ").append(limit); + } + + return queryBuilder.toString(); + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java new file mode 100644 index 00000000..325d4eb1 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java @@ -0,0 +1,189 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +import dev.langchain4j.model.embedding.EmbeddingModel; + +import javax.sql.DataSource; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.UUID; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.wait.strategy.Wait; + +/** + * Common operations for tests. + */ +public class CommonTestOperations { + + private static final Logger log = LoggerFactory.getLogger(CommonTestOperations.class); + + private static final String TABLE_NAME = "langchain4j_oceanbase_test_" + UUID.randomUUID().toString().replace("-", ""); + private static final int VECTOR_DIM = 3; + + // OceanBase container configuration + private static final int OCEANBASE_PORT = 2881; // Default OceanBase port + private static GenericContainer oceanBaseContainer; + private static DataSource dataSource; + private static EmbeddingModel embeddingModel; + + static { + startOceanBaseContainer(); + } + + /** + * Starts the OceanBase container for tests. + */ + private static void startOceanBaseContainer() { + try { + // Using OceanBase's official Docker image + oceanBaseContainer = new GenericContainer<>("oceanbase/oceanbase-ce:4.3.5-lts") + .withExposedPorts(OCEANBASE_PORT) + .withEnv("MODE", "standalone") // For single-node deployment + // Wait for boot success message + .waitingFor(Wait.forLogMessage(".*boot success!.*", 1)); + + // Start the container + oceanBaseContainer.start(); + + // Get the mapped port and host + String jdbcUrl = String.format("jdbc:oceanbase://%s:%d/test", + oceanBaseContainer.getHost(), + oceanBaseContainer.getMappedPort(OCEANBASE_PORT)); + + log.info("OceanBase container started at {}", jdbcUrl); + + // Create a data source with the container's connection info + dataSource = new SimpleDataSource(jdbcUrl, "root@test", ""); + + // Create test database and setup + try (Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement()) { + // Set memory limit for vector columns + stmt.execute("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30;"); + // Create a test database if needed + stmt.execute("CREATE DATABASE IF NOT EXISTS test_example"); + stmt.execute("USE test_example"); + // Add any other initialization SQL you need + } catch (SQLException e) { + log.error("Failed to initialize OceanBase container", e); + throw new RuntimeException("Failed to initialize OceanBase container", e); + } + } catch (Exception e) { + log.error("Failed to start OceanBase container", e); + throw new RuntimeException("Failed to start OceanBase container", e); + } + } + + /** + * Creates a new embedding store for testing. + * + * @return A new embedding store. + */ + public static OceanBaseEmbeddingStore newEmbeddingStore() { + return OceanBaseEmbeddingStore.builder(getDataSource()) + .embeddingTable( + EmbeddingTable.builder(TABLE_NAME) + .vectorDimension(VECTOR_DIM) + .createOption(CreateOption.CREATE_OR_REPLACE) + .build()) + .build(); + } + + + /** + * Returns the data source for connecting to OceanBase. + * + * @return The data source. + */ + public static DataSource getDataSource() { + if (dataSource == null) { + throw new IllegalStateException("DataSource is not initialized. Container might not have started properly."); + } + return dataSource; + } + + /** + * Simple DataSource implementation for tests. + */ + private static class SimpleDataSource implements DataSource { + private final String url; + private final String user; + private final String password; + + public SimpleDataSource(String url, String user, String password) { + this.url = url; + this.user = user; + this.password = password; + } + + @Override + public Connection getConnection() throws SQLException { + return DriverManager.getConnection(url, user, password); + } + + @Override + public Connection getConnection(String username, String password) throws SQLException { + return DriverManager.getConnection(url, username, password); + } + + @Override + public java.io.PrintWriter getLogWriter() throws SQLException { + return null; + } + + @Override + public void setLogWriter(java.io.PrintWriter out) throws SQLException { + } + + @Override + public void setLoginTimeout(int seconds) throws SQLException { + } + + @Override + public int getLoginTimeout() throws SQLException { + return 0; + } + + @Override + public java.util.logging.Logger getParentLogger() { + return null; + } + + @Override + public T unwrap(Class iface) throws SQLException { + throw new SQLException("Unwrapping not supported"); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + return false; + } + } + + /** + * Drops the test table. + * + * @throws SQLException If an error occurs. + */ + public static void dropTable() throws SQLException { + try (Connection connection = getDataSource().getConnection(); + Statement statement = connection.createStatement()) { + statement.execute("DROP TABLE IF EXISTS " + TABLE_NAME); + log.info("Dropped table {}", TABLE_NAME); + } + } + + /** + * Stops the OceanBase container. + */ + public static void stopContainer() { + if (oceanBaseContainer != null && oceanBaseContainer.isRunning()) { + oceanBaseContainer.stop(); + log.info("OceanBase container stopped"); + } + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java new file mode 100644 index 00000000..7fb20c78 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java @@ -0,0 +1,242 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.BDDAssertions.within; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.sql.SQLException; +import java.util.Arrays; +import java.util.List; + +/** + * Integration test for OceanBaseEmbeddingStore. + */ +public class OceanBaseEmbeddingStoreIT { + + private OceanBaseEmbeddingStore embeddingStore; + + @BeforeEach + void setUp() { + embeddingStore = CommonTestOperations.newEmbeddingStore(); + } + + @BeforeAll + static void setUpAll() { + // The container is started automatically in CommonTestOperations static block + } + + @AfterAll + static void cleanUp() throws SQLException { + try { + CommonTestOperations.dropTable(); + } finally { + CommonTestOperations.stopContainer(); + } + } + + @Test + void should_add_embedding_and_find_it_by_similarity() { + // Given + Embedding embedding = TestData.randomEmbedding(); + + // When + String id = embeddingStore.add(embedding); + + // Then + List> matches = embeddingStore.search( + EmbeddingSearchRequest.builder() + .queryEmbedding(embedding) + .maxResults(1) + .build()) + .matches(); + + assertThat(matches).hasSize(1); + assertThat(matches.get(0).embeddingId()).isEqualTo(id); + assertThat(matches.get(0).embedding().vector()).usingComparatorWithPrecision(0.0001f).containsExactly(embedding.vector()); + assertThat(matches.get(0).score()).isCloseTo(1.0, within(0.01)); + } + + @Test + void should_add_embedding_with_text_segment_and_find_it_by_similarity() { + // Given + Embedding embedding = TestData.randomEmbedding(); + TextSegment segment = TextSegment.from("Test text", Metadata.from("key", "value")); + + // When + String id = embeddingStore.add(embedding, segment); + + // Then + List> matches = embeddingStore.search( + EmbeddingSearchRequest.builder() + .queryEmbedding(embedding) + .maxResults(1) + .build()) + .matches(); + + assertThat(matches).hasSize(1); + assertThat(matches.get(0).embeddingId()).isEqualTo(id); + assertThat(matches.get(0).embedding().vector()).usingComparatorWithPrecision(0.0001f).containsExactly(embedding.vector()); + assertThat(matches.get(0).embedded().text()).isEqualTo("Test text"); + assertThat(matches.get(0).embedded().metadata().getString("key")).isEqualTo("value"); + } + + @Test + void should_add_multiple_embeddings_and_find_them_by_similarity() { + // Given + Embedding[] embeddings = TestData.sampleEmbeddings(); + TextSegment[] segments = TestData.sampleTextSegments(); + + // When + List ids = embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); + + // Then + List> matches = embeddingStore.search( + EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(3) + .build()) + .matches(); + + assertThat(matches).hasSize(3); + + // The query vector [0.9, 1.0, 0.9] should be closest to fruits + // Order should be: orange, banana, apple + assertThat(matches.get(0).embedded().text()).isEqualTo("橙子"); + assertThat(matches.get(1).embedded().text()).isEqualTo("香蕉"); + assertThat(matches.get(2).embedded().text()).isEqualTo("苹果"); + } + + @Test + void should_remove_embeddings_by_collection() { + // Given + Embedding[] embeddings = TestData.sampleEmbeddings(); + TextSegment[] segments = TestData.sampleTextSegments(); + List ids = embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); + + // When - remove first 3 embeddings (fruits) + embeddingStore.removeAll(ids.subList(0, 3)); + + // Then - only vegetables should remain + List> matches = embeddingStore.search( + EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(10) + .build()) + .matches(); + + assertThat(matches).hasSize(3); + assertThat(matches).extracting(match -> match.embedded().metadata().getString("type")) + .containsOnly("vegetable"); + } + + @Test + void should_remove_embeddings_by_filter() { + // Given + Embedding[] embeddings = TestData.sampleEmbeddings(); + TextSegment[] segments = TestData.sampleTextSegments(); + embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); + + // When - remove all fruits + embeddingStore.removeAll(MetadataFilterBuilder.metadataKey("type").isEqualTo("fruit")); + + // Then - only vegetables should remain + List> matches = embeddingStore.search( + EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(10) + .build()) + .matches(); + + assertThat(matches).hasSize(3); + assertThat(matches).extracting(match -> match.embedded().metadata().getString("type")) + .containsOnly("vegetable"); + } + + @Test + void should_filter_by_metadata() { + // Given + Embedding[] embeddings = TestData.sampleEmbeddings(); + TextSegment[] segments = TestData.sampleTextSegments(); + embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); + + // When: filter by fruits only + List> matches = embeddingStore.search( + EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(6) + .filter(MetadataFilterBuilder.metadataKey("type").isEqualTo("fruit")) + .build()) + .matches(); + + // Then + assertThat(matches).hasSize(3); + assertThat(matches).extracting(match -> match.embedded().metadata().getString("type")) + .containsOnly("fruit"); + + // When: filter by red color + matches = embeddingStore.search( + EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(6) + .filter(MetadataFilterBuilder.metadataKey("color").isEqualTo("red")) + .build()) + .matches(); + + // Then + assertThat(matches).hasSize(2); + assertThat(matches).extracting(match -> match.embedded().metadata().getString("color")) + .containsOnly("red"); + } + + @Test + void should_remove_embedding() { + // Given + Embedding embedding = TestData.randomEmbedding(); + String id = embeddingStore.add(embedding); + + // When + embeddingStore.remove(id); + + // Then + List> matches = embeddingStore.search( + EmbeddingSearchRequest.builder() + .queryEmbedding(embedding) + .maxResults(1) + .build()) + .matches(); + + assertThat(matches).isEmpty(); + } + + @Test + void should_remove_all_embeddings() { + // Given + Embedding[] embeddings = TestData.sampleEmbeddings(); + TextSegment[] segments = TestData.sampleTextSegments(); + embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); + + // When + embeddingStore.removeAll(); + + // Then + List> matches = embeddingStore.search( + EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(10) + .build()) + .matches(); + + assertThat(matches).isEmpty(); + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/TestData.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/TestData.java new file mode 100644 index 00000000..11e664e5 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/TestData.java @@ -0,0 +1,81 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; + +import java.util.Map; +import java.util.Random; + +/** + * Utility class for generating test data. + */ +public class TestData { + + private static final Random random = new Random(666); + + /** + * Creates a random embedding with specified dimension. + * + * @param dimension Vector dimension. + * @return A random embedding. + */ + public static Embedding randomEmbedding(int dimension) { + float[] vector = new float[dimension]; + for (int i = 0; i < dimension; i++) { + vector[i] = random.nextFloat(); + } + return new Embedding(vector); + } + + /** + * Creates a random embedding with default dimension (3). + * + * @return A random embedding with default dimension. + */ + public static Embedding randomEmbedding() { + return randomEmbedding(3); + } + + /** + * Creates an array of sample embeddings and their corresponding text segments. + * The sample data mimics fruits and vegetables with their vector representation. + * + * @return An array of sample embeddings. + */ + public static Embedding[] sampleEmbeddings() { + return new Embedding[]{ + new Embedding(new float[]{1.2f, 0.7f, 1.1f}), + new Embedding(new float[]{0.6f, 1.2f, 0.8f}), + new Embedding(new float[]{1.1f, 1.1f, 0.9f}), + new Embedding(new float[]{5.3f, 4.8f, 5.4f}), + new Embedding(new float[]{4.9f, 5.3f, 4.8f}), + new Embedding(new float[]{5.2f, 4.9f, 5.1f}) + }; + } + + /** + * Creates an array of sample text segments. + * + * @return An array of sample text segments. + */ + public static TextSegment[] sampleTextSegments() { + return new TextSegment[]{ + TextSegment.from("苹果", Metadata.from(Map.of("type", "fruit", "color", "red"))), + TextSegment.from("香蕉", Metadata.from(Map.of("type", "fruit", "color", "yellow"))), + TextSegment.from("橙子", Metadata.from(Map.of("type", "fruit", "color", "orange"))), + TextSegment.from("胡萝卜", Metadata.from(Map.of("type", "vegetable", "color", "orange"))), + TextSegment.from("菠菜", Metadata.from(Map.of("type", "vegetable", "color", "green"))), + TextSegment.from("西红柿", Metadata.from(Map.of("type", "vegetable", "color", "red"))) + }; + } + + /** + * Creates a query embedding for testing similarity search. + * + * @return A query embedding. + */ + public static Embedding queryEmbedding() { + return new Embedding(new float[]{0.9f, 1.0f, 0.9f}); + } +} diff --git a/langchain4j-community-bom/pom.xml b/langchain4j-community-bom/pom.xml index 43819eae..caeba6b4 100644 --- a/langchain4j-community-bom/pom.xml +++ b/langchain4j-community-bom/pom.xml @@ -103,6 +103,12 @@ ${project.version} + + dev.langchain4j + langchain4j-community-oceanbase + ${project.version} + + dev.langchain4j diff --git a/pom.xml b/pom.xml index 9d254fb0..9e2d5cf6 100644 --- a/pom.xml +++ b/pom.xml @@ -48,6 +48,7 @@ embedding-stores/langchain4j-community-alloydb-pg embedding-stores/langchain4j-community-cloud-sql-pg embedding-stores/langchain4j-community-neo4j + embedding-stores/langchain4j-community-oceanbase content-retrievers/langchain4j-community-lucene From bbba7de13b9dd38c66149d9445e845979323cae7 Mon Sep 17 00:00:00 2001 From: Martin7-1 Date: Tue, 24 Jun 2025 23:55:46 +0800 Subject: [PATCH 2/6] feat: make format --- .../langchain4j-community-oceanbase/pom.xml | 4 +- .../embedding/oceanbase/EmbeddingTable.java | 37 ++++-- .../oceanbase/MetadataKeyMapper.java | 4 +- .../oceanbase/OceanBaseEmbeddingStore.java | 81 ++++++------ .../distance/CosineDistanceConverter.java | 2 +- .../distance/DefaultDistanceConverter.java | 2 +- .../oceanbase/distance/DistanceConverter.java | 4 +- .../distance/DistanceConverterFactory.java | 8 +- .../distance/EuclideanDistanceConverter.java | 2 +- .../search/OceanBaseSearchTemplate.java | 51 +++----- .../oceanbase/search/SearchTemplate.java | 19 ++- .../oceanbase/sql/SQLFilterFactory.java | 3 +- .../oceanbase/sql/SQLFilterVisitor.java | 69 +++++----- .../embedding/oceanbase/sql/SQLFilters.java | 1 - .../oceanbase/CommonTestOperations.java | 36 +++--- .../oceanbase/OceanBaseEmbeddingStoreIT.java | 122 +++++++++--------- .../store/embedding/oceanbase/TestData.java | 31 +++-- 17 files changed, 234 insertions(+), 242 deletions(-) diff --git a/embedding-stores/langchain4j-community-oceanbase/pom.xml b/embedding-stores/langchain4j-community-oceanbase/pom.xml index 20a7afb4..97c5d573 100644 --- a/embedding-stores/langchain4j-community-oceanbase/pom.xml +++ b/embedding-stores/langchain4j-community-oceanbase/pom.xml @@ -1,7 +1,5 @@ - + 4.0.0 dev.langchain4j diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java index d54dede2..6f4d785f 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java @@ -86,21 +86,38 @@ void create(DataSource dataSource) throws SQLException { if (createOption == CreateOption.CREATE_NONE) return; try (Connection connection = dataSource.getConnection(); - Statement statement = connection.createStatement()) { + Statement statement = connection.createStatement()) { if (createOption == CreateOption.CREATE_OR_REPLACE) { statement.addBatch("DROP TABLE IF EXISTS " + name); } StringBuilder createTableSql = new StringBuilder(); - createTableSql.append("CREATE TABLE IF NOT EXISTS ").append(name) - .append("(").append(idColumn).append(" VARCHAR(36) NOT NULL, ") - .append(embeddingColumn).append(" VECTOR(").append(vectorDimension).append("), ") - .append(textColumn).append(" VARCHAR(4000), ") - .append(metadataColumn).append(" JSON, ") - .append("PRIMARY KEY (").append(idColumn).append("), ") - .append("VECTOR INDEX ").append(vectorIndexName).append("(") - .append(embeddingColumn).append(") WITH (distance=").append(distanceMetric) - .append(", type=").append(indexType).append("))"); + createTableSql + .append("CREATE TABLE IF NOT EXISTS ") + .append(name) + .append("(") + .append(idColumn) + .append(" VARCHAR(36) NOT NULL, ") + .append(embeddingColumn) + .append(" VECTOR(") + .append(vectorDimension) + .append("), ") + .append(textColumn) + .append(" VARCHAR(4000), ") + .append(metadataColumn) + .append(" JSON, ") + .append("PRIMARY KEY (") + .append(idColumn) + .append("), ") + .append("VECTOR INDEX ") + .append(vectorIndexName) + .append("(") + .append(embeddingColumn) + .append(") WITH (distance=") + .append(distanceMetric) + .append(", type=") + .append(indexType) + .append("))"); statement.addBatch(createTableSql.toString()); statement.executeBatch(); diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/MetadataKeyMapper.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/MetadataKeyMapper.java index b06a5538..1a60fc9d 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/MetadataKeyMapper.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/MetadataKeyMapper.java @@ -5,7 +5,7 @@ * This allows for different mapping strategies without exposing internal implementations. */ public interface MetadataKeyMapper { - + /** * Maps a metadata key to a SQL expression that can be used in queries. * @@ -13,7 +13,7 @@ public interface MetadataKeyMapper { * @return A SQL expression that references the given metadata key */ String mapKey(String key); - + /** * Default implementation that uses JSON_VALUE function */ diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java index 9da74db3..e8d50323 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java @@ -3,6 +3,11 @@ import static dev.langchain4j.internal.Utils.randomUUID; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import dev.langchain4j.community.store.embedding.oceanbase.search.OceanBaseSearchTemplate; +import dev.langchain4j.community.store.embedding.oceanbase.search.SearchTemplate; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilter; import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilterFactory; import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; @@ -11,19 +16,11 @@ import dev.langchain4j.store.embedding.EmbeddingSearchResult; import dev.langchain4j.store.embedding.EmbeddingStore; import dev.langchain4j.store.embedding.filter.Filter; -import dev.langchain4j.community.store.embedding.oceanbase.search.OceanBaseSearchTemplate; -import dev.langchain4j.community.store.embedding.oceanbase.search.SearchTemplate; - -import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilter; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.sql.*; import java.util.*; import javax.sql.DataSource; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * OceanBase Vector EmbeddingStore Implementation @@ -98,8 +95,8 @@ private static RuntimeException uncheckSQLException(SQLException sqlException) { */ private static T ensureIndexNotNull(List list, int index, String name) { if (index < 0 || index >= list.size()) { - throw new IllegalArgumentException(String.format( - "Index %d is out of bounds for %s list of size %d", index, name, list.size())); + throw new IllegalArgumentException( + String.format("Index %d is out of bounds for %s list of size %d", index, name, list.size())); } T item = list.get(index); if (item == null) { @@ -130,13 +127,12 @@ public void add(String id, Embedding embedding) { ensureNotEmpty(id, "id"); ensureNotNull(embedding, "embedding"); - String sql = "INSERT INTO " + table.name() + " (" + - table.idColumn() + ", " + - table.embeddingColumn() + - ") VALUES (?, ?)"; + String sql = "INSERT INTO " + table.name() + " (" + table.idColumn() + + ", " + table.embeddingColumn() + + ") VALUES (?, ?)"; try (Connection connection = dataSource.getConnection(); - PreparedStatement statement = connection.prepareStatement(sql)) { + PreparedStatement statement = connection.prepareStatement(sql)) { statement.setString(1, id); statement.setString(2, vectorToString(embedding.vector())); statement.executeUpdate(); @@ -164,13 +160,12 @@ public List addAll(List embeddings) { List ids = generateIds(embeddings.size()); - String sql = "INSERT INTO " + table.name() + " (" + - table.idColumn() + ", " + - table.embeddingColumn() + - ") VALUES (?, ?)"; + String sql = "INSERT INTO " + table.name() + " (" + table.idColumn() + + ", " + table.embeddingColumn() + + ") VALUES (?, ?)"; try (Connection connection = dataSource.getConnection(); - PreparedStatement statement = connection.prepareStatement(sql)) { + PreparedStatement statement = connection.prepareStatement(sql)) { for (int i = 0; i < ids.size(); i++) { statement.setString(1, ids.get(i)); @@ -192,15 +187,14 @@ public String add(Embedding embedding, TextSegment textSegment) { String id = randomUUID(); - String sql = "INSERT INTO " + table.name() + " (" + - table.idColumn() + ", " + - table.embeddingColumn() + ", " + - table.textColumn() + ", " + - table.metadataColumn() + - ") VALUES (?, ?, ?, ?)"; + String sql = "INSERT INTO " + table.name() + " (" + table.idColumn() + + ", " + table.embeddingColumn() + + ", " + table.textColumn() + + ", " + table.metadataColumn() + + ") VALUES (?, ?, ?, ?)"; try (Connection connection = dataSource.getConnection(); - PreparedStatement statement = connection.prepareStatement(sql)) { + PreparedStatement statement = connection.prepareStatement(sql)) { statement.setString(1, id); statement.setString(2, vectorToString(embedding.vector())); statement.setString(3, textSegment.text()); @@ -248,15 +242,14 @@ public void addAll(List ids, List embeddings, List ids) { String sql = "DELETE FROM " + table.name() + " WHERE " + table.idColumn() + " IN (" + placeholders + ")"; try (Connection connection = dataSource.getConnection(); - PreparedStatement statement = connection.prepareStatement(sql)) { + PreparedStatement statement = connection.prepareStatement(sql)) { int index = 1; for (String id : ids) { @@ -327,8 +320,7 @@ public void removeAll(Collection ids) { public void removeAll(Filter filter) { ensureNotNull(filter, "filter"); - SQLFilter sqlFilter = SQLFilterFactory.create( - filter, (key, value) -> table.mapMetadataKey(key)); + SQLFilter sqlFilter = SQLFilterFactory.create(filter, (key, value) -> table.mapMetadataKey(key)); if (sqlFilter.matchesNoRows()) { return; @@ -342,7 +334,7 @@ public void removeAll(Filter filter) { } try (Connection connection = dataSource.getConnection(); - Statement statement = connection.createStatement()) { + Statement statement = connection.createStatement()) { statement.executeUpdate(sql); } catch (SQLException sqlException) { throw uncheckSQLException(sqlException); @@ -354,7 +346,7 @@ public void removeAll() { String sql = "DELETE FROM " + table.name(); try (Connection connection = dataSource.getConnection(); - Statement statement = connection.createStatement()) { + Statement statement = connection.createStatement()) { statement.executeUpdate(sql); } catch (SQLException sqlException) { throw uncheckSQLException(sqlException); @@ -456,9 +448,8 @@ public Builder embeddingTable(String tableName) { public Builder embeddingTable(String tableName, CreateOption createOption) { ensureNotNull(tableName, "tableName"); ensureNotNull(createOption, "createOption"); - this.embeddingTable = EmbeddingTable.builder(tableName) - .createOption(createOption) - .build(); + this.embeddingTable = + EmbeddingTable.builder(tableName).createOption(createOption).build(); return this; } diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/CosineDistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/CosineDistanceConverter.java index 349d0c38..57af43ab 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/CosineDistanceConverter.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/CosineDistanceConverter.java @@ -5,7 +5,7 @@ * For cosine distance, similarity = 1 - distance */ public class CosineDistanceConverter implements DistanceConverter { - + @Override public double toSimilarity(double distance) { return 1.0 - distance; diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DefaultDistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DefaultDistanceConverter.java index 413060f1..79bef41e 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DefaultDistanceConverter.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DefaultDistanceConverter.java @@ -6,7 +6,7 @@ * This works well for most distance metrics where smaller distance means higher similarity. */ public class DefaultDistanceConverter implements DistanceConverter { - + @Override public double toSimilarity(double distance) { return Math.exp(-distance); diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverter.java index b45a8cda..39d34f75 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverter.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverter.java @@ -5,10 +5,10 @@ * Different distance metrics require different conversion formulas. */ public interface DistanceConverter { - + /** * Converts a distance value to a similarity score. - * + * * @param distance The distance value to convert * @return A similarity score between 0 and 1, where 1 means exact match */ diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java index f824760a..c1e3168b 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java @@ -5,10 +5,10 @@ * Implements the Factory Method pattern. */ public class DistanceConverterFactory { - + /** * Returns the appropriate converter for the given distance metric. - * + * * @param metric The distance metric name (e.g., "cosine", "euclidean") * @return A DistanceConverter appropriate for the given metric */ @@ -16,9 +16,9 @@ public static DistanceConverter getConverter(String metric) { if (metric == null) { return new DefaultDistanceConverter(); } - + metric = metric.toLowerCase(); - + switch (metric) { case "cosine": return new CosineDistanceConverter(); diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java index 8979db87..94495801 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java @@ -5,7 +5,7 @@ * For euclidean distance, similarity = 1 / (1 + distance) */ public class EuclideanDistanceConverter implements DistanceConverter { - + @Override public double toSimilarity(double distance) { return 1.0 / (1.0 + distance); diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/OceanBaseSearchTemplate.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/OceanBaseSearchTemplate.java index 6b775406..b864f016 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/OceanBaseSearchTemplate.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/OceanBaseSearchTemplate.java @@ -1,28 +1,25 @@ package dev.langchain4j.community.store.embedding.oceanbase.search; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilterFactory; -import dev.langchain4j.data.document.Metadata; -import dev.langchain4j.data.embedding.Embedding; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; import dev.langchain4j.community.store.embedding.oceanbase.EmbeddingTable; import dev.langchain4j.community.store.embedding.oceanbase.distance.DistanceConverter; import dev.langchain4j.community.store.embedding.oceanbase.distance.DistanceConverterFactory; import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilter; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilterFactory; import dev.langchain4j.community.store.embedding.oceanbase.sql.SqlQueryBuilder; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; +import dev.langchain4j.data.document.Metadata; +import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.filter.Filter; -import dev.langchain4j.store.embedding.EmbeddingMatch; - +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; /** * Concrete implementation of SearchTemplate for OceanBase. @@ -140,11 +137,7 @@ protected String buildSearchQuery(EmbeddingSearchRequest request) { // Use SqlQueryBuilder (Builder pattern) to construct the query SqlQueryBuilder builder = SqlQueryBuilder.select() .columns(table.idColumn(), table.embeddingColumn(), table.textColumn(), table.metadataColumn()) - .function( - table.distanceMetric().toLowerCase() + "_distance", - table.embeddingColumn(), - "?" - ) + .function(table.distanceMetric().toLowerCase() + "_distance", table.embeddingColumn(), "?") .from(table.name()); // Add WHERE clause for filtering @@ -157,11 +150,7 @@ protected String buildSearchQuery(EmbeddingSearchRequest request) { } // Add ORDER BY clause - builder.orderByFunction( - table.distanceMetric().toLowerCase() + "_distance", - table.embeddingColumn(), - "?" - ); + builder.orderByFunction(table.distanceMetric().toLowerCase() + "_distance", table.embeddingColumn(), "?"); // Use APPROXIMATE keyword if not exact search if (!isExactSearch) { @@ -182,7 +171,8 @@ protected void setParameters(PreparedStatement statement, EmbeddingSearchRequest } @Override - protected List> processResults(ResultSet resultSet, EmbeddingSearchRequest request) throws SQLException { + protected List> processResults(ResultSet resultSet, EmbeddingSearchRequest request) + throws SQLException { List> matches = new ArrayList<>(); while (resultSet.next()) { @@ -203,9 +193,7 @@ protected List> processResults(ResultSet resultSet, Embedding embedding = new Embedding(stringToVector(vectorStr)); Metadata metadata = jsonToMetadata(metadataJson); - TextSegment segment = (text != null) ? - TextSegment.from(text, metadata) : - null; + TextSegment segment = (text != null) ? TextSegment.from(text, metadata) : null; matches.add(new EmbeddingMatch<>(score, id, embedding, segment)); } @@ -219,8 +207,8 @@ protected String buildWhereClause(Filter filter) { return null; } - SQLFilter sqlFilter = SQLFilterFactory.create(filter, - (key, value) -> table.getMetadataKeyMapper().mapKey(key)); + SQLFilter sqlFilter = SQLFilterFactory.create( + filter, (key, value) -> table.getMetadataKeyMapper().mapKey(key)); if (sqlFilter.matchesAllRows()) { return null; @@ -229,4 +217,3 @@ protected String buildWhereClause(Filter filter) { return sqlFilter.toSql(); } } - diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/SearchTemplate.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/SearchTemplate.java index 17785c29..3e6ddb73 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/SearchTemplate.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/search/SearchTemplate.java @@ -1,20 +1,17 @@ package dev.langchain4j.community.store.embedding.oceanbase.search; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; +import dev.langchain4j.store.embedding.filter.Filter; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import java.util.List; - -import dev.langchain4j.store.embedding.EmbeddingMatch; -import dev.langchain4j.store.embedding.EmbeddingSearchRequest; -import dev.langchain4j.store.embedding.EmbeddingSearchResult; -import dev.langchain4j.store.embedding.filter.Filter; - import javax.sql.DataSource; - /** * Template for the search operation. * Implements the Template Method pattern to define the skeleton of the search algorithm, @@ -41,7 +38,7 @@ public final EmbeddingSearchResult search(DataSource dataSource, EmbeddingSea String query = buildSearchQuery(request); try (Connection connection = dataSource.getConnection(); - PreparedStatement statement = connection.prepareStatement(query)) { + PreparedStatement statement = connection.prepareStatement(query)) { setParameters(statement, request); @@ -86,7 +83,8 @@ protected void validateRequest(EmbeddingSearchRequest request) { * @param request The search request * @throws SQLException If an error occurs setting the parameters */ - protected abstract void setParameters(PreparedStatement statement, EmbeddingSearchRequest request) throws SQLException; + protected abstract void setParameters(PreparedStatement statement, EmbeddingSearchRequest request) + throws SQLException; /** * Processes the result set and creates embedding matches. @@ -97,7 +95,8 @@ protected void validateRequest(EmbeddingSearchRequest request) { * @return A list of embedding matches * @throws SQLException If an error occurs processing the result set */ - protected abstract List> processResults(ResultSet resultSet, EmbeddingSearchRequest request) throws SQLException; + protected abstract List> processResults(ResultSet resultSet, EmbeddingSearchRequest request) + throws SQLException; /** * Converts a filter to a WHERE clause. diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterFactory.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterFactory.java index 5b418860..a9fbfb34 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterFactory.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterFactory.java @@ -1,8 +1,7 @@ package dev.langchain4j.community.store.embedding.oceanbase.sql; -import java.util.function.BiFunction; - import dev.langchain4j.store.embedding.filter.Filter; +import java.util.function.BiFunction; /** * Factory for creating SQL filters. diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterVisitor.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterVisitor.java index c2de1d00..31a628c5 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterVisitor.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilterVisitor.java @@ -1,10 +1,8 @@ package dev.langchain4j.community.store.embedding.oceanbase.sql; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.function.BiFunction; - +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilters.MatchAllSQLFilter; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilters.MatchNoSQLFilter; +import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilters.SimpleSQLFilter; import dev.langchain4j.store.embedding.filter.Filter; import dev.langchain4j.store.embedding.filter.comparison.IsEqualTo; import dev.langchain4j.store.embedding.filter.comparison.IsGreaterThan; @@ -17,9 +15,10 @@ import dev.langchain4j.store.embedding.filter.logical.And; import dev.langchain4j.store.embedding.filter.logical.Not; import dev.langchain4j.store.embedding.filter.logical.Or; -import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilters.MatchAllSQLFilter; -import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilters.MatchNoSQLFilter; -import dev.langchain4j.community.store.embedding.oceanbase.sql.SQLFilters.SimpleSQLFilter; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.function.BiFunction; /** * Concrete visitor implementation that converts Filter objects to SQLFilter objects. @@ -27,10 +26,10 @@ */ public class SQLFilterVisitor implements FilterVisitor { private final BiFunction keyMapper; - + /** * Creates a new SQLFilterVisitor with the given key mapper. - * + * * @param keyMapper Function that maps a key and value to a SQL column expression */ public SQLFilterVisitor(BiFunction keyMapper) { @@ -39,7 +38,7 @@ public SQLFilterVisitor(BiFunction keyMapper) { /** * Helper method to convert a value to its SQL string representation. - * + * * @param value The value to convert * @return SQL string representation of the value */ @@ -59,7 +58,7 @@ private String valueToSql(Object value) { public SQLFilter visit(IsEqualTo filter) { String expr = keyMapper.apply(filter.key(), filter.comparisonValue()); Object value = filter.comparisonValue(); - + if (value == null) { return new SimpleSQLFilter(expr + " IS NULL"); } else { @@ -71,7 +70,7 @@ public SQLFilter visit(IsEqualTo filter) { public SQLFilter visit(IsNotEqualTo filter) { String expr = keyMapper.apply(filter.key(), filter.comparisonValue()); Object value = filter.comparisonValue(); - + if (value == null) { return new SimpleSQLFilter(expr + " IS NOT NULL"); } else { @@ -109,12 +108,12 @@ public SQLFilter visit(IsIn filter) { if (values == null || values.isEmpty()) { return new MatchNoSQLFilter(); } - + List valueStrings = new ArrayList<>(values.size()); for (Object value : values) { valueStrings.add(valueToSql(value)); } - + String expr = keyMapper.apply(filter.key(), values.iterator().next()); return new SimpleSQLFilter(expr + " IN (" + String.join(", ", valueStrings) + ")"); } @@ -125,42 +124,43 @@ public SQLFilter visit(IsNotIn filter) { if (values == null || values.isEmpty()) { return new MatchAllSQLFilter(); } - + List valueStrings = new ArrayList<>(values.size()); for (Object value : values) { valueStrings.add(valueToSql(value)); } - + String expr = keyMapper.apply(filter.key(), values.iterator().next()); - return new SimpleSQLFilter("(" + expr + " NOT IN (" + String.join(", ", valueStrings) + ") OR " + expr + " IS NULL)"); + return new SimpleSQLFilter( + "(" + expr + " NOT IN (" + String.join(", ", valueStrings) + ") OR " + expr + " IS NULL)"); } @Override public SQLFilter visit(And filter) { SQLFilter leftFilter = process(filter.left()); - + // If left side matches no rows, the entire AND also matches no rows if (leftFilter.matchesNoRows()) { return new MatchNoSQLFilter(); } - + // If left side matches all rows, the result depends on the right side if (leftFilter.matchesAllRows()) { return process(filter.right()); } - + SQLFilter rightFilter = process(filter.right()); - + // If right side matches no rows, the entire AND also matches no rows if (rightFilter.matchesNoRows()) { return new MatchNoSQLFilter(); } - + // If right side matches all rows, the result depends on the left side if (rightFilter.matchesAllRows()) { return leftFilter; } - + // Both sides need to participate in the AND operation return new SimpleSQLFilter("(" + leftFilter.toSql() + ") AND (" + rightFilter.toSql() + ")"); } @@ -168,29 +168,29 @@ public SQLFilter visit(And filter) { @Override public SQLFilter visit(Or filter) { SQLFilter leftFilter = process(filter.left()); - + // If left side matches all rows, the entire OR also matches all rows if (leftFilter.matchesAllRows()) { return new MatchAllSQLFilter(); } - + // If left side matches no rows, the result depends on the right side if (leftFilter.matchesNoRows()) { return process(filter.right()); } - + SQLFilter rightFilter = process(filter.right()); - + // If right side matches all rows, the entire OR also matches all rows if (rightFilter.matchesAllRows()) { return new MatchAllSQLFilter(); } - + // If right side matches no rows, the result depends on the left side if (rightFilter.matchesNoRows()) { return leftFilter; } - + // Both sides need to participate in the OR operation return new SimpleSQLFilter("(" + leftFilter.toSql() + ") OR (" + rightFilter.toSql() + ")"); } @@ -198,7 +198,7 @@ public SQLFilter visit(Or filter) { @Override public SQLFilter visit(Not filter) { SQLFilter expressionFilter = process(filter.expression()); - + if (expressionFilter.matchesAllRows()) { return new MatchNoSQLFilter(); } else if (expressionFilter.matchesNoRows()) { @@ -207,10 +207,10 @@ public SQLFilter visit(Not filter) { return new SimpleSQLFilter("NOT (" + expressionFilter.toSql() + ")"); } } - + /** * Process any Filter object by making it accept this visitor. - * + * * @param filter The filter to process * @return The resulting SQLFilter */ @@ -238,7 +238,8 @@ public SQLFilter process(Filter filter) { } else if (filter instanceof Not) { return visit((Not) filter); } else { - throw new IllegalArgumentException("Unsupported filter type: " + filter.getClass().getName()); + throw new IllegalArgumentException( + "Unsupported filter type: " + filter.getClass().getName()); } } } diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilters.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilters.java index 49ca90c8..1e2cafd3 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilters.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/sql/SQLFilters.java @@ -104,4 +104,3 @@ public boolean matchesNoRows() { } } } - diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java index 325d4eb1..025e3786 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java @@ -1,14 +1,12 @@ package dev.langchain4j.community.store.embedding.oceanbase; import dev.langchain4j.model.embedding.EmbeddingModel; - -import javax.sql.DataSource; import java.sql.Connection; import java.sql.DriverManager; import java.sql.SQLException; import java.sql.Statement; import java.util.UUID; - +import javax.sql.DataSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testcontainers.containers.GenericContainer; @@ -21,7 +19,8 @@ public class CommonTestOperations { private static final Logger log = LoggerFactory.getLogger(CommonTestOperations.class); - private static final String TABLE_NAME = "langchain4j_oceanbase_test_" + UUID.randomUUID().toString().replace("-", ""); + private static final String TABLE_NAME = + "langchain4j_oceanbase_test_" + UUID.randomUUID().toString().replace("-", ""); private static final int VECTOR_DIM = 3; // OceanBase container configuration @@ -50,9 +49,9 @@ private static void startOceanBaseContainer() { oceanBaseContainer.start(); // Get the mapped port and host - String jdbcUrl = String.format("jdbc:oceanbase://%s:%d/test", - oceanBaseContainer.getHost(), - oceanBaseContainer.getMappedPort(OCEANBASE_PORT)); + String jdbcUrl = String.format( + "jdbc:oceanbase://%s:%d/test", + oceanBaseContainer.getHost(), oceanBaseContainer.getMappedPort(OCEANBASE_PORT)); log.info("OceanBase container started at {}", jdbcUrl); @@ -61,7 +60,7 @@ private static void startOceanBaseContainer() { // Create test database and setup try (Connection conn = dataSource.getConnection(); - Statement stmt = conn.createStatement()) { + Statement stmt = conn.createStatement()) { // Set memory limit for vector columns stmt.execute("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30;"); // Create a test database if needed @@ -85,15 +84,13 @@ private static void startOceanBaseContainer() { */ public static OceanBaseEmbeddingStore newEmbeddingStore() { return OceanBaseEmbeddingStore.builder(getDataSource()) - .embeddingTable( - EmbeddingTable.builder(TABLE_NAME) - .vectorDimension(VECTOR_DIM) - .createOption(CreateOption.CREATE_OR_REPLACE) - .build()) + .embeddingTable(EmbeddingTable.builder(TABLE_NAME) + .vectorDimension(VECTOR_DIM) + .createOption(CreateOption.CREATE_OR_REPLACE) + .build()) .build(); } - /** * Returns the data source for connecting to OceanBase. * @@ -101,7 +98,8 @@ public static OceanBaseEmbeddingStore newEmbeddingStore() { */ public static DataSource getDataSource() { if (dataSource == null) { - throw new IllegalStateException("DataSource is not initialized. Container might not have started properly."); + throw new IllegalStateException( + "DataSource is not initialized. Container might not have started properly."); } return dataSource; } @@ -136,12 +134,10 @@ public java.io.PrintWriter getLogWriter() throws SQLException { } @Override - public void setLogWriter(java.io.PrintWriter out) throws SQLException { - } + public void setLogWriter(java.io.PrintWriter out) throws SQLException {} @Override - public void setLoginTimeout(int seconds) throws SQLException { - } + public void setLoginTimeout(int seconds) throws SQLException {} @Override public int getLoginTimeout() throws SQLException { @@ -171,7 +167,7 @@ public boolean isWrapperFor(Class iface) throws SQLException { */ public static void dropTable() throws SQLException { try (Connection connection = getDataSource().getConnection(); - Statement statement = connection.createStatement()) { + Statement statement = connection.createStatement()) { statement.execute("DROP TABLE IF EXISTS " + TABLE_NAME); log.info("Dropped table {}", TABLE_NAME); } diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java index 7fb20c78..3844caa1 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java @@ -9,16 +9,14 @@ import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingSearchRequest; import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder; - +import java.sql.SQLException; +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import java.sql.SQLException; -import java.util.Arrays; -import java.util.List; - /** * Integration test for OceanBaseEmbeddingStore. */ @@ -54,16 +52,18 @@ void should_add_embedding_and_find_it_by_similarity() { String id = embeddingStore.add(embedding); // Then - List> matches = embeddingStore.search( - EmbeddingSearchRequest.builder() - .queryEmbedding(embedding) - .maxResults(1) - .build()) + List> matches = embeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(embedding) + .maxResults(1) + .build()) .matches(); assertThat(matches).hasSize(1); assertThat(matches.get(0).embeddingId()).isEqualTo(id); - assertThat(matches.get(0).embedding().vector()).usingComparatorWithPrecision(0.0001f).containsExactly(embedding.vector()); + assertThat(matches.get(0).embedding().vector()) + .usingComparatorWithPrecision(0.0001f) + .containsExactly(embedding.vector()); assertThat(matches.get(0).score()).isCloseTo(1.0, within(0.01)); } @@ -77,16 +77,18 @@ void should_add_embedding_with_text_segment_and_find_it_by_similarity() { String id = embeddingStore.add(embedding, segment); // Then - List> matches = embeddingStore.search( - EmbeddingSearchRequest.builder() - .queryEmbedding(embedding) - .maxResults(1) - .build()) + List> matches = embeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(embedding) + .maxResults(1) + .build()) .matches(); assertThat(matches).hasSize(1); assertThat(matches.get(0).embeddingId()).isEqualTo(id); - assertThat(matches.get(0).embedding().vector()).usingComparatorWithPrecision(0.0001f).containsExactly(embedding.vector()); + assertThat(matches.get(0).embedding().vector()) + .usingComparatorWithPrecision(0.0001f) + .containsExactly(embedding.vector()); assertThat(matches.get(0).embedded().text()).isEqualTo("Test text"); assertThat(matches.get(0).embedded().metadata().getString("key")).isEqualTo("value"); } @@ -101,11 +103,11 @@ void should_add_multiple_embeddings_and_find_them_by_similarity() { List ids = embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); // Then - List> matches = embeddingStore.search( - EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(3) - .build()) + List> matches = embeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(3) + .build()) .matches(); assertThat(matches).hasSize(3); @@ -128,15 +130,16 @@ void should_remove_embeddings_by_collection() { embeddingStore.removeAll(ids.subList(0, 3)); // Then - only vegetables should remain - List> matches = embeddingStore.search( - EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(10) - .build()) + List> matches = embeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(10) + .build()) .matches(); assertThat(matches).hasSize(3); - assertThat(matches).extracting(match -> match.embedded().metadata().getString("type")) + assertThat(matches) + .extracting(match -> match.embedded().metadata().getString("type")) .containsOnly("vegetable"); } @@ -151,15 +154,16 @@ void should_remove_embeddings_by_filter() { embeddingStore.removeAll(MetadataFilterBuilder.metadataKey("type").isEqualTo("fruit")); // Then - only vegetables should remain - List> matches = embeddingStore.search( - EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(10) - .build()) + List> matches = embeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(10) + .build()) .matches(); assertThat(matches).hasSize(3); - assertThat(matches).extracting(match -> match.embedded().metadata().getString("type")) + assertThat(matches) + .extracting(match -> match.embedded().metadata().getString("type")) .containsOnly("vegetable"); } @@ -171,31 +175,33 @@ void should_filter_by_metadata() { embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); // When: filter by fruits only - List> matches = embeddingStore.search( - EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(6) - .filter(MetadataFilterBuilder.metadataKey("type").isEqualTo("fruit")) - .build()) + List> matches = embeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(6) + .filter(MetadataFilterBuilder.metadataKey("type").isEqualTo("fruit")) + .build()) .matches(); // Then assertThat(matches).hasSize(3); - assertThat(matches).extracting(match -> match.embedded().metadata().getString("type")) + assertThat(matches) + .extracting(match -> match.embedded().metadata().getString("type")) .containsOnly("fruit"); // When: filter by red color - matches = embeddingStore.search( - EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(6) - .filter(MetadataFilterBuilder.metadataKey("color").isEqualTo("red")) - .build()) + matches = embeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(6) + .filter(MetadataFilterBuilder.metadataKey("color").isEqualTo("red")) + .build()) .matches(); // Then assertThat(matches).hasSize(2); - assertThat(matches).extracting(match -> match.embedded().metadata().getString("color")) + assertThat(matches) + .extracting(match -> match.embedded().metadata().getString("color")) .containsOnly("red"); } @@ -209,11 +215,11 @@ void should_remove_embedding() { embeddingStore.remove(id); // Then - List> matches = embeddingStore.search( - EmbeddingSearchRequest.builder() - .queryEmbedding(embedding) - .maxResults(1) - .build()) + List> matches = embeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(embedding) + .maxResults(1) + .build()) .matches(); assertThat(matches).isEmpty(); @@ -230,11 +236,11 @@ void should_remove_all_embeddings() { embeddingStore.removeAll(); // Then - List> matches = embeddingStore.search( - EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(10) - .build()) + List> matches = embeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(10) + .build()) .matches(); assertThat(matches).isEmpty(); diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/TestData.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/TestData.java index 11e664e5..14e5c3c6 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/TestData.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/TestData.java @@ -3,7 +3,6 @@ import dev.langchain4j.data.document.Metadata; import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; - import java.util.Map; import java.util.Random; @@ -44,13 +43,13 @@ public static Embedding randomEmbedding() { * @return An array of sample embeddings. */ public static Embedding[] sampleEmbeddings() { - return new Embedding[]{ - new Embedding(new float[]{1.2f, 0.7f, 1.1f}), - new Embedding(new float[]{0.6f, 1.2f, 0.8f}), - new Embedding(new float[]{1.1f, 1.1f, 0.9f}), - new Embedding(new float[]{5.3f, 4.8f, 5.4f}), - new Embedding(new float[]{4.9f, 5.3f, 4.8f}), - new Embedding(new float[]{5.2f, 4.9f, 5.1f}) + return new Embedding[] { + new Embedding(new float[] {1.2f, 0.7f, 1.1f}), + new Embedding(new float[] {0.6f, 1.2f, 0.8f}), + new Embedding(new float[] {1.1f, 1.1f, 0.9f}), + new Embedding(new float[] {5.3f, 4.8f, 5.4f}), + new Embedding(new float[] {4.9f, 5.3f, 4.8f}), + new Embedding(new float[] {5.2f, 4.9f, 5.1f}) }; } @@ -60,13 +59,13 @@ public static Embedding[] sampleEmbeddings() { * @return An array of sample text segments. */ public static TextSegment[] sampleTextSegments() { - return new TextSegment[]{ - TextSegment.from("苹果", Metadata.from(Map.of("type", "fruit", "color", "red"))), - TextSegment.from("香蕉", Metadata.from(Map.of("type", "fruit", "color", "yellow"))), - TextSegment.from("橙子", Metadata.from(Map.of("type", "fruit", "color", "orange"))), - TextSegment.from("胡萝卜", Metadata.from(Map.of("type", "vegetable", "color", "orange"))), - TextSegment.from("菠菜", Metadata.from(Map.of("type", "vegetable", "color", "green"))), - TextSegment.from("西红柿", Metadata.from(Map.of("type", "vegetable", "color", "red"))) + return new TextSegment[] { + TextSegment.from("苹果", Metadata.from(Map.of("type", "fruit", "color", "red"))), + TextSegment.from("香蕉", Metadata.from(Map.of("type", "fruit", "color", "yellow"))), + TextSegment.from("橙子", Metadata.from(Map.of("type", "fruit", "color", "orange"))), + TextSegment.from("胡萝卜", Metadata.from(Map.of("type", "vegetable", "color", "orange"))), + TextSegment.from("菠菜", Metadata.from(Map.of("type", "vegetable", "color", "green"))), + TextSegment.from("西红柿", Metadata.from(Map.of("type", "vegetable", "color", "red"))) }; } @@ -76,6 +75,6 @@ public static TextSegment[] sampleTextSegments() { * @return A query embedding. */ public static Embedding queryEmbedding() { - return new Embedding(new float[]{0.9f, 1.0f, 0.9f}); + return new Embedding(new float[] {0.9f, 1.0f, 0.9f}); } } From d617d32411dce8a48932329d8072614a7e986d1e Mon Sep 17 00:00:00 2001 From: SimonChou <2484152300@qq.com> Date: Thu, 26 Jun 2025 00:44:18 +0800 Subject: [PATCH 3/6] [FEATURE] Adjust some suggestions in the ocean embedding store review --- .../langchain4j-community-oceanbase/pom.xml | 14 ++- .../embedding/oceanbase/EmbeddingTable.java | 54 ++++----- .../oceanbase/OceanBaseEmbeddingStore.java | 69 ++--------- .../oceanbase/UncheckSQLException.java | 21 ++++ .../embedding/oceanbase/ValidationUtils.java | 50 ++++++++ .../distance/DefaultDistanceConverter.java | 14 --- .../distance/DistanceConverterFactory.java | 12 +- .../distance/EuclideanDistanceConverter.java | 7 +- .../distance/ManhattanDistanceConverter.java | 14 +++ .../oceanbase/CommonTestOperations.java | 1 + .../oceanbase/OceanBaseEmbeddingStoreIT.java | 111 ++++++++++++++++++ 11 files changed, 254 insertions(+), 113 deletions(-) create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/UncheckSQLException.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/ValidationUtils.java delete mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DefaultDistanceConverter.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/ManhattanDistanceConverter.java diff --git a/embedding-stores/langchain4j-community-oceanbase/pom.xml b/embedding-stores/langchain4j-community-oceanbase/pom.xml index 97c5d573..c42504ba 100644 --- a/embedding-stores/langchain4j-community-oceanbase/pom.xml +++ b/embedding-stores/langchain4j-community-oceanbase/pom.xml @@ -23,7 +23,7 @@ com.oceanbase oceanbase-client - 2.4.5 + 2.4.12 @@ -41,7 +41,7 @@ dev.langchain4j langchain4j-core - 1.1.0-SNAPSHOT + ${langchain4j.core.version} tests test-jar test @@ -72,8 +72,14 @@ - ch.qos.logback - logback-classic + org.tinylog + tinylog-impl + test + + + + org.tinylog + slf4j-tinylog test diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java index 6f4d785f..1062c537 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/EmbeddingTable.java @@ -91,35 +91,29 @@ void create(DataSource dataSource) throws SQLException { statement.addBatch("DROP TABLE IF EXISTS " + name); } - StringBuilder createTableSql = new StringBuilder(); - createTableSql - .append("CREATE TABLE IF NOT EXISTS ") - .append(name) - .append("(") - .append(idColumn) - .append(" VARCHAR(36) NOT NULL, ") - .append(embeddingColumn) - .append(" VECTOR(") - .append(vectorDimension) - .append("), ") - .append(textColumn) - .append(" VARCHAR(4000), ") - .append(metadataColumn) - .append(" JSON, ") - .append("PRIMARY KEY (") - .append(idColumn) - .append("), ") - .append("VECTOR INDEX ") - .append(vectorIndexName) - .append("(") - .append(embeddingColumn) - .append(") WITH (distance=") - .append(distanceMetric) - .append(", type=") - .append(indexType) - .append("))"); - - statement.addBatch(createTableSql.toString()); + statement.addBatch( + """ + CREATE TABLE IF NOT EXISTS %s( + %s VARCHAR(36) NOT NULL, + %s VECTOR(%d), + %s VARCHAR(4000), + %s JSON, + PRIMARY KEY (%s), + VECTOR INDEX %s(%s) WITH (distance=%s, type=%s) + ) + """ + .formatted( + name, + idColumn, + embeddingColumn, + vectorDimension, + textColumn, + metadataColumn, + idColumn, + vectorIndexName, + embeddingColumn, + distanceMetric, + indexType)); statement.executeBatch(); } } @@ -246,7 +240,7 @@ public static final class Builder { private String metadataColumn = "metadata"; private int vectorDimension = 1536; private String vectorIndexName = "idx_vector"; - private String distanceMetric = "L2"; + private String distanceMetric = "cosine"; private String indexType = "hnsw"; private CreateOption createOption = CreateOption.CREATE_IF_NOT_EXISTS; diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java index e8d50323..0771352f 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java @@ -1,5 +1,7 @@ package dev.langchain4j.community.store.embedding.oceanbase; +import static dev.langchain4j.community.store.embedding.oceanbase.ValidationUtils.ensureIndexNotNull; +import static dev.langchain4j.community.store.embedding.oceanbase.ValidationUtils.ensureNotEmpty; import static dev.langchain4j.internal.Utils.randomUUID; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; @@ -71,48 +73,10 @@ private OceanBaseEmbeddingStore(Builder builder) { try { table.create(dataSource); } catch (SQLException sqlException) { - throw uncheckSQLException(sqlException); + throw new UncheckSQLException(sqlException); } } - /** - * Transforms a SQLException into a RuntimeException. - * @param sqlException SQLException to transform. Not null. - * @return A RuntimeException that wraps the SQLException. - */ - private static RuntimeException uncheckSQLException(SQLException sqlException) { - return new RuntimeException("SQL error: " + sqlException.getMessage(), sqlException); - } - - /** - * Returns a null-safe item from a List. - * @param list List to get an item from. - * @param index Index of the item to get. - * @param name Name of the parameter, for error messages. - * @param Type of the list items. - * @return The item at the given index. - * @throws IllegalArgumentException If the item is null. - */ - private static T ensureIndexNotNull(List list, int index, String name) { - if (index < 0 || index >= list.size()) { - throw new IllegalArgumentException( - String.format("Index %d is out of bounds for %s list of size %d", index, name, list.size())); - } - T item = list.get(index); - if (item == null) { - throw new IllegalArgumentException(String.format("%s[%d] is null", name, index)); - } - return item; - } - - private static String ensureNotEmpty(String value, String name) { - ensureNotNull(value, name); - if (value.trim().isEmpty()) { - throw new IllegalArgumentException(name + " cannot be empty"); - } - return value; - } - @Override public String add(Embedding embedding) { ensureNotNull(embedding, "embedding"); @@ -137,25 +101,16 @@ public void add(String id, Embedding embedding) { statement.setString(2, vectorToString(embedding.vector())); statement.executeUpdate(); } catch (SQLException sqlException) { - throw uncheckSQLException(sqlException); + throw new UncheckSQLException(sqlException); } } - @Override - public List generateIds(final int n) { - List ids = new ArrayList<>(n); - for (int i = 0; i < n; i++) { - ids.add(randomUUID()); - } - return ids; - } - @Override public List addAll(List embeddings) { ensureNotNull(embeddings, "embeddings"); if (embeddings.isEmpty()) { - return Collections.emptyList(); + return List.of(); } List ids = generateIds(embeddings.size()); @@ -176,7 +131,7 @@ public List addAll(List embeddings) { statement.executeBatch(); return ids; } catch (SQLException sqlException) { - throw uncheckSQLException(sqlException); + throw new UncheckSQLException(sqlException); } } @@ -202,7 +157,7 @@ public String add(Embedding embedding, TextSegment textSegment) { statement.executeUpdate(); return id; } catch (SQLException sqlException) { - throw uncheckSQLException(sqlException); + throw new UncheckSQLException(sqlException); } } @@ -265,7 +220,7 @@ public void addAll(List ids, List embeddings, List ids) { statement.executeUpdate(); } catch (SQLException sqlException) { - throw uncheckSQLException(sqlException); + throw new UncheckSQLException(sqlException); } } @@ -337,7 +292,7 @@ public void removeAll(Filter filter) { Statement statement = connection.createStatement()) { statement.executeUpdate(sql); } catch (SQLException sqlException) { - throw uncheckSQLException(sqlException); + throw new UncheckSQLException(sqlException); } } @@ -349,7 +304,7 @@ public void removeAll() { Statement statement = connection.createStatement()) { statement.executeUpdate(sql); } catch (SQLException sqlException) { - throw uncheckSQLException(sqlException); + throw new UncheckSQLException(sqlException); } } diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/UncheckSQLException.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/UncheckSQLException.java new file mode 100644 index 00000000..4a2d8d2c --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/UncheckSQLException.java @@ -0,0 +1,21 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +import dev.langchain4j.exception.LangChain4jException; + +/** + * Exception for OceanBase SQL related errors + */ +public class UncheckSQLException extends LangChain4jException { + + public UncheckSQLException(String message, Throwable cause) { + super(message, cause); + } + + public UncheckSQLException(String message) { + super(message); + } + + public UncheckSQLException(Throwable cause) { + super("SQL error: " + cause.getMessage(), cause); + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/ValidationUtils.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/ValidationUtils.java new file mode 100644 index 00000000..477f753d --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/ValidationUtils.java @@ -0,0 +1,50 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; + +import java.util.List; + +/** + * Utility class for validating method arguments. + */ +public final class ValidationUtils { + + private ValidationUtils() {} + + /** + * Returns a null-safe item from a List. + * @param list List to get an item from. + * @param index Index of the item to get. + * @param name Name of the parameter, for error messages. + * @param Type of the list items. + * @return The item at the given index. + * @throws IllegalArgumentException If the item is null. + */ + public static T ensureIndexNotNull(List list, int index, String name) { + if (index < 0 || index >= list.size()) { + throw new IllegalArgumentException( + String.format("Index %d is out of bounds for %s list of size %d", index, name, list.size())); + } + T item = list.get(index); + if (item == null) { + throw new IllegalArgumentException(String.format("%s[%d] is null", name, index)); + } + return item; + } + + /** + * Ensures that the given string is not null and not empty. + * + * @param value The string to check. + * @param name The name of the string to be used in the exception message. + * @return The string if it is not null and not empty. + * @throws IllegalArgumentException if the string is null or empty. + */ + public static String ensureNotEmpty(String value, String name) { + ensureNotNull(value, name); + if (value.trim().isEmpty()) { + throw new IllegalArgumentException(name + " cannot be empty"); + } + return value; + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DefaultDistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DefaultDistanceConverter.java deleted file mode 100644 index 79bef41e..00000000 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DefaultDistanceConverter.java +++ /dev/null @@ -1,14 +0,0 @@ -package dev.langchain4j.community.store.embedding.oceanbase.distance; - -/** - * Default distance to similarity converter. - * Uses exponential decay: similarity = e^(-distance) - * This works well for most distance metrics where smaller distance means higher similarity. - */ -public class DefaultDistanceConverter implements DistanceConverter { - - @Override - public double toSimilarity(double distance) { - return Math.exp(-distance); - } -} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java index c1e3168b..d5c59dbe 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java @@ -9,23 +9,25 @@ public class DistanceConverterFactory { /** * Returns the appropriate converter for the given distance metric. * - * @param metric The distance metric name (e.g., "cosine", "euclidean") + * @param metric The distance metric name (e.g., "cosine", "manhattan", "euclidean") * @return A DistanceConverter appropriate for the given metric */ public static DistanceConverter getConverter(String metric) { if (metric == null) { - return new DefaultDistanceConverter(); + return new EuclideanDistanceConverter(); } metric = metric.toLowerCase(); switch (metric) { - case "cosine": - return new CosineDistanceConverter(); + case "l1": + case "manhattan": + return new ManhattanDistanceConverter(); + case "l2": case "euclidean": return new EuclideanDistanceConverter(); default: - return new DefaultDistanceConverter(); + return new CosineDistanceConverter(); } } } diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java index 94495801..75b821a3 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/EuclideanDistanceConverter.java @@ -1,13 +1,14 @@ package dev.langchain4j.community.store.embedding.oceanbase.distance; /** - * Converts euclidean distance to similarity score. - * For euclidean distance, similarity = 1 / (1 + distance) + * Euclidean distance to similarity converter. + * Uses exponential decay: similarity = e^(-distance) + * This works well for most distance metrics where smaller distance means higher similarity. */ public class EuclideanDistanceConverter implements DistanceConverter { @Override public double toSimilarity(double distance) { - return 1.0 / (1.0 + distance); + return Math.exp(-distance); } } diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/ManhattanDistanceConverter.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/ManhattanDistanceConverter.java new file mode 100644 index 00000000..3d6d4ec8 --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/ManhattanDistanceConverter.java @@ -0,0 +1,14 @@ +package dev.langchain4j.community.store.embedding.oceanbase.distance; + +/** + * Manhattan distance to similarity converter. + * Uses a simple inverse function: similarity = 1 / (1 + distance) + * This works well for Manhattan distance, which is the sum of absolute differences. + */ +public class ManhattanDistanceConverter implements DistanceConverter { + + @Override + public double toSimilarity(double distance) { + return 1.0 / (1.0 + distance); + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java index 025e3786..c150ec33 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java @@ -87,6 +87,7 @@ public static OceanBaseEmbeddingStore newEmbeddingStore() { .embeddingTable(EmbeddingTable.builder(TABLE_NAME) .vectorDimension(VECTOR_DIM) .createOption(CreateOption.CREATE_OR_REPLACE) + .distanceMetric("L2") .build()) .build(); } diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java index 3844caa1..4e6c030c 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java @@ -8,6 +8,7 @@ import dev.langchain4j.data.segment.TextSegment; import dev.langchain4j.store.embedding.EmbeddingMatch; import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.filter.Filter; import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder; import java.sql.SQLException; import java.util.Arrays; @@ -245,4 +246,114 @@ void should_remove_all_embeddings() { assertThat(matches).isEmpty(); } + + @Test + void should_filter_using_sql_filter_expression() { + // Given + Embedding[] embeddings = TestData.sampleEmbeddings(); + TextSegment[] segments = TestData.sampleTextSegments(); + embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); + + // When: filter by simple SQL expression for red color + Filter redColorFilter = MetadataFilterBuilder.metadataKey("color").isEqualTo("red"); + List> matches = embeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(6) + .filter(redColorFilter) + .build()) + .matches(); + + // Then + assertThat(matches).hasSize(2); + assertThat(matches) + .extracting(match -> match.embedded().metadata().getString("color")) + .containsOnly("red"); + + // When: filter by simple SQL expression for fruits + Filter fruitFilter = MetadataFilterBuilder.metadataKey("type").isEqualTo("fruit"); + matches = embeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(6) + .filter(fruitFilter) + .build()) + .matches(); + + // Then + assertThat(matches).hasSize(3); + assertThat(matches) + .extracting(match -> match.embedded().metadata().getString("type")) + .containsOnly("fruit"); + } + + @Test + void should_use_sql_filter_factory_for_filtering() { + // Given + Embedding[] embeddings = TestData.sampleEmbeddings(); + TextSegment[] segments = TestData.sampleTextSegments(); + embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); + + Filter matchAllFilter = MetadataFilterBuilder.metadataKey("type").isIn("fruit", "vegetable"); + + List> matches = embeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(10) + .filter(matchAllFilter) + .build()) + .matches(); + + // Then: should return all embeddings + assertThat(matches).hasSize(6); // 3 fruits + 3 vegetables + assertThat(matches) + .extracting(match -> match.embedded().metadata().getString("type")) + .containsExactlyInAnyOrder("fruit", "fruit", "fruit", "vegetable", "vegetable", "vegetable"); + } + + @Test + void should_use_sql_filter_factory_for_empty_results() { + // Given + Embedding[] embeddings = TestData.sampleEmbeddings(); + TextSegment[] segments = TestData.sampleTextSegments(); + embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); + + Filter nonExistentFilter = MetadataFilterBuilder.metadataKey("type").isEqualTo("non-existent-type"); + + List> matches = embeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(10) + .filter(nonExistentFilter) + .build()) + .matches(); + + // Then: should return no embeddings + assertThat(matches).isEmpty(); + } + + @Test + void should_remove_embeddings_by_filter_with_sql_filter_factory() { + // Given + Embedding[] embeddings = TestData.sampleEmbeddings(); + TextSegment[] segments = TestData.sampleTextSegments(); + embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); + + // When - remove all fruits using Filter + Filter fruitFilter = MetadataFilterBuilder.metadataKey("type").isEqualTo("fruit"); + embeddingStore.removeAll(fruitFilter); + + // Then - only vegetables should remain + List> matches = embeddingStore + .search(EmbeddingSearchRequest.builder() + .queryEmbedding(TestData.queryEmbedding()) + .maxResults(10) + .build()) + .matches(); + + assertThat(matches).hasSize(3); + assertThat(matches) + .extracting(match -> match.embedded().metadata().getString("type")) + .containsOnly("vegetable"); + } } From 8885bd0b3de2c1679dd24aca3ea9b427dc281245 Mon Sep 17 00:00:00 2001 From: SimonChou <2484152300@qq.com> Date: Thu, 26 Jun 2025 00:58:57 +0800 Subject: [PATCH 4/6] [FEATURE] Adjust the default distance converter of OceanBaseEmbeddingStore to cosine --- .../embedding/oceanbase/distance/DistanceConverterFactory.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java index d5c59dbe..62490e76 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/distance/DistanceConverterFactory.java @@ -14,7 +14,7 @@ public class DistanceConverterFactory { */ public static DistanceConverter getConverter(String metric) { if (metric == null) { - return new EuclideanDistanceConverter(); + return new CosineDistanceConverter(); } metric = metric.toLowerCase(); From 871a0338a4a6f4ab1de618481658b223bf945e8f Mon Sep 17 00:00:00 2001 From: SimonChou <2484152300@qq.com> Date: Mon, 30 Jun 2025 21:11:47 +0800 Subject: [PATCH 5/6] [FEATURE] Dealing with issues related to license compliance checks --- .../langchain4j-community-oceanbase/pom.xml | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/embedding-stores/langchain4j-community-oceanbase/pom.xml b/embedding-stores/langchain4j-community-oceanbase/pom.xml index c42504ba..88fc00bb 100644 --- a/embedding-stores/langchain4j-community-oceanbase/pom.xml +++ b/embedding-stores/langchain4j-community-oceanbase/pom.xml @@ -12,6 +12,10 @@ LangChain4j :: Community :: Integration :: OceanBase Community Integration with OceanBase Vector Storage + + 2.4.12 + + dev.langchain4j @@ -23,7 +27,7 @@ com.oceanbase oceanbase-client - 2.4.12 + ${oceanbase-jdbc.version} @@ -83,4 +87,19 @@ test + + + + + org.honton.chas + license-maven-plugin + + + true + + + + From e67f10c6b6735ed50fa3087bfc96efe3055eafc7 Mon Sep 17 00:00:00 2001 From: SimonChou <2484152300@qq.com> Date: Mon, 8 Sep 2025 13:43:41 +0800 Subject: [PATCH 6/6] [FEATURE] Adjust test cases and use EmbeddingStoreIT --- .../langchain4j-community-oceanbase/pom.xml | 6 + .../oceanbase/OceanBaseEmbeddingStore.java | 35 +- .../embedding/oceanbase/ValidationUtils.java | 50 --- .../oceanbase/CommonTestOperations.java | 9 +- .../oceanbase/OceanBaseEmbeddingStoreIT.java | 357 ++---------------- .../oceanbase/OceanBaseWithRemovalIT.java | 51 +++ .../store/embedding/oceanbase/TestData.java | 80 ---- 7 files changed, 114 insertions(+), 474 deletions(-) delete mode 100644 embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/ValidationUtils.java create mode 100644 embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseWithRemovalIT.java delete mode 100644 embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/TestData.java diff --git a/embedding-stores/langchain4j-community-oceanbase/pom.xml b/embedding-stores/langchain4j-community-oceanbase/pom.xml index 88fc00bb..6e9f8aa4 100644 --- a/embedding-stores/langchain4j-community-oceanbase/pom.xml +++ b/embedding-stores/langchain4j-community-oceanbase/pom.xml @@ -51,6 +51,12 @@ test + + dev.langchain4j + langchain4j-embeddings-all-minilm-l6-v2-q + test + + org.junit.jupiter junit-jupiter diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java index 0771352f..4368c83a 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStore.java @@ -1,8 +1,8 @@ package dev.langchain4j.community.store.embedding.oceanbase; -import static dev.langchain4j.community.store.embedding.oceanbase.ValidationUtils.ensureIndexNotNull; -import static dev.langchain4j.community.store.embedding.oceanbase.ValidationUtils.ensureNotEmpty; import static dev.langchain4j.internal.Utils.randomUUID; +import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank; +import static dev.langchain4j.internal.ValidationUtils.ensureNotEmpty; import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; import com.fasterxml.jackson.core.JsonProcessingException; @@ -88,7 +88,7 @@ public String add(Embedding embedding) { @Override public void add(String id, Embedding embedding) { - ensureNotEmpty(id, "id"); + ensureNotBlank(id, "id"); ensureNotNull(embedding, "embedding"); String sql = "INSERT INTO " + table.name() + " (" + table.idColumn() @@ -232,7 +232,7 @@ public EmbeddingSearchResult search(EmbeddingSearchRequest request) @Override public void remove(String id) { - ensureNotNull(id, "id"); + ensureNotBlank(id, "id"); String sql = "DELETE FROM " + table.name() + " WHERE " + table.idColumn() + " = ?"; @@ -247,11 +247,7 @@ public void remove(String id) { @Override public void removeAll(Collection ids) { - ensureNotNull(ids, "ids"); - - if (ids.isEmpty()) { - return; - } + ensureNotEmpty(ids, "ids"); // Use placeholders for each ID in the IN clause String placeholders = String.join(", ", Collections.nCopies(ids.size(), "?")); @@ -349,6 +345,27 @@ private String metadataToJson(Metadata metadata) { } } + /** + * Returns a null-safe item from a List. + * @param list List to get an item from. + * @param index Index of the item to get. + * @param name Name of the parameter, for error messages. + * @param Type of the list items. + * @return The item at the given index. + * @throws IllegalArgumentException If the item is null. + */ + public static T ensureIndexNotNull(List list, int index, String name) { + if (index < 0 || index >= list.size()) { + throw new IllegalArgumentException( + String.format("Index %d is out of bounds for %s list of size %d", index, name, list.size())); + } + T item = list.get(index); + if (item == null) { + throw new IllegalArgumentException(String.format("%s[%d] is null", name, index)); + } + return item; + } + /** * Returns a builder for configuring an OceanBaseEmbeddingStore. * diff --git a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/ValidationUtils.java b/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/ValidationUtils.java deleted file mode 100644 index 477f753d..00000000 --- a/embedding-stores/langchain4j-community-oceanbase/src/main/java/dev/langchain4j/community/store/embedding/oceanbase/ValidationUtils.java +++ /dev/null @@ -1,50 +0,0 @@ -package dev.langchain4j.community.store.embedding.oceanbase; - -import static dev.langchain4j.internal.ValidationUtils.ensureNotNull; - -import java.util.List; - -/** - * Utility class for validating method arguments. - */ -public final class ValidationUtils { - - private ValidationUtils() {} - - /** - * Returns a null-safe item from a List. - * @param list List to get an item from. - * @param index Index of the item to get. - * @param name Name of the parameter, for error messages. - * @param Type of the list items. - * @return The item at the given index. - * @throws IllegalArgumentException If the item is null. - */ - public static T ensureIndexNotNull(List list, int index, String name) { - if (index < 0 || index >= list.size()) { - throw new IllegalArgumentException( - String.format("Index %d is out of bounds for %s list of size %d", index, name, list.size())); - } - T item = list.get(index); - if (item == null) { - throw new IllegalArgumentException(String.format("%s[%d] is null", name, index)); - } - return item; - } - - /** - * Ensures that the given string is not null and not empty. - * - * @param value The string to check. - * @param name The name of the string to be used in the exception message. - * @return The string if it is not null and not empty. - * @throws IllegalArgumentException if the string is null or empty. - */ - public static String ensureNotEmpty(String value, String name) { - ensureNotNull(value, name); - if (value.trim().isEmpty()) { - throw new IllegalArgumentException(name + " cannot be empty"); - } - return value; - } -} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java index c150ec33..48107d83 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/CommonTestOperations.java @@ -7,6 +7,8 @@ import java.sql.Statement; import java.util.UUID; import javax.sql.DataSource; + +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.testcontainers.containers.GenericContainer; @@ -21,13 +23,12 @@ public class CommonTestOperations { private static final String TABLE_NAME = "langchain4j_oceanbase_test_" + UUID.randomUUID().toString().replace("-", ""); - private static final int VECTOR_DIM = 3; // OceanBase container configuration private static final int OCEANBASE_PORT = 2881; // Default OceanBase port private static GenericContainer oceanBaseContainer; private static DataSource dataSource; - private static EmbeddingModel embeddingModel; + private static EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); static { startOceanBaseContainer(); @@ -85,9 +86,9 @@ private static void startOceanBaseContainer() { public static OceanBaseEmbeddingStore newEmbeddingStore() { return OceanBaseEmbeddingStore.builder(getDataSource()) .embeddingTable(EmbeddingTable.builder(TABLE_NAME) - .vectorDimension(VECTOR_DIM) + .vectorDimension(embeddingModel.dimension()) .createOption(CreateOption.CREATE_OR_REPLACE) - .distanceMetric("L2") + .distanceMetric("cosine") .build()) .build(); } diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java index 4e6c030c..b17bb048 100644 --- a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java +++ b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseEmbeddingStoreIT.java @@ -1,39 +1,22 @@ package dev.langchain4j.community.store.embedding.oceanbase; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.BDDAssertions.within; - -import dev.langchain4j.data.document.Metadata; -import dev.langchain4j.data.embedding.Embedding; import dev.langchain4j.data.segment.TextSegment; -import dev.langchain4j.store.embedding.EmbeddingMatch; -import dev.langchain4j.store.embedding.EmbeddingSearchRequest; -import dev.langchain4j.store.embedding.filter.Filter; -import dev.langchain4j.store.embedding.filter.MetadataFilterBuilder; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreWithFilteringIT; import java.sql.SQLException; -import java.util.Arrays; -import java.util.List; import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.AfterEach; /** - * Integration test for OceanBaseEmbeddingStore. + * Integration tests for {@link OceanBaseEmbeddingStore}. */ -public class OceanBaseEmbeddingStoreIT { - - private OceanBaseEmbeddingStore embeddingStore; +public class OceanBaseEmbeddingStoreIT extends EmbeddingStoreWithFilteringIT { - @BeforeEach - void setUp() { - embeddingStore = CommonTestOperations.newEmbeddingStore(); - } + private OceanBaseEmbeddingStore embeddingStore = CommonTestOperations.newEmbeddingStore(); - @BeforeAll - static void setUpAll() { - // The container is started automatically in CommonTestOperations static block - } + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); @AfterAll static void cleanUp() throws SQLException { @@ -43,317 +26,29 @@ static void cleanUp() throws SQLException { CommonTestOperations.stopContainer(); } } - - @Test - void should_add_embedding_and_find_it_by_similarity() { - // Given - Embedding embedding = TestData.randomEmbedding(); - - // When - String id = embeddingStore.add(embedding); - - // Then - List> matches = embeddingStore - .search(EmbeddingSearchRequest.builder() - .queryEmbedding(embedding) - .maxResults(1) - .build()) - .matches(); - - assertThat(matches).hasSize(1); - assertThat(matches.get(0).embeddingId()).isEqualTo(id); - assertThat(matches.get(0).embedding().vector()) - .usingComparatorWithPrecision(0.0001f) - .containsExactly(embedding.vector()); - assertThat(matches.get(0).score()).isCloseTo(1.0, within(0.01)); - } - - @Test - void should_add_embedding_with_text_segment_and_find_it_by_similarity() { - // Given - Embedding embedding = TestData.randomEmbedding(); - TextSegment segment = TextSegment.from("Test text", Metadata.from("key", "value")); - - // When - String id = embeddingStore.add(embedding, segment); - - // Then - List> matches = embeddingStore - .search(EmbeddingSearchRequest.builder() - .queryEmbedding(embedding) - .maxResults(1) - .build()) - .matches(); - - assertThat(matches).hasSize(1); - assertThat(matches.get(0).embeddingId()).isEqualTo(id); - assertThat(matches.get(0).embedding().vector()) - .usingComparatorWithPrecision(0.0001f) - .containsExactly(embedding.vector()); - assertThat(matches.get(0).embedded().text()).isEqualTo("Test text"); - assertThat(matches.get(0).embedded().metadata().getString("key")).isEqualTo("value"); - } - - @Test - void should_add_multiple_embeddings_and_find_them_by_similarity() { - // Given - Embedding[] embeddings = TestData.sampleEmbeddings(); - TextSegment[] segments = TestData.sampleTextSegments(); - - // When - List ids = embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); - - // Then - List> matches = embeddingStore - .search(EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(3) - .build()) - .matches(); - - assertThat(matches).hasSize(3); - - // The query vector [0.9, 1.0, 0.9] should be closest to fruits - // Order should be: orange, banana, apple - assertThat(matches.get(0).embedded().text()).isEqualTo("橙子"); - assertThat(matches.get(1).embedded().text()).isEqualTo("香蕉"); - assertThat(matches.get(2).embedded().text()).isEqualTo("苹果"); - } - - @Test - void should_remove_embeddings_by_collection() { - // Given - Embedding[] embeddings = TestData.sampleEmbeddings(); - TextSegment[] segments = TestData.sampleTextSegments(); - List ids = embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); - - // When - remove first 3 embeddings (fruits) - embeddingStore.removeAll(ids.subList(0, 3)); - - // Then - only vegetables should remain - List> matches = embeddingStore - .search(EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(10) - .build()) - .matches(); - - assertThat(matches).hasSize(3); - assertThat(matches) - .extracting(match -> match.embedded().metadata().getString("type")) - .containsOnly("vegetable"); + + @AfterEach + void tearDown() { + clearStore(); } - - @Test - void should_remove_embeddings_by_filter() { - // Given - Embedding[] embeddings = TestData.sampleEmbeddings(); - TextSegment[] segments = TestData.sampleTextSegments(); - embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); - - // When - remove all fruits - embeddingStore.removeAll(MetadataFilterBuilder.metadataKey("type").isEqualTo("fruit")); - - // Then - only vegetables should remain - List> matches = embeddingStore - .search(EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(10) - .build()) - .matches(); - - assertThat(matches).hasSize(3); - assertThat(matches) - .extracting(match -> match.embedded().metadata().getString("type")) - .containsOnly("vegetable"); - } - - @Test - void should_filter_by_metadata() { - // Given - Embedding[] embeddings = TestData.sampleEmbeddings(); - TextSegment[] segments = TestData.sampleTextSegments(); - embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); - - // When: filter by fruits only - List> matches = embeddingStore - .search(EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(6) - .filter(MetadataFilterBuilder.metadataKey("type").isEqualTo("fruit")) - .build()) - .matches(); - - // Then - assertThat(matches).hasSize(3); - assertThat(matches) - .extracting(match -> match.embedded().metadata().getString("type")) - .containsOnly("fruit"); - - // When: filter by red color - matches = embeddingStore - .search(EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(6) - .filter(MetadataFilterBuilder.metadataKey("color").isEqualTo("red")) - .build()) - .matches(); - - // Then - assertThat(matches).hasSize(2); - assertThat(matches) - .extracting(match -> match.embedded().metadata().getString("color")) - .containsOnly("red"); + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; } - @Test - void should_remove_embedding() { - // Given - Embedding embedding = TestData.randomEmbedding(); - String id = embeddingStore.add(embedding); - - // When - embeddingStore.remove(id); - - // Then - List> matches = embeddingStore - .search(EmbeddingSearchRequest.builder() - .queryEmbedding(embedding) - .maxResults(1) - .build()) - .matches(); - - assertThat(matches).isEmpty(); + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; } - @Test - void should_remove_all_embeddings() { - // Given - Embedding[] embeddings = TestData.sampleEmbeddings(); - TextSegment[] segments = TestData.sampleTextSegments(); - embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); - - // When + @Override + protected void clearStore() { embeddingStore.removeAll(); - - // Then - List> matches = embeddingStore - .search(EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(10) - .build()) - .matches(); - - assertThat(matches).isEmpty(); - } - - @Test - void should_filter_using_sql_filter_expression() { - // Given - Embedding[] embeddings = TestData.sampleEmbeddings(); - TextSegment[] segments = TestData.sampleTextSegments(); - embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); - - // When: filter by simple SQL expression for red color - Filter redColorFilter = MetadataFilterBuilder.metadataKey("color").isEqualTo("red"); - List> matches = embeddingStore - .search(EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(6) - .filter(redColorFilter) - .build()) - .matches(); - - // Then - assertThat(matches).hasSize(2); - assertThat(matches) - .extracting(match -> match.embedded().metadata().getString("color")) - .containsOnly("red"); - - // When: filter by simple SQL expression for fruits - Filter fruitFilter = MetadataFilterBuilder.metadataKey("type").isEqualTo("fruit"); - matches = embeddingStore - .search(EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(6) - .filter(fruitFilter) - .build()) - .matches(); - - // Then - assertThat(matches).hasSize(3); - assertThat(matches) - .extracting(match -> match.embedded().metadata().getString("type")) - .containsOnly("fruit"); - } - - @Test - void should_use_sql_filter_factory_for_filtering() { - // Given - Embedding[] embeddings = TestData.sampleEmbeddings(); - TextSegment[] segments = TestData.sampleTextSegments(); - embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); - - Filter matchAllFilter = MetadataFilterBuilder.metadataKey("type").isIn("fruit", "vegetable"); - - List> matches = embeddingStore - .search(EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(10) - .filter(matchAllFilter) - .build()) - .matches(); - - // Then: should return all embeddings - assertThat(matches).hasSize(6); // 3 fruits + 3 vegetables - assertThat(matches) - .extracting(match -> match.embedded().metadata().getString("type")) - .containsExactlyInAnyOrder("fruit", "fruit", "fruit", "vegetable", "vegetable", "vegetable"); } - - @Test - void should_use_sql_filter_factory_for_empty_results() { - // Given - Embedding[] embeddings = TestData.sampleEmbeddings(); - TextSegment[] segments = TestData.sampleTextSegments(); - embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); - - Filter nonExistentFilter = MetadataFilterBuilder.metadataKey("type").isEqualTo("non-existent-type"); - - List> matches = embeddingStore - .search(EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(10) - .filter(nonExistentFilter) - .build()) - .matches(); - - // Then: should return no embeddings - assertThat(matches).isEmpty(); - } - - @Test - void should_remove_embeddings_by_filter_with_sql_filter_factory() { - // Given - Embedding[] embeddings = TestData.sampleEmbeddings(); - TextSegment[] segments = TestData.sampleTextSegments(); - embeddingStore.addAll(Arrays.asList(embeddings), Arrays.asList(segments)); - - // When - remove all fruits using Filter - Filter fruitFilter = MetadataFilterBuilder.metadataKey("type").isEqualTo("fruit"); - embeddingStore.removeAll(fruitFilter); - - // Then - only vegetables should remain - List> matches = embeddingStore - .search(EmbeddingSearchRequest.builder() - .queryEmbedding(TestData.queryEmbedding()) - .maxResults(10) - .build()) - .matches(); - - assertThat(matches).hasSize(3); - assertThat(matches) - .extracting(match -> match.embedded().metadata().getString("type")) - .containsOnly("vegetable"); + + @Override + protected boolean supportsContains() { + return true; } } diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseWithRemovalIT.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseWithRemovalIT.java new file mode 100644 index 00000000..9772d79d --- /dev/null +++ b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/OceanBaseWithRemovalIT.java @@ -0,0 +1,51 @@ +package dev.langchain4j.community.store.embedding.oceanbase; + +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.EmbeddingStoreWithRemovalIT; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import java.sql.SQLException; + +/** + * Tests for {@link OceanBaseEmbeddingStore} with removal. + */ +public class OceanBaseWithRemovalIT extends EmbeddingStoreWithRemovalIT { + + private OceanBaseEmbeddingStore embeddingStore; + + EmbeddingModel embeddingModel = new AllMiniLmL6V2QuantizedEmbeddingModel(); + + @BeforeEach + void setUp() { + embeddingStore = CommonTestOperations.newEmbeddingStore(); + } + + @AfterAll + static void cleanUp() throws SQLException { + try { + CommonTestOperations.dropTable(); + } finally { + CommonTestOperations.stopContainer(); + } + } + + @AfterEach + void tearDown() { + embeddingStore.removeAll(); + } + + @Override + protected EmbeddingStore embeddingStore() { + return embeddingStore; + } + + @Override + protected EmbeddingModel embeddingModel() { + return embeddingModel; + } +} diff --git a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/TestData.java b/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/TestData.java deleted file mode 100644 index 14e5c3c6..00000000 --- a/embedding-stores/langchain4j-community-oceanbase/src/test/java/dev/langchain4j/community/store/embedding/oceanbase/TestData.java +++ /dev/null @@ -1,80 +0,0 @@ -package dev.langchain4j.community.store.embedding.oceanbase; - -import dev.langchain4j.data.document.Metadata; -import dev.langchain4j.data.embedding.Embedding; -import dev.langchain4j.data.segment.TextSegment; -import java.util.Map; -import java.util.Random; - -/** - * Utility class for generating test data. - */ -public class TestData { - - private static final Random random = new Random(666); - - /** - * Creates a random embedding with specified dimension. - * - * @param dimension Vector dimension. - * @return A random embedding. - */ - public static Embedding randomEmbedding(int dimension) { - float[] vector = new float[dimension]; - for (int i = 0; i < dimension; i++) { - vector[i] = random.nextFloat(); - } - return new Embedding(vector); - } - - /** - * Creates a random embedding with default dimension (3). - * - * @return A random embedding with default dimension. - */ - public static Embedding randomEmbedding() { - return randomEmbedding(3); - } - - /** - * Creates an array of sample embeddings and their corresponding text segments. - * The sample data mimics fruits and vegetables with their vector representation. - * - * @return An array of sample embeddings. - */ - public static Embedding[] sampleEmbeddings() { - return new Embedding[] { - new Embedding(new float[] {1.2f, 0.7f, 1.1f}), - new Embedding(new float[] {0.6f, 1.2f, 0.8f}), - new Embedding(new float[] {1.1f, 1.1f, 0.9f}), - new Embedding(new float[] {5.3f, 4.8f, 5.4f}), - new Embedding(new float[] {4.9f, 5.3f, 4.8f}), - new Embedding(new float[] {5.2f, 4.9f, 5.1f}) - }; - } - - /** - * Creates an array of sample text segments. - * - * @return An array of sample text segments. - */ - public static TextSegment[] sampleTextSegments() { - return new TextSegment[] { - TextSegment.from("苹果", Metadata.from(Map.of("type", "fruit", "color", "red"))), - TextSegment.from("香蕉", Metadata.from(Map.of("type", "fruit", "color", "yellow"))), - TextSegment.from("橙子", Metadata.from(Map.of("type", "fruit", "color", "orange"))), - TextSegment.from("胡萝卜", Metadata.from(Map.of("type", "vegetable", "color", "orange"))), - TextSegment.from("菠菜", Metadata.from(Map.of("type", "vegetable", "color", "green"))), - TextSegment.from("西红柿", Metadata.from(Map.of("type", "vegetable", "color", "red"))) - }; - } - - /** - * Creates a query embedding for testing similarity search. - * - * @return A query embedding. - */ - public static Embedding queryEmbedding() { - return new Embedding(new float[] {0.9f, 1.0f, 0.9f}); - } -}