Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tests/unit/vertex_rag/test_rag_constants_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
RagFile,
RagManagedDb,
RagManagedDbConfig,
RagManagedVertexVectorSearch,
RagResource,
RagRetrievalConfig,
RagVectorDbConfig,
Expand Down Expand Up @@ -137,6 +138,12 @@
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
)
TEST_RAG_MANAGED_VERTEX_VECTOR_SEARCH_COLLECTION_NAME = (
"test-rag-managed-vertex-vector-search-collection"
)
TEST_RAG_MANAGED_VERTEX_VECTOR_SEARCH_CONFIG = RagManagedVertexVectorSearch(
collection_name=TEST_RAG_MANAGED_VERTEX_VECTOR_SEARCH_COLLECTION_NAME
)
TEST_GAPIC_RAG_CORPUS.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = (
"projects/{}/locations/{}/publishers/google/models/textembedding-gecko".format(
TEST_PROJECT, TEST_REGION
Expand Down Expand Up @@ -200,6 +207,17 @@
),
),
)
TEST_GAPIC_RAG_CORPUS_RAG_MANAGED_VERTEX_VECTOR_SEARCH = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
rag_vector_db_config=GapicRagVectorDbConfig(
rag_managed_vertex_vector_search=GapicRagVectorDbConfig.RagManagedVertexVectorSearch(
collection_name=TEST_RAG_MANAGED_VERTEX_VECTOR_SEARCH_COLLECTION_NAME,
),
),
)

TEST_GAPIC_RAG_CORPUS_RAG_MANAGED_DB = GapicRagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
Expand Down Expand Up @@ -301,6 +319,13 @@
description=TEST_CORPUS_DISCRIPTION,
vector_db=TEST_VERTEX_VECTOR_SEARCH_CONFIG,
)
TEST_RAG_CORPUS_RAG_MANAGED_VERTEX_VECTOR_SEARCH = RagCorpus(
name=TEST_RAG_CORPUS_RESOURCE_NAME,
display_name=TEST_CORPUS_DISPLAY_NAME,
description=TEST_CORPUS_DISCRIPTION,
vector_db=TEST_RAG_MANAGED_VERTEX_VECTOR_SEARCH_CONFIG,
)

TEST_PAGE_TOKEN = "test-page-token"
# Backend Config
TEST_GAPIC_RAG_CORPUS_BACKEND_CONFIG = GapicRagCorpus(
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/vertex_rag/test_rag_data_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,23 @@ def create_rag_corpus_mock_pinecone():
yield create_rag_corpus_mock_pinecone


@pytest.fixture
def create_rag_corpus_mock_rag_managed_vertex_vector_search():
with mock.patch.object(
VertexRagDataServiceClient,
"create_rag_corpus",
) as create_rag_corpus_mock_rag_managed_vertex_vector_search:
create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation)
create_rag_corpus_lro_mock.done.return_value = True
create_rag_corpus_lro_mock.result.return_value = (
test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_RAG_MANAGED_VERTEX_VECTOR_SEARCH
)
create_rag_corpus_mock_rag_managed_vertex_vector_search.return_value = (
create_rag_corpus_lro_mock
)
yield create_rag_corpus_mock_rag_managed_vertex_vector_search


@pytest.fixture
def create_rag_corpus_mock_rag_managed_db():
with mock.patch.object(
Expand Down Expand Up @@ -805,6 +822,18 @@ def test_create_corpus_vertex_vector_search_success(self):
rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH
)

@pytest.mark.usefixtures("create_rag_corpus_mock_rag_managed_vertex_vector_search")
def test_create_corpus_rag_managed_vertex_vector_search_success(self):
rag_corpus = rag.create_corpus(
display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME,
vector_db=test_rag_constants_preview.TEST_RAG_MANAGED_VERTEX_VECTOR_SEARCH_CONFIG,
)

rag_corpus_eq(
rag_corpus,
test_rag_constants_preview.TEST_RAG_CORPUS_RAG_MANAGED_VERTEX_VECTOR_SEARCH,
)

@pytest.mark.usefixtures("create_rag_corpus_mock_vertex_vector_search_backend")
def test_create_corpus_vertex_vector_search_backend_success(self):
rag_corpus = rag.create_corpus(
Expand Down
2 changes: 2 additions & 0 deletions vertexai/preview/rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
RagFile,
RagManagedDb,
RagManagedDbConfig,
RagManagedVertexVectorSearch,
RagResource,
RagRetrievalConfig,
RagVectorDbConfig,
Expand Down Expand Up @@ -102,6 +103,7 @@
"RagFile",
"RagManagedDb",
"RagManagedDbConfig",
"RagManagedVertexVectorSearch",
"RagResource",
"RagRetrievalConfig",
"RagVectorDbConfig",
Expand Down
2 changes: 2 additions & 0 deletions vertexai/preview/rag/rag_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
RagEngineConfig,
RagFile,
RagManagedDb,
RagManagedVertexVectorSearch,
RagVectorDbConfig,
SharePointSources,
SlackChannelsSource,
Expand All @@ -79,6 +80,7 @@ def create_corpus(
VertexVectorSearch,
Pinecone,
RagManagedDb,
RagManagedVertexVectorSearch,
]
] = None,
vertex_ai_search_config: Optional[VertexAiSearchConfig] = None,
Expand Down
31 changes: 29 additions & 2 deletions vertexai/preview/rag/utils/_gapic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
RagFile,
RagManagedDb,
RagManagedDbConfig,
RagManagedVertexVectorSearch,
RagVectorDbConfig,
Basic,
Enterprise,
Expand Down Expand Up @@ -176,6 +177,15 @@ def _check_vertex_vector_search(gapic_vector_db: GapicRagVectorDbConfig) -> bool
return gapic_vector_db.vertex_vector_search.ByteSize() > 0


def _check_rag_managed_vertex_vector_search(
gapic_vector_db: GapicRagVectorDbConfig,
) -> bool:
try:
return gapic_vector_db.__contains__("rag_managed_vertex_vector_search")
except AttributeError:
return gapic_vector_db.rag_managed_vertex_vector_search.ByteSize() > 0


def _check_rag_embedding_model_config(
gapic_vector_db: GapicRagVectorDbConfig,
) -> bool:
Expand Down Expand Up @@ -240,6 +250,10 @@ def convert_gapic_to_vector_db(
index_name=gapic_vector_db.pinecone.index_name,
api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version,
)
elif _check_rag_managed_vertex_vector_search(gapic_vector_db):
return RagManagedVertexVectorSearch(
collection_name=gapic_vector_db.rag_managed_vertex_vector_search.collection_name,
)
elif _check_vertex_vector_search(gapic_vector_db):
return VertexVectorSearch(
index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint,
Expand Down Expand Up @@ -299,6 +313,10 @@ def convert_gapic_to_backend_config(
index_name=gapic_vector_db.pinecone.index_name,
api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version,
)
elif _check_rag_managed_vertex_vector_search(gapic_vector_db):
vector_config.vector_db = RagManagedVertexVectorSearch(
collection_name=gapic_vector_db.rag_managed_vertex_vector_search.collection_name,
)
elif _check_vertex_vector_search(gapic_vector_db):
vector_config.vector_db = VertexVectorSearch(
index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint,
Expand Down Expand Up @@ -904,9 +922,14 @@ def set_vector_db(
),
),
)
elif isinstance(vector_db, RagManagedVertexVectorSearch):
rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig(
rag_managed_vertex_vector_search=GapicRagVectorDbConfig.RagManagedVertexVectorSearch(),
)

else:
raise TypeError(
"vector_db must be a Weaviate, VertexFeatureStore, VertexVectorSearch, RagManagedDb, or Pinecone."
"vector_db must be a Weaviate, VertexFeatureStore, VertexVectorSearch, RagManagedDb, Pinecone, or RagManagedVertexVectorSearch."
)


Expand Down Expand Up @@ -973,10 +996,14 @@ def set_backend_config(
rag_corpus.vector_db_config.api_auth.api_key_config.api_key_secret_version = (
api_key
)
elif isinstance(vector_config, RagManagedVertexVectorSearch):
rag_corpus.vector_db_config.rag_managed_vertex_vector_search.CopyFrom(
GapicRagVectorDbConfig.RagManagedVertexVectorSearch()
)
else:
raise TypeError(
"backend_config must be a VertexFeatureStore,"
"RagManagedDb, or Pinecone."
"RagManagedDb, Pinecone, or RagManagedVertexVectorSearch."
)
if backend_config.rag_embedding_model_config:
set_embedding_model_config(
Expand Down
25 changes: 23 additions & 2 deletions vertexai/preview/rag/utils/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,20 @@ class Pinecone:
api_key: Optional[str] = None


@dataclasses.dataclass
class RagManagedVertexVectorSearch:
"""RagManagedVertexVectorSearch.

Attributes:
collection_name: The resource name of the Vector Search 2.0 Collection that
RAG Created for the corpus. Only populated after the corpus is successfully
created. Format:
``projects/{project}/locations/{location}/collections/{collection_id}``
"""

collection_name: Optional[str] = None


@dataclasses.dataclass
class VertexAiSearchConfig:
"""VertexAiSearchConfig.
Expand All @@ -246,12 +260,19 @@ class RagVectorDbConfig:

Attributes:
vector_db: Can be one of the following: Weaviate, VertexFeatureStore,
VertexVectorSearch, Pinecone, RagManagedDb.
VertexVectorSearch, Pinecone, RagManagedDb, RagManagedVertexVectorSearch.
rag_embedding_model_config: The embedding model config of the Vector DB.
"""

vector_db: Optional[
Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb]
Union[
Weaviate,
VertexFeatureStore,
VertexVectorSearch,
Pinecone,
RagManagedDb,
RagManagedVertexVectorSearch,
]
] = None
rag_embedding_model_config: Optional[RagEmbeddingModelConfig] = None

Expand Down