diff --git a/ads/aqua/common/enums.py b/ads/aqua/common/enums.py index 9bdc0be4a..4a423788d 100644 --- a/ads/aqua/common/enums.py +++ b/ads/aqua/common/enums.py @@ -52,6 +52,7 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta): AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving" AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving" AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving" + AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving" class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta): @@ -80,3 +81,11 @@ class RqsAdditionalDetails(str, metaclass=ExtendedEnumMeta): MODEL_VERSION_SET_NAME = "modelVersionSetName" PROJECT_ID = "projectId" VERSION_LABEL = "versionLabel" + + +class TextEmbeddingInferenceContainerParams(str, metaclass=ExtendedEnumMeta): + """Contains a subset of params that are required for enabling model deployment in OCI Data Science. More options + are available at https://huggingface.co/docs/text-embeddings-inference/en/cli_arguments""" + + MODEL_ID = "model-id" + PORT = "port" diff --git a/ads/aqua/common/utils.py b/ads/aqua/common/utils.py index b47aab6c9..7a0d9d46b 100644 --- a/ads/aqua/common/utils.py +++ b/ads/aqua/common/utils.py @@ -35,6 +35,7 @@ InferenceContainerParamType, InferenceContainerType, RqsAdditionalDetails, + TextEmbeddingInferenceContainerParams, ) from ads.aqua.common.errors import ( AquaFileNotFoundError, @@ -51,6 +52,7 @@ MODEL_BY_REFERENCE_OSS_PATH_KEY, SERVICE_MANAGED_CONTAINER_URI_SCHEME, SUPPORTED_FILE_FORMATS, + TEI_CONTAINER_DEFAULT_HOST, TGI_INFERENCE_RESTRICTED_PARAMS, UNKNOWN, UNKNOWN_JSON_STR, @@ -63,7 +65,12 @@ from ads.common.object_storage_details import ObjectStorageDetails from ads.common.oci_resource import SEARCH_TYPE, OCIResource from ads.common.utils import copy_file, get_console_link, upload_to_os -from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID +from ads.config import ( + AQUA_MODEL_DEPLOYMENT_FOLDER, + AQUA_SERVICE_MODELS_BUCKET, + CONDA_BUCKET_NS, + TENANCY_OCID, +) from ads.model import DataScienceModel, ModelVersionSet logger = logging.getLogger("ads.aqua") @@ -569,15 +576,13 @@ def get_container_image( A dict of allowed configs. """ + container_image = UNKNOWN config = config_file_name or get_container_config() config_file_name = service_config_path() if container_type not in config: - raise AquaValueError( - f"{config_file_name} does not have config details for model: {container_type}" - ) + return UNKNOWN - container_image = None mapping = config[container_type] versions = [obj["version"] for obj in mapping] # assumes numbered versions, update if `latest` is used @@ -1078,3 +1083,76 @@ def list_hf_models(query: str) -> List[str]: return [model.id for model in models if model.disabled is None] except HfHubHTTPError as err: raise format_hf_custom_error_message(err) from err + + +def generate_tei_cmd_var(os_path: str) -> List[str]: + """This utility functions generates CMD params for Text Embedding Inference container. Only the + essential parameters for OCI model deployment are added, defaults are used for the rest. + Parameters + ---------- + os_path: str + OCI bucket path where the model artifacts are uploaded - oci://bucket@namespace/prefix + + Returns + ------- + cmd_var: + List of command line arguments + """ + + cmd_prefix = "--" + cmd_var = [ + f"{cmd_prefix}{TextEmbeddingInferenceContainerParams.MODEL_ID}", + f"{AQUA_MODEL_DEPLOYMENT_FOLDER}{ObjectStorageDetails.from_path(os_path.rstrip('/')).filepath}/", + f"{cmd_prefix}{TextEmbeddingInferenceContainerParams.PORT}", + TEI_CONTAINER_DEFAULT_HOST, + ] + + return cmd_var + + +def parse_cmd_var(cmd_list: List[str]) -> dict: + """Helper functions that parses a list into a key-value dictionary. The list contains keys separated by the prefix + '--' and the value of the key is the subsequent element. + """ + parsed_cmd = {} + + for i, cmd in enumerate(cmd_list): + if cmd.startswith("--"): + if i + 1 < len(cmd_list) and not cmd_list[i + 1].startswith("--"): + parsed_cmd[cmd] = cmd_list[i + 1] + i += 1 + else: + parsed_cmd[cmd] = None + return parsed_cmd + + +def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]: + """This function accepts two lists of parameters and combines them. If the second list shares the common parameter + names/keys, then it raises an error. + Parameters + ---------- + cmd_var: List[str] + Default list of parameters + overrides: List[str] + List of parameters to override + Returns + ------- + List[str] of combined parameters + """ + cmd_var = [str(x) for x in cmd_var] + if not overrides: + return cmd_var + overrides = [str(x) for x in overrides] + + cmd_dict = parse_cmd_var(cmd_var) + overrides_dict = parse_cmd_var(overrides) + + # check for conflicts + common_keys = set(cmd_dict.keys()) & set(overrides_dict.keys()) + if common_keys: + raise AquaValueError( + f"The following CMD input cannot be overridden for model deployment: {', '.join(common_keys)}" + ) + + combined_cmd_var = cmd_var + overrides + return combined_cmd_var diff --git a/ads/aqua/constants.py b/ads/aqua/constants.py index c6a03a801..76406d0d7 100644 --- a/ads/aqua/constants.py +++ b/ads/aqua/constants.py @@ -80,3 +80,4 @@ "--port", "--host", } +TEI_CONTAINER_DEFAULT_HOST = "8080" diff --git a/ads/aqua/model/constants.py b/ads/aqua/model/constants.py index 88b35ca30..0a07152e4 100644 --- a/ads/aqua/model/constants.py +++ b/ads/aqua/model/constants.py @@ -17,6 +17,7 @@ class ModelCustomMetadataFields(str, metaclass=ExtendedEnumMeta): DEPLOYMENT_CONTAINER = "deployment-container" EVALUATION_CONTAINER = "evaluation-container" FINETUNE_CONTAINER = "finetune-container" + DEPLOYMENT_CONTAINER_URI = "deployment-container-uri" class ModelTask(str, metaclass=ExtendedEnumMeta): diff --git a/ads/aqua/model/entities.py b/ads/aqua/model/entities.py index 31125644f..3ba884da9 100644 --- a/ads/aqua/model/entities.py +++ b/ads/aqua/model/entities.py @@ -98,6 +98,7 @@ class AquaModel(AquaModelSummary, DataClassSerializable): model_card: str = None inference_container: str = None + inference_container_uri: str = None finetuning_container: str = None evaluation_container: str = None artifact_location: str = None @@ -287,6 +288,7 @@ class ImportModelDetails(CLIBuilderMixin): compartment_id: Optional[str] = None project_id: Optional[str] = None model_file: Optional[str] = None + inference_container_uri: Optional[str] = None def __post_init__(self): self._command = "model register" diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index b27d28c49..374e20ada 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -21,6 +21,7 @@ _build_resource_identifier, copy_model_config, create_word_icon, + generate_tei_cmd_var, get_artifact_path, get_hf_model_info, list_os_files_with_extension, @@ -67,7 +68,9 @@ from ads.common.oci_resource import SEARCH_TYPE, OCIResource from ads.common.utils import get_console_link from ads.config import ( + AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME, AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME, + AQUA_DEPLOYMENT_CONTAINER_URI_METADATA_NAME, AQUA_EVALUATION_CONTAINER_METADATA_NAME, AQUA_FINETUNING_CONTAINER_METADATA_NAME, COMPARTMENT_OCID, @@ -229,6 +232,12 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod ModelCustomMetadataFields.DEPLOYMENT_CONTAINER, ModelCustomMetadataItem(key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER), ).value + inference_container_uri = ds_model.custom_metadata_list.get( + ModelCustomMetadataFields.DEPLOYMENT_CONTAINER_URI, + ModelCustomMetadataItem( + key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER_URI + ), + ).value evaluation_container = ds_model.custom_metadata_list.get( ModelCustomMetadataFields.EVALUATION_CONTAINER, ModelCustomMetadataItem(key=ModelCustomMetadataFields.EVALUATION_CONTAINER), @@ -247,6 +256,7 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod project_id=ds_model.project_id, model_card=model_card, inference_container=inference_container, + inference_container_uri=inference_container_uri, finetuning_container=finetuning_container, evaluation_container=evaluation_container, artifact_location=artifact_location, @@ -629,6 +639,7 @@ def _create_model_catalog_entry( validation_result: ModelValidationResult, compartment_id: Optional[str], project_id: Optional[str], + inference_container_uri: Optional[str], ) -> DataScienceModel: """Create model by reference from the object storage path @@ -640,6 +651,7 @@ def _create_model_catalog_entry( verified_model (DataScienceModel): If set, then copies all the tags and custom metadata information from the service verified model compartment_id (Optional[str]): Compartment Id of the compartment where the model has to be created project_id (Optional[str]): Project id of the project where the model has to be created + inference_container_uri (Optional[str]): Inference container uri for BYOC Returns: DataScienceModel: Returns Datascience model instance. @@ -685,6 +697,40 @@ def _create_model_catalog_entry( raise AquaRuntimeError( f"Require Inference container information. Model: {model_name} does not have associated inference container defaults. Check docs for more information on how to pass inference container." ) + metadata.add( + key=AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME, + value=inference_container, + description=f"Inference container mapping for {model_name}", + category="Other", + ) + if inference_container_uri: + metadata.add( + key=AQUA_DEPLOYMENT_CONTAINER_URI_METADATA_NAME, + value=inference_container_uri, + description=f"Inference container URI for {model_name}", + category="Other", + ) + + inference_containers = ( + AquaContainerConfig.from_container_index_json().inference + ) + smc_container_set = { + container.family for container in inference_containers.values() + } + # only add cmd vars if inference container is not an SMC + if ( + inference_container not in smc_container_set + and inference_container + == InferenceContainerTypeFamily.AQUA_TEI_CONTAINER_FAMILY + ): + cmd_vars = generate_tei_cmd_var(os_path) + metadata.add( + key=AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME, + value=" ".join(cmd_vars), + description=f"Inference container cmd vars for {model_name}", + category="Other", + ) + if finetuning_container: tags[Tags.READY_TO_FINE_TUNE] = "true" metadata.add( @@ -706,12 +752,6 @@ def _create_model_catalog_entry( category="Other", ) - metadata.add( - key=AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME, - value=inference_container, - description=f"Inference container mapping for {model_name}", - category="Other", - ) metadata.add( key=AQUA_EVALUATION_CONTAINER_METADATA_NAME, value="odsc-llm-evaluate", @@ -935,14 +975,15 @@ def _validate_model( # gguf extension exist. if {ModelFormat.SAFETENSORS, ModelFormat.GGUF}.issubset(set(model_formats)): if ( - import_model_details.inference_container.lower() == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY + import_model_details.inference_container.lower() + == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY ): self._validate_gguf_format( import_model_details=import_model_details, verified_model=verified_model, gguf_model_files=gguf_model_files, validation_result=validation_result, - model_name=model_name + model_name=model_name, ) else: self._validate_safetensor_format( @@ -950,7 +991,7 @@ def _validate_model( verified_model=verified_model, validation_result=validation_result, hf_download_config_present=hf_download_config_present, - model_name=model_name + model_name=model_name, ) elif ModelFormat.SAFETENSORS in model_formats: self._validate_safetensor_format( @@ -958,7 +999,7 @@ def _validate_model( verified_model=verified_model, validation_result=validation_result, hf_download_config_present=hf_download_config_present, - model_name=model_name + model_name=model_name, ) elif ModelFormat.GGUF in model_formats: self._validate_gguf_format( @@ -966,7 +1007,7 @@ def _validate_model( verified_model=verified_model, gguf_model_files=gguf_model_files, validation_result=validation_result, - model_name=model_name + model_name=model_name, ) return validation_result @@ -977,7 +1018,7 @@ def _validate_safetensor_format( verified_model: DataScienceModel = None, validation_result: ModelValidationResult = None, hf_download_config_present: bool = None, - model_name: str = None + model_name: str = None, ): if import_model_details.download_from_hf: # validates config.json exists for safetensors model from hugginface @@ -1004,20 +1045,13 @@ def _validate_safetensor_format( ) from ex else: try: - metadata_model_type = ( - verified_model.custom_metadata_list.get( - AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE - ).value - ) + metadata_model_type = verified_model.custom_metadata_list.get( + AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE + ).value if metadata_model_type: - if ( - AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE - in model_config - ): + if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config: if ( - model_config[ - AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE - ] + model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE] != metadata_model_type ): raise AquaRuntimeError( @@ -1035,9 +1069,7 @@ def _validate_safetensor_format( except Exception: pass if verified_model: - validation_result.telemetry_model_name = ( - verified_model.display_name - ) + validation_result.telemetry_model_name = verified_model.display_name elif ( model_config is not None and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config @@ -1049,9 +1081,7 @@ def _validate_safetensor_format( ): validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}" else: - validation_result.telemetry_model_name = ( - AQUA_MODEL_TYPE_CUSTOM - ) + validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM @staticmethod def _validate_gguf_format( @@ -1240,21 +1270,28 @@ def register( validation_result=validation_result, compartment_id=import_model_details.compartment_id, project_id=import_model_details.project_id, + inference_container_uri=import_model_details.inference_container_uri, ) # registered model will always have inference and evaluation container, but # fine-tuning container may be not set inference_container = ds_model.custom_metadata_list.get( - ModelCustomMetadataFields.DEPLOYMENT_CONTAINER + ModelCustomMetadataFields.DEPLOYMENT_CONTAINER, + ModelCustomMetadataItem(key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER), + ).value + inference_container_uri = ds_model.custom_metadata_list.get( + ModelCustomMetadataFields.DEPLOYMENT_CONTAINER_URI, + ModelCustomMetadataItem( + key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER_URI + ), ).value evaluation_container = ds_model.custom_metadata_list.get( ModelCustomMetadataFields.EVALUATION_CONTAINER, + ModelCustomMetadataItem(key=ModelCustomMetadataFields.EVALUATION_CONTAINER), + ).value + finetuning_container: str = ds_model.custom_metadata_list.get( + ModelCustomMetadataFields.FINETUNE_CONTAINER, + ModelCustomMetadataItem(key=ModelCustomMetadataFields.FINETUNE_CONTAINER), ).value - try: - finetuning_container = ds_model.custom_metadata_list.get( - ModelCustomMetadataFields.FINETUNE_CONTAINER, - ).value - except Exception: - finetuning_container = None aqua_model_attributes = dict( **self._process_model(ds_model, self.region), @@ -1266,6 +1303,7 @@ def register( ) ), inference_container=inference_container, + inference_container_uri=inference_container_uri, finetuning_container=finetuning_container, evaluation_container=evaluation_container, artifact_location=artifact_path, diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index 1f32a1ff9..d7ba06abc 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -3,6 +3,7 @@ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import logging +import shlex from typing import Dict, List, Optional, Union from ads.aqua.app import AquaApp, logger @@ -23,6 +24,7 @@ get_params_list, get_resource_name, get_restricted_params_by_container, + validate_cmd_var, ) from ads.aqua.constants import ( AQUA_MODEL_ARTIFACT_FILE, @@ -43,7 +45,9 @@ from ads.common.object_storage_details import ObjectStorageDetails from ads.common.utils import get_log_links from ads.config import ( + AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME, AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME, + AQUA_DEPLOYMENT_CONTAINER_URI_METADATA_NAME, AQUA_MODEL_DEPLOYMENT_CONFIG, COMPARTMENT_OCID, ) @@ -104,6 +108,8 @@ def create( ocpus: Optional[float] = None, model_file: Optional[str] = None, private_endpoint_id: Optional[str] = None, + container_image_uri: Optional[None] = None, + cmd_var: List[str] = None, ) -> "AquaDeployment": """ Creates a new Aqua deployment @@ -152,7 +158,11 @@ def create( The file used for model deployment. private_endpoint_id: str The private endpoint id of model deployment. - + container_image_uri: str + The image of model deployment container runtime, ignored for service managed containers. + Required parameter for BYOC based deployments if this parameter was not set during model registration. + cmd_var: List[str] + The cmd of model deployment container runtime. Returns ------- AquaDeployment @@ -197,9 +207,11 @@ def create( f"from custom metadata for the model {config_source_id}" ) from err - # set up env vars + # set up env and cmd var if not env_var: env_var = {} + if not cmd_var: + cmd_var = [] try: model_path_prefix = aqua_model.custom_metadata_list.get( @@ -235,13 +247,42 @@ def create( model=aqua_model, container_family=container_family ) - # fetch image name from config - container_image = get_container_image(container_type=container_type_key) - + container_image_uri = container_image_uri or get_container_image( + container_type=container_type_key + ) + if not container_image_uri: + try: + container_image_uri = aqua_model.custom_metadata_list.get( + AQUA_DEPLOYMENT_CONTAINER_URI_METADATA_NAME + ).value + except ValueError as err: + raise AquaValueError( + f"{AQUA_DEPLOYMENT_CONTAINER_URI_METADATA_NAME} key is not available in the custom metadata " + f"field. Either re-register the model with custom container URI, or set container_image_uri " + f"parameter when creating this deployment." + ) from err logging.info( - f"Aqua Image used for deploying {aqua_model.id} : {container_image}" + f"Aqua Image used for deploying {aqua_model.id} : {container_image_uri}" ) + try: + cmd_var_string = aqua_model.custom_metadata_list.get( + AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME + ).value + default_cmd_var = shlex.split(cmd_var_string) + if default_cmd_var: + cmd_var = validate_cmd_var(default_cmd_var, cmd_var) + logging.info(f"CMD used for deploying {aqua_model.id} :{cmd_var}") + except ValueError: + logging.debug( + f"CMD will be ignored for this deployment as {AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME} " + f"key is not available in the custom metadata field for this model." + ) + except Exception as e: + logging.error( + f"There was an issue processing CMD arguments. Error: {str(e)}" + ) + model_formats_str = aqua_model.freeform_tags.get( Tags.MODEL_FORMAT, ModelFormat.SAFETENSORS.value ).upper() @@ -309,7 +350,7 @@ def create( # AQUA_LLAMA_CPP_CONTAINER_FAMILY container uses uvicorn that required model/server params # to be set as env vars raise AquaValueError( - f"Currently, parameters cannot be overridden for the container: {container_image}. Please proceed " + f"Currently, parameters cannot be overridden for the container: {container_image_uri}. Please proceed " f"with deployment without parameter overrides." ) @@ -364,7 +405,7 @@ def create( # configure model deployment runtime container_runtime = ( ModelDeploymentContainerRuntime() - .with_image(container_image) + .with_image(container_image_uri) .with_server_port(server_port) .with_health_check_port(health_check_port) .with_env(env_var) @@ -374,6 +415,8 @@ def create( .with_overwrite_existing_artifact(True) .with_remove_existing_artifact(True) ) + if cmd_var: + container_runtime.with_cmd(cmd_var) # configure model deployment and deploy model on container runtime deployment = ( diff --git a/ads/aqua/modeldeployment/entities.py b/ads/aqua/modeldeployment/entities.py index dd2d72e10..4ad4d1a78 100644 --- a/ads/aqua/modeldeployment/entities.py +++ b/ads/aqua/modeldeployment/entities.py @@ -3,7 +3,7 @@ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ from dataclasses import dataclass, field -from typing import Union, Optional +from typing import List, Optional, Union from oci.data_science.models import ( ModelDeployment, @@ -53,6 +53,7 @@ class AquaDeployment(DataClassSerializable): shape_info: Optional[ShapeInfo] = None tags: dict = None environment_variables: dict = None + cmd: List[str] = None @classmethod def from_oci_model_deployment( @@ -81,6 +82,7 @@ def from_oci_model_deployment( ) instance_count = oci_model_deployment.model_deployment_configuration_details.model_configuration_details.scaling_policy.instance_count environment_variables = oci_model_deployment.model_deployment_configuration_details.environment_configuration_details.environment_variables + cmd = oci_model_deployment.model_deployment_configuration_details.environment_configuration_details.cmd shape_info = ShapeInfo( instance_shape=instance_configuration.instance_shape_name, instance_count=instance_count, @@ -99,7 +101,9 @@ def from_oci_model_deployment( freeform_tags = oci_model_deployment.freeform_tags or UNKNOWN_DICT aqua_service_model_tag = freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None) aqua_model_name = freeform_tags.get(Tags.AQUA_MODEL_NAME_TAG, UNKNOWN) - private_endpoint_id = getattr(instance_configuration, "private_endpoint_id", UNKNOWN) + private_endpoint_id = getattr( + instance_configuration, "private_endpoint_id", UNKNOWN + ) return AquaDeployment( id=oci_model_deployment.id, @@ -123,6 +127,7 @@ def from_oci_model_deployment( ), tags=freeform_tags, environment_variables=environment_variables, + cmd=cmd, ) diff --git a/ads/config.py b/ads/config.py index ec6c91396..cd4ee2b07 100644 --- a/ads/config.py +++ b/ads/config.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8; -*- # Copyright (c) 2020, 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ @@ -60,6 +59,8 @@ AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME = "deployment-container" AQUA_FINETUNING_CONTAINER_METADATA_NAME = "finetune-container" AQUA_EVALUATION_CONTAINER_METADATA_NAME = "evaluation-container" +AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME = "container-cmd-var" +AQUA_DEPLOYMENT_CONTAINER_URI_METADATA_NAME = "deployment-container-uri" AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME = "deployment-container-custom" AQUA_FINETUNING_CONTAINER_OVERRIDE_FLAG_METADATA_NAME = "finetune-container-custom" AQUA_MODEL_DEPLOYMENT_FOLDER = "/opt/ds/model/deployed_model/" @@ -206,7 +207,7 @@ def open( frame.f_globals.pop("config", None) # Restores original globals - for key in defined_globals.keys(): + for key in defined_globals: frame.f_globals[key] = defined_globals[key] # Saving config if it necessary diff --git a/tests/unitary/with_extras/aqua/test_data/deployment/aqua_create_embedding_deployment.yaml b/tests/unitary/with_extras/aqua/test_data/deployment/aqua_create_embedding_deployment.yaml new file mode 100644 index 000000000..c55460140 --- /dev/null +++ b/tests/unitary/with_extras/aqua/test_data/deployment/aqua_create_embedding_deployment.yaml @@ -0,0 +1,34 @@ +kind: deployment +spec: + createdBy: ocid1.user.oc1.. + displayName: model-deployment-name + freeformTags: + OCI_AQUA: active + aqua_model_name: model-name + id: "ocid1.datasciencemodeldeployment.oc1.." + infrastructure: + kind: infrastructure + spec: + bandwidthMbps: 10 + compartmentId: ocid1.compartment.oc1.. + deploymentType: SINGLE_MODEL + policyType: FIXED_SIZE + projectId: ocid1.datascienceproject.oc1.iad. + replica: 1 + shapeName: "VM.GPU.A10.1" + type: datascienceModelDeployment + lifecycleState: CREATING + modelDeploymentUrl: "https://modeldeployment.customer-oci.com/ocid1.datasciencemodeldeployment.oc1.." + runtime: + kind: runtime + spec: + env: + BASE_MODEL: service_models/model-name/artifact + MODEL_DEPLOY_PREDICT_ENDPOINT: /v1/embeddings + healthCheckPort: 8080 + image: "dsmc://image-name:1.0.0.0" + modelUri: "ocid1.datasciencemodeldeployment.oc1.." + serverPort: 8080 + type: container + timeCreated: 2024-01-01T00:00:00.000000+00:00 +type: modelDeployment diff --git a/tests/unitary/with_extras/aqua/test_data/deployment/aqua_tei_byoc_embedding_model.yaml b/tests/unitary/with_extras/aqua/test_data/deployment/aqua_tei_byoc_embedding_model.yaml new file mode 100644 index 000000000..350a925bc --- /dev/null +++ b/tests/unitary/with_extras/aqua/test_data/deployment/aqua_tei_byoc_embedding_model.yaml @@ -0,0 +1,88 @@ +kind: datascienceModel +spec: + artifact: oci://service-managed-models@namespace/service_models/model-name/artifact + compartmentId: ocid1.compartment.oc1.. + customMetadataList: + data: + - category: Other + description: artifact location + key: artifact_location + value: service_models/model-name/artifact + - category: Other + description: model by reference flag + key: modelDescription + value: true + - category: Other + description: Deployment container mapping for model model-name + key: deployment-container + value: odsc-tei-serving + - category: Other + description: Inference container URI for model model-name + key: deployment-container-uri + value: region.ocir.io/tenancy/image_name:tag + - category: Other + description: Inference container cmd vars for model-name + key: container-cmd-var + value: --model-id,/opt/ds/model/deployed_model/service_models/model-name/artifact/,--port,8080 + definedTags: {} + description: Mock model description + displayName: model-name + freeformTags: + OCI_AQUA: active + license: License + organization: Organization + ready_to_fine_tune: false + task: text_embedding + id: ocid1.datasciencemodel.oc1.iad. + lifecycleState: ACTIVE + modelDescription: + models: + - bucketName: service-managed-models + namespace: namespace + objects: + - name: service_models/model-name/artifact/README.md + sizeInBytes: 10317 + version: 450a8124-f5ca-4ee6-b4cf-c1dc05b13d46 + - name: service_models/model-name/artifact/config.json + sizeInBytes: 950 + version: 3ace781b-4a48-4e89-88b6-61f0db6d51ad + - name: service_models/model-name/artifact/configuration_RW.py + sizeInBytes: 2607 + version: ba1df5b6-7546-42e5-964e-63cd013e988c + - name: service_models/model-name/artifact/generation_config.json + sizeInBytes: 111 + version: e23a04c8-9725-4f20-8bb1-f455129e2a4e + - name: service_models/model-name/artifact/modelling_RW.py + sizeInBytes: 47560 + version: a584c221-afab-441f-901d-fbe8251dccf6 + - name: service_models/model-name/artifact/pytorch_model-00001-of-00002.bin + sizeInBytes: 9951028193 + version: e919676e-48dd-4bea-af82-14b5f3eb2b9b + - name: service_models/model-name/artifact/pytorch_model-00002-of-00002.bin + sizeInBytes: 4483421659 + version: d6255d3e-bd91-4c05-b3ca-fc1be576ee10 + - name: service_models/model-name/artifact/pytorch_model.bin.index.json + sizeInBytes: 16924 + version: 0419428c-2a7b-45d9-bb78-142fe0630017 + - name: service_models/model-name/artifact/special_tokens_map.json + sizeInBytes: 281 + version: 5569231a-a526-4881-8945-a94a1bb59b2e + - name: service_models/model-name/artifact/tokenizer.json + sizeInBytes: 2734130 + version: d3a8a00a-de79-4d80-aa69-d8f68ee800ec + - name: service_models/model-name/artifact/tokenizer_config.json + sizeInBytes: 220 + version: 84eed6ff-c1ed-4641-8c10-e6a49364d7dd + prefix: service_models/model-name/artifact + type: modelOSSReferenceDescription + version: '1.0' + projectId: ocid1.datascienceproject.oc1.iad. + provenanceMetadata: + artifact_dir: null + git_branch: null + git_commit: 123456 + repository_url: https://model-name-url.com + training_id: null + training_script_path: null + timeCreated: 2024-01-01T00:00:00.000000+00:00 +type: dataScienceModel diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index f50a73721..1a9d69c87 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -171,6 +171,78 @@ class TestDataset: } ] + model_deployment_object_tei_byoc = [ + { + "category_log_details": oci.data_science.models.CategoryLogDetails( + **{ + "access": oci.data_science.models.LogDetails( + **{ + "log_group_id": "ocid1.loggroup.oc1..", + "log_id": "ocid1.log.oc1..", + } + ), + "predict": oci.data_science.models.LogDetails( + **{ + "log_group_id": "ocid1.loggroup.oc1..", + "log_id": "ocid1.log.oc1..", + } + ), + } + ), + "compartment_id": "ocid1.compartment.oc1..", + "created_by": "ocid1.user.oc1..", + "defined_tags": {}, + "description": "Mock description", + "display_name": "model-deployment-name", + "freeform_tags": {"OCI_AQUA": "active", "aqua_model_name": "model-name"}, + "id": "ocid1.datasciencemodeldeployment.oc1..", + "lifecycle_state": "ACTIVE", + "model_deployment_configuration_details": oci.data_science.models.SingleModelDeploymentConfigurationDetails( + **{ + "deployment_type": "SINGLE_MODEL", + "environment_configuration_details": oci.data_science.models.OcirModelDeploymentEnvironmentConfigurationDetails( + **{ + "cmd": [ + "--model-id", + "/opt/ds/model/deployed_model/service_models/model-name/artifact/", + "--port", + "8080", + ], + "entrypoint": [], + "environment_configuration_type": "OCIR_CONTAINER", + "environment_variables": { + "BASE_MODEL": "service_models/model-name/artifact", + "MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/embeddings", + }, + "health_check_port": 8080, + "image": "dsmc://image-name:1.0.0.0", + "image_digest": "sha256:mock22373c16f2015f6f33c5c8553923cf8520217da0bd9504471c5e53cbc9d", + "server_port": 8080, + } + ), + "model_configuration_details": oci.data_science.models.ModelConfigurationDetails( + **{ + "bandwidth_mbps": 10, + "instance_configuration": oci.data_science.models.InstanceConfiguration( + **{ + "instance_shape_name": DEPLOYMENT_SHAPE_NAME, + "model_deployment_instance_shape_config_details": null, + } + ), + "model_id": "ocid1.datasciencemodel.oc1..", + "scaling_policy": oci.data_science.models.FixedSizeScalingPolicy( + **{"instance_count": 1, "policy_type": "FIXED_SIZE"} + ), + } + ), + } + ), + "model_deployment_url": MODEL_DEPLOYMENT_URL, + "project_id": "ocid1.datascienceproject.oc1..", + "time_created": "2024-01-01T00:00:00.000000+00:00", + } + ] + aqua_deployment_object = { "id": "ocid1.datasciencemodeldeployment.oc1..", "display_name": "model-deployment-name", @@ -188,6 +260,7 @@ class TestDataset: "MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions", "PARAMS": "--served-model-name odsc-llm --seed 42", }, + "cmd": [], "console_link": "https://cloud.oracle.com/data-science/model-deployments/ocid1.datasciencemodeldeployment.oc1..?region=region-name", "lifecycle_details": "", "shape_info": { @@ -236,6 +309,25 @@ class TestDataset: "top_k": 10, } + aqua_deployment_tei_byoc_embeddings_env_vars = { + "BASE_MODEL": "service_models/model-name/artifact", + "MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/embeddings", + } + + aqua_deployment_tei_byoc_embeddings_shape_info = { + "instance_shape": DEPLOYMENT_SHAPE_NAME, + "instance_count": 1, + "ocpus": None, + "memory_in_gbs": None, + } + + aqua_deployment_tei_byoc_embeddings_cmd = [ + "--model-id", + "/opt/ds/model/deployed_model/service_models/model-name/artifact/", + "--port", + "8080", + ] + class TestAquaDeployment(unittest.TestCase): def setUp(self): @@ -562,6 +654,84 @@ def test_create_deployment_for_gguf_model( ) assert actual_attributes == expected_result + # @patch("ads.aqua.modeldeployment.deployment.get_container_config") + # @patch("ads.aqua.model.AquaModelApp.create") + # @patch("ads.aqua.modeldeployment.deployment.get_container_image") + # @patch("ads.model.deployment.model_deployment.ModelDeployment.deploy") + # def test_create_deployment_for_tei_byoc_embedding_model( + # self, + # mock_deploy, + # mock_get_container_image, + # mock_create, + # mock_get_container_config, + # ): + # """Test to create a deployment for fine-tuned model""" + # aqua_model = os.path.join( + # self.curr_dir, "test_data/deployment/aqua_tei_byoc_embedding_model.yaml" + # ) + # datascience_model = DataScienceModel.from_yaml(uri=aqua_model) + # mock_create.return_value = datascience_model + # + # config_json = os.path.join( + # self.curr_dir, "test_data/deployment/deployment_config.json" + # ) + # with open(config_json, "r") as _file: + # config = json.load(_file) + # + # self.app.get_deployment_config = MagicMock(return_value=config) + # + # container_index_json = os.path.join( + # self.curr_dir, "test_data/ui/container_index.json" + # ) + # with open(container_index_json, "r") as _file: + # container_index_config = json.load(_file) + # mock_get_container_config.return_value = container_index_config + # + # mock_get_container_image.return_value = TestDataset.DEPLOYMENT_IMAGE_NAME + # aqua_deployment = os.path.join( + # self.curr_dir, "test_data/deployment/aqua_create_embedding_deployment.yaml" + # ) + # model_deployment_obj = ModelDeployment.from_yaml(uri=aqua_deployment) + # model_deployment_dsc_obj = copy.deepcopy( + # TestDataset.model_deployment_object_tei_byoc[0] + # ) + # model_deployment_dsc_obj["lifecycle_state"] = "CREATING" + # model_deployment_obj.dsc_model_deployment = ( + # oci.data_science.models.ModelDeploymentSummary(**model_deployment_dsc_obj) + # ) + # mock_deploy.return_value = model_deployment_obj + # + # result = self.app.create( + # model_id=TestDataset.MODEL_ID, + # instance_shape=TestDataset.DEPLOYMENT_SHAPE_NAME, + # display_name="model-deployment-name", + # log_group_id="ocid1.loggroup.oc1..", + # access_log_id="ocid1.log.oc1..", + # predict_log_id="ocid1.log.oc1..", + # container_family="odsc-tei-serving", + # cmd_var=[], + # ) + # + # mock_create.assert_called_with( + # model_id=TestDataset.MODEL_ID, compartment_id=None, project_id=None + # ) + # mock_get_container_image.assert_called() + # mock_deploy.assert_called() + # + # expected_attributes = set(AquaDeployment.__annotations__.keys()) + # actual_attributes = asdict(result) + # assert set(actual_attributes) == set(expected_attributes), "Attributes mismatch" + # expected_result = copy.deepcopy(TestDataset.aqua_deployment_object) + # expected_result["state"] = "CREATING" + # expected_result["shape_info"] = ( + # TestDataset.aqua_deployment_tei_byoc_embeddings_shape_info + # ) + # expected_result["cmd"] = TestDataset.aqua_deployment_tei_byoc_embeddings_cmd + # expected_result["environment_variables"] = ( + # TestDataset.aqua_deployment_tei_byoc_embeddings_env_vars + # ) + # assert actual_attributes == expected_result + @parameterized.expand( [ ( diff --git a/tests/unitary/with_extras/aqua/test_model.py b/tests/unitary/with_extras/aqua/test_model.py index c0386fc69..f84dd604c 100644 --- a/tests/unitary/with_extras/aqua/test_model.py +++ b/tests/unitary/with_extras/aqua/test_model.py @@ -24,6 +24,7 @@ AquaModel, ModelValidationResult, ) +from ads.aqua.common.utils import get_hf_model_info import ads.common import ads.common.oci_client import ads.config @@ -64,7 +65,7 @@ def mock_get_container_config(): yield mock_config -@pytest.fixture(autouse=True, scope="class") +@pytest.fixture(autouse=True, scope="function") def mock_get_hf_model_info(): with patch.object(HfApi, "model_info") as mock_get_hf_model_info: test_hf_model_info = ModelInfo( @@ -229,17 +230,17 @@ class TestDataset: class TestAquaModel: """Contains unittests for AquaModelApp.""" - @pytest.fixture(autouse=True, scope="class") - def mock_auth(cls): - with patch("ads.common.auth.default_signer") as mock_default_signer: - yield mock_default_signer - - @pytest.fixture(autouse=True, scope="class") - def mock_init_client(cls): - with patch( - "ads.common.oci_datascience.OCIDataScienceMixin.init_client" - ) as mock_client: - yield mock_client + # @pytest.fixture(autouse=True, scope="class") + # def mock_auth(cls): + # with patch("ads.common.auth.default_signer") as mock_default_signer: + # yield mock_default_signer + # + # @pytest.fixture(autouse=True, scope="class") + # def mock_init_client(cls): + # with patch( + # "ads.common.oci_datascience.OCIDataScienceMixin.init_client" + # ) as mock_client: + # yield mock_client def setup_method(self): self.default_signer_patch = patch( @@ -266,6 +267,7 @@ def teardown_method(self): self.create_signer_patch.stop() self.validate_config_patch.stop() self.create_client_patch.stop() + get_hf_model_info.cache_clear() @classmethod def setup_class(cls): @@ -465,6 +467,7 @@ def test_get_foundation_models( "task": f'{ds_model.freeform_tags["task"]}', "time_created": f"{ds_model.time_created}", "inference_container": "odsc-vllm-serving", + "inference_container_uri": None, "finetuning_container": "odsc-llm-fine-tuning", "evaluation_container": "odsc-llm-evaluate", } @@ -643,6 +646,7 @@ def test_get_model_fine_tuned( "time_created": f"{ds_model.time_created}", "validation": {"type": "Automatic split", "value": "test_val_set_size"}, "inference_container": "odsc-vllm-serving", + "inference_container_uri": None, "finetuning_container": "odsc-llm-fine-tuning", "evaluation_container": "odsc-llm-evaluate", } @@ -656,6 +660,9 @@ def test_get_model_fine_tuned( (False, False), ], ) + @patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create") + @patch("ads.model.datascience_model.DataScienceModel.sync") + @patch("ads.model.datascience_model.DataScienceModel.upload_artifact") @patch.object(AquaModelApp, "_find_matching_aqua_model") @patch("ads.aqua.common.utils.copy_file") @patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects") @@ -670,16 +677,15 @@ def test_import_verified_model( mock_list_objects, mock_copy_file, mock__find_matching_aqua_model, + mock_upload_artifact, + mock_sync, + mock_ocidsc_create, artifact_location_set, download_from_hf, mock_get_hf_model_info, + mock_init_client, ): ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) - ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock() - DataScienceModel.upload_artifact = MagicMock() - DataScienceModel.sync = MagicMock() - OCIDataScienceModel.create = MagicMock() - # The name attribute cannot be mocked during creation of the mock object, # hence attach it separately to the mocked objects. artifact_path = "service_models/model-name/commit-id/artifact" @@ -778,17 +784,21 @@ def test_import_verified_model( assert model.ready_to_deploy is True assert model.ready_to_finetune is False + @patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create") + @patch("ads.model.datascience_model.DataScienceModel.sync") + @patch("ads.model.datascience_model.DataScienceModel.upload_artifact") @patch.object(AquaModelApp, "_validate_model") @patch("ads.aqua.common.utils.load_config", return_value={}) def test_import_any_model_no_containers_specified( - self, mock_load_config, mock__validate_model, mock_get_hf_model_info + self, + mock_load_config, + mock__validate_model, + mock_upload_artifact, + mock_sync, + mock_ocidsc_create, + mock_get_hf_model_info, ): ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) - ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock() - DataScienceModel.upload_artifact = MagicMock() - DataScienceModel.sync = MagicMock() - OCIDataScienceModel.create = MagicMock() - os_path = "oci://aqua-bkt@aqua-ns/prefix/path" model_name = "oracle/aqua-1t-mega-model" ds_freeform_tags = { @@ -825,6 +835,9 @@ def test_import_any_model_no_containers_specified( "download_from_hf", [True, False], ) + @patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create") + @patch("ads.model.datascience_model.DataScienceModel.sync") + @patch("ads.model.datascience_model.DataScienceModel.upload_artifact") @patch.object(AquaModelApp, "_find_matching_aqua_model") @patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects") @patch("ads.aqua.common.utils.load_config", return_value={}) @@ -837,14 +850,13 @@ def test_import_model_with_project_compartment_override( mock_load_config, mock_list_objects, mock__find_matching_aqua_model, + mock_upload_artifact, + mock_sync, + mock_ocidsc_create, download_from_hf, mock_get_hf_model_info, ): ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) - ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock() - DataScienceModel.upload_artifact = MagicMock() - DataScienceModel.sync = MagicMock() - OCIDataScienceModel.create = MagicMock() mock_list_objects.return_value = MagicMock(objects=[]) ds_model = DataScienceModel() @@ -906,6 +918,8 @@ def test_import_model_with_project_compartment_override( "download_from_hf", [True, False], ) + @patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create") + @patch("ads.model.datascience_model.DataScienceModel.upload_artifact") @patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects") @patch("ads.aqua.common.utils.load_config", side_effect=AquaFileNotFoundError) @patch("huggingface_hub.snapshot_download") @@ -916,17 +930,18 @@ def test_import_model_with_missing_config( mock_snapshot_download, mock_load_config, mock_list_objects, + mock_upload_artifact, + mock_ocidsc_create, mock_get_container_config, download_from_hf, mock_get_hf_model_info, + mock_init_client, ): """Test for validating if error is returned when model artifacts are incomplete or not available.""" os_path = "oci://aqua-bkt@aqua-ns/prefix/path" model_name = "oracle/aqua-1t-mega-model" ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) - ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock() - DataScienceModel.upload_artifact = MagicMock() mock_list_objects.return_value = MagicMock(objects=[]) reload(ads.aqua.model.model) app = AquaModelApp() @@ -950,6 +965,9 @@ def test_import_model_with_missing_config( download_from_hf=False, ) + @patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create") + @patch("ads.model.datascience_model.DataScienceModel.sync") + @patch("ads.model.datascience_model.DataScienceModel.upload_artifact") @patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects") @patch.object(HfApi, "model_info") @patch("ads.aqua.common.utils.load_config", return_value={}) @@ -957,14 +975,14 @@ def test_import_any_model_smc_container( self, mock_load_config, mock_list_objects, + mock_upload_artifact, + mock_sync, + mock_ocidsc_create, mock_get_hf_model_info, + mock_init_client, ): my_model = "oracle/aqua-1t-mega-model" ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) - ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock() - DataScienceModel.upload_artifact = MagicMock() - DataScienceModel.sync = MagicMock() - OCIDataScienceModel.create = MagicMock() os_path = "oci://aqua-bkt@aqua-ns/prefix/path" ds_freeform_tags = { @@ -1012,6 +1030,90 @@ def test_import_any_model_smc_container( assert model.ready_to_deploy is True assert model.ready_to_finetune is True + @pytest.mark.parametrize( + "download_from_hf", + [True, False], + ) + @patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create") + @patch("ads.model.datascience_model.DataScienceModel.sync") + @patch("ads.model.datascience_model.DataScienceModel.upload_artifact") + @patch.object(AquaModelApp, "_find_matching_aqua_model") + @patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects") + @patch("ads.aqua.common.utils.load_config", return_value={}) + @patch("huggingface_hub.snapshot_download") + @patch("subprocess.check_call") + def test_import_tei_model_byoc( + self, + mock_subprocess, + mock_snapshot_download, + mock_load_config, + mock_list_objects, + mock__find_matching_aqua_model, + mock_upload_artifact, + mock_sync, + mock_ocidsc_create, + download_from_hf, + mock_get_hf_model_info, + mock_init_client, + ): + ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) + + artifact_path = "service_models/model-name/commit-id/artifact" + obj1 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150) + obj1.name = f"{artifact_path}/config.json" + objects = [obj1] + mock_list_objects.return_value = MagicMock(objects=objects) + ds_model = DataScienceModel() + os_path = "oci://aqua-bkt@aqua-ns/prefix/path" + model_name = "oracle/aqua-1t-mega-model" + ds_freeform_tags = { + "OCI_AQUA": "ACTIVE", + "license": "aqua-license", + "organization": "oracle", + "task": "text_embedding", + } + ds_model = ( + ds_model.with_compartment_id("test_model_compartment_id") + .with_project_id("test_project_id") + .with_display_name(model_name) + .with_description("test_description") + .with_model_version_set_id("test_model_version_set_id") + .with_freeform_tags(**ds_freeform_tags) + .with_version_id("ocid1.version.id") + ) + custom_metadata_list = ModelCustomMetadata() + custom_metadata_list.add( + **{"key": "deployment-container", "value": "odsc-tei-serving"} + ) + ds_model.with_custom_metadata_list(custom_metadata_list) + ds_model.set_spec(ds_model.CONST_MODEL_FILE_DESCRIPTION, {}) + DataScienceModel.from_id = MagicMock(return_value=ds_model) + mock__find_matching_aqua_model.return_value = None + reload(ads.aqua.model.model) + app = AquaModelApp() + + if download_from_hf: + with tempfile.TemporaryDirectory() as tmpdir: + model: AquaModel = app.register( + model=model_name, + os_path=os_path, + local_dir=str(tmpdir), + download_from_hf=True, + inference_container="odsc-tei-serving", + inference_container_uri="region.ocir.io/your_tenancy/your_image", + ) + else: + model: AquaModel = app.register( + model="ocid1.datasciencemodel.xxx.xxxx.", + os_path=os_path, + download_from_hf=False, + inference_container="odsc-tei-serving", + inference_container_uri="region.ocir.io/your_tenancy/your_image", + ) + assert model.inference_container == "odsc-tei-serving" + assert model.ready_to_deploy is True + assert model.ready_to_finetune is False + @pytest.mark.parametrize( "data, expected_output", [ @@ -1047,6 +1149,15 @@ def test_import_any_model_smc_container( }, "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --model_file test_model_file", ), + ( + { + "os_path": "oci://aqua-bkt@aqua-ns/path", + "model": "oracle/oracle-1it", + "inference_container": "odsc-tei-serving", + "inference_container_uri": ".ocir.io//", + }, + "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --inference_container odsc-tei-serving --inference_container_uri .ocir.io//", + ), ], ) def test_import_cli(self, data, expected_output):