diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index.py b/google/cloud/aiplatform/matching_engine/matching_engine_index.py index b5a521d1e1..7922cba7e1 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index.py @@ -596,10 +596,15 @@ def create_tree_ah_index( """ - algorithm_config = matching_engine_index_config.TreeAhConfig( - leaf_node_embedding_count=leaf_node_embedding_count, - leaf_nodes_to_search_percent=leaf_nodes_to_search_percent, - ) + algorithm_config = None + if ( + leaf_node_embedding_count is not None + or leaf_nodes_to_search_percent is not None + ): + algorithm_config = matching_engine_index_config.TreeAhConfig( + leaf_node_embedding_count=leaf_node_embedding_count, + leaf_nodes_to_search_percent=leaf_nodes_to_search_percent, + ) config = matching_engine_index_config.MatchingEngineIndexConfig( dimensions=dimensions, diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_config.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_config.py index e388b78a1f..f9140dd1d1 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_config.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_config.py @@ -120,7 +120,7 @@ class MatchingEngineIndexConfig: dimensions (int): Required. The number of dimensions of the input vectors. algorithm_config (AlgorithmConfig): - Required. The configuration with regard to the algorithms used for efficient search. + Optional. The configuration with regard to the algorithms used for efficient search. approximate_neighbors_count (int): Optional. The default number of neighbors to find via approximate search before exact reordering is performed. Exact reordering is a procedure where results returned by an @@ -139,7 +139,7 @@ class MatchingEngineIndexConfig: """ dimensions: int - algorithm_config: AlgorithmConfig + algorithm_config: Optional[AlgorithmConfig] = None approximate_neighbors_count: Optional[int] = None distance_measure_type: Optional[DistanceMeasureType] = None feature_norm_type: Optional[FeatureNormType] = None @@ -153,10 +153,13 @@ def as_dict(self) -> Dict[str, Any]: """ res = { "dimensions": self.dimensions, - "algorithmConfig": self.algorithm_config.as_dict(), "approximateNeighborsCount": self.approximate_neighbors_count, "distanceMeasureType": self.distance_measure_type, "featureNormType": self.feature_norm_type, "shardSize": self.shard_size, } + if self.algorithm_config: + res["algorithmConfig"] = self.algorithm_config.as_dict() + else: + res["algorithmConfig"] = None return res diff --git a/tests/unit/aiplatform/test_matching_engine_index.py b/tests/unit/aiplatform/test_matching_engine_index.py index 1a66d3be71..c8929366b5 100644 --- a/tests/unit/aiplatform/test_matching_engine_index.py +++ b/tests/unit/aiplatform/test_matching_engine_index.py @@ -618,6 +618,45 @@ def test_create_tree_ah_index_backward_compatibility(self, create_index_mock): timeout=None, ) + @pytest.mark.usefixtures("get_index_mock") + def test_create_tree_ah_index_empty_algorithm_config(self, create_index_mock): + aiplatform.init(project=_TEST_PROJECT) + + aiplatform.MatchingEngineIndex.create_tree_ah_index( + display_name=_TEST_INDEX_DISPLAY_NAME, + contents_delta_uri=_TEST_CONTENTS_DELTA_URI, + dimensions=_TEST_INDEX_CONFIG_DIMENSIONS, + approximate_neighbors_count=_TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT, + distance_measure_type=_TEST_INDEX_DISTANCE_MEASURE_TYPE.value, + feature_norm_type=_TEST_INDEX_FEATURE_NORM_TYPE.value, + description=_TEST_INDEX_DESCRIPTION, + labels=_TEST_LABELS, + ) + + expected = gca_index.Index( + display_name=_TEST_INDEX_DISPLAY_NAME, + metadata={ + "config": { + "algorithmConfig": None, + "dimensions": _TEST_INDEX_CONFIG_DIMENSIONS, + "approximateNeighborsCount": _TEST_INDEX_APPROXIMATE_NEIGHBORS_COUNT, + "distanceMeasureType": _TEST_INDEX_DISTANCE_MEASURE_TYPE, + "featureNormType": _TEST_INDEX_FEATURE_NORM_TYPE, + "shardSize": None, + }, + "contentsDeltaUri": _TEST_CONTENTS_DELTA_URI, + }, + description=_TEST_INDEX_DESCRIPTION, + labels=_TEST_LABELS, + ) + + create_index_mock.assert_called_once_with( + parent=_TEST_PARENT, + index=expected, + metadata=_TEST_REQUEST_METADATA, + timeout=None, + ) + @pytest.mark.usefixtures("get_index_mock") @pytest.mark.parametrize("sync", [True, False]) @pytest.mark.parametrize(