From c8f38a0a51c318a5065438067f85f31be5088af1 Mon Sep 17 00:00:00 2001 From: Ray Xu Date: Mon, 3 Nov 2025 14:54:15 -0800 Subject: [PATCH] feat: Add reservation affinity support to preview BatchPredictionJob PiperOrigin-RevId: 827661239 --- google/cloud/aiplatform/jobs.py | 27 +- google/cloud/aiplatform/preview/jobs.py | 555 +++++++++++++++++- .../test_batch_prediction_job_preview.py | 301 ++++++++++ 3 files changed, 880 insertions(+), 3 deletions(-) create mode 100644 tests/unit/aiplatform/test_batch_prediction_job_preview.py diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index e854faa3e6..472c1661d5 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -319,7 +319,7 @@ def cancel(self) -> None: getattr(self.api_client, self._cancel_method)(name=self.resource_name) -class BatchPredictionJob(_Job): +class BatchPredictionJob(_Job, base.PreviewMixin): _resource_noun = "batchPredictionJobs" _getter_method = "get_batch_prediction_job" @@ -329,6 +329,9 @@ class BatchPredictionJob(_Job): _job_type = "batch-predictions" _parse_resource_name_method = "parse_batch_prediction_job_path" _format_resource_name_method = "batch_prediction_job_path" + _preview_class = ( + "google.cloud.aiplatform.aiplatform.preview.jobs.BatchPredictionJob" + ) def __init__( self, @@ -949,6 +952,9 @@ def _submit_impl( ] = None, analysis_instance_schema_uri: Optional[str] = None, service_account: Optional[str] = None, + reservation_affinity_type: Optional[str] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[Sequence[str]] = None, wait_for_completion: bool = False, ) -> "BatchPredictionJob": """Create a batch prediction job. @@ -1136,6 +1142,18 @@ def _submit_impl( service_account (str): Optional. Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. + reservation_affinity_type (str): + Optional. The type of reservation affinity. + One of NO_RESERVATION, ANY_RESERVATION, SPECIFIC_RESERVATION, + SPECIFIC_THEN_ANY_RESERVATION, SPECIFIC_THEN_NO_RESERVATION + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' wait_for_completion (bool): Whether to wait for the job completion. Returns: @@ -1268,6 +1286,13 @@ def _submit_impl( machine_spec.accelerator_type = accelerator_type machine_spec.accelerator_count = accelerator_count + if reservation_affinity_type: + machine_spec.reservation_affinity = utils.get_reservation_affinity( + reservation_affinity_type, + reservation_affinity_key, + reservation_affinity_values, + ) + dedicated_resources = gca_machine_resources_compat.BatchDedicatedResources() dedicated_resources.machine_spec = machine_spec diff --git a/google/cloud/aiplatform/preview/jobs.py b/google/cloud/aiplatform/preview/jobs.py index b8ed5519e7..645800e408 100644 --- a/google/cloud/aiplatform/preview/jobs.py +++ b/google/cloud/aiplatform/preview/jobs.py @@ -15,9 +15,8 @@ # limitations under the License. # -from typing import Dict, List, Optional, Union - import copy +from typing import Dict, List, Optional, Sequence, Union import uuid from google.api_core import retry @@ -68,6 +67,12 @@ gca_job_state_v1beta1.JobState.JOB_STATE_CANCELLED, ) +# _block_until_complete wait times +_JOB_WAIT_TIME = 5 # start at five seconds +_LOG_WAIT_TIME = 5 +_MAX_WAIT_TIME = 60 * 5 # 5 minute wait +_WAIT_TIME_MULTIPLIER = 2 # scale wait by 2 every iteration + class CustomJob(jobs.CustomJob): """Deprecated. Vertex AI Custom Job (preview).""" @@ -867,3 +872,549 @@ def _run( ) self._block_until_complete() + + +class BatchPredictionJob(jobs.BatchPredictionJob): + """Vertex AI Batch Prediction Job.""" + + @classmethod + def create( + cls, + # TODO(b/223262536): Make the job_display_name parameter optional in the next major release + job_display_name: str, + model_name: Union[str, "aiplatform.Model"], + instances_format: str = "jsonl", + predictions_format: str = "jsonl", + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bigquery_source: Optional[str] = None, + gcs_destination_prefix: Optional[str] = None, + bigquery_destination_prefix: Optional[str] = None, + model_parameters: Optional[Dict] = None, + machine_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + starting_replica_count: Optional[int] = None, + max_replica_count: Optional[int] = None, + generate_explanation: Optional[bool] = False, + explanation_metadata: Optional["aiplatform.explain.ExplanationMetadata"] = None, + explanation_parameters: Optional[ + "aiplatform.explain.ExplanationParameters" + ] = None, + labels: Optional[Dict[str, str]] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + sync: bool = True, + create_request_timeout: Optional[float] = None, + batch_size: Optional[int] = None, + model_monitoring_objective_config: Optional[ + "aiplatform.model_monitoring.ObjectiveConfig" + ] = None, + model_monitoring_alert_config: Optional[ + "aiplatform.model_monitoring.AlertConfig" + ] = None, + analysis_instance_schema_uri: Optional[str] = None, + service_account: Optional[str] = None, + reservation_affinity_type: Optional[str] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, + ) -> "BatchPredictionJob": + """Create a batch prediction job. + + Args: + job_display_name (str): + Required. The user-defined name of the BatchPredictionJob. + The name can be up to 128 characters long and can be consist + of any UTF-8 characters. + model_name (Union[str, aiplatform.Model]): + Required. A fully-qualified model resource name or model ID. + Example: "projects/123/locations/us-central1/models/456" or + "456" when project and location are initialized or passed. + May optionally contain a version ID or alias in + {model_name}@{version} form. + + Or an instance of aiplatform.Model. + instances_format (str): + Required. The format in which instances are provided. Must be one + of the formats listed in `Model.supported_input_storage_formats`. + Default is "jsonl" when using `gcs_source`. If a `bigquery_source` + is provided, this is overridden to "bigquery". + predictions_format (str): + Required. The format in which Vertex AI outputs the + predictions, must be one of the formats specified in + `Model.supported_output_storage_formats`. + Default is "jsonl" when using `gcs_destination_prefix`. If a + `bigquery_destination_prefix` is provided, this is overridden to + "bigquery". + gcs_source (Optional[Sequence[str]]): + Google Cloud Storage URI(-s) to your instances to run + batch prediction on. They must match `instances_format`. + + bigquery_source (Optional[str]): + BigQuery URI to a table, up to 2000 characters long. For example: + `bq://projectId.bqDatasetId.bqTableId` + gcs_destination_prefix (Optional[str]): + The Google Cloud Storage location of the directory where the + output is to be written to. In the given directory a new + directory is created. Its name is + ``prediction--``, where + timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 format. + Inside of it files ``predictions_0001.``, + ``predictions_0002.``, ..., + ``predictions_N.`` are created where + ```` depends on chosen ``predictions_format``, + and N may equal 0001 and depends on the total number of + successfully predicted instances. If the Model has both + ``instance`` and ``prediction`` schemata defined then each such + file contains predictions as per the ``predictions_format``. + If prediction for any instance failed (partially or + completely), then an additional ``errors_0001.``, + ``errors_0002.``,..., ``errors_N.`` + files are created (N depends on total number of failed + predictions). These files contain the failed instances, as + per their schema, followed by an additional ``error`` field + which as value has ```google.rpc.Status`` `__ + containing only ``code`` and ``message`` fields. + bigquery_destination_prefix (Optional[str]): + The BigQuery project or dataset location where the output is + to be written to. If project is provided, a new dataset is + created with name + ``prediction__`` where + is made BigQuery-dataset-name compatible (for example, most + special characters become underscores), and timestamp is in + YYYY_MM_DDThh_mm_ss_sssZ "based on ISO-8601" format. In the + dataset two tables will be created, ``predictions``, and + ``errors``. If the Model has both + [instance][google.cloud.aiplatform.v1.PredictSchemata.instance_schema_uri] + and + [prediction][google.cloud.aiplatform.v1.PredictSchemata.parameters_schema_uri] + schemata defined then the tables have columns as follows: + The ``predictions`` table contains instances for which the + prediction succeeded, it has columns as per a concatenation + of the Model's instance and prediction schemata. The + ``errors`` table contains rows for which the prediction has + failed, it has instance columns, as per the instance schema, + followed by a single "errors" column, which as values has + [google.rpc.Status][google.rpc.Status] represented as a + STRUCT, and containing only ``code`` and ``message``. + model_parameters (Optional[Dict]): + The parameters that govern the predictions. The schema of + the parameters may be specified via the Model's `parameters_schema_uri`. + machine_type (Optional[str]): + The type of machine for running batch prediction on + dedicated resources. Not specifying machine type will result in + batch prediction job being run with automatic resources. + accelerator_type (Optional[str]): + The type of accelerator(s) that may be attached + to the machine as per `accelerator_count`. Only used if + `machine_type` is set. + accelerator_count (Optional[int]): + The number of accelerators to attach to the + `machine_type`. Only used if `machine_type` is set. + starting_replica_count (Optional[int]): + The number of machine replicas used at the start of the batch + operation. If not set, Vertex AI decides starting number, not + greater than `max_replica_count`. Only used if `machine_type` is + set. + max_replica_count (Optional[int]): + The maximum number of machine replicas the batch operation may + be scaled to. Only used if `machine_type` is set. + Default is 10. + generate_explanation (bool): + Optional. Generate explanation along with the batch prediction + results. This will cause the batch prediction output to include + explanations based on the `prediction_format`: + - `bigquery`: output includes a column named `explanation`. The value + is a struct that conforms to the [aiplatform.gapic.Explanation] object. + - `jsonl`: The JSON objects on each line include an additional entry + keyed `explanation`. The value of the entry is a JSON object that + conforms to the [aiplatform.gapic.Explanation] object. + - `csv`: Generating explanations for CSV format is not supported. + explanation_metadata (aiplatform.explain.ExplanationMetadata): + Optional. Explanation metadata configuration for this BatchPredictionJob. + Can be specified only if `generate_explanation` is set to `True`. + + This value overrides the value of `Model.explanation_metadata`. + All fields of `explanation_metadata` are optional in the request. If + a field of the `explanation_metadata` object is not populated, the + corresponding field of the `Model.explanation_metadata` object is inherited. + For more details, see `Ref docs ` + explanation_parameters (aiplatform.explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's predictions. + Can be specified only if `generate_explanation` is set to `True`. + + This value overrides the value of `Model.explanation_parameters`. + All fields of `explanation_parameters` are optional in the request. If + a field of the `explanation_parameters` object is not populated, the + corresponding field of the `Model.explanation_parameters` object is inherited. + For more details, see `Ref docs ` + labels (Dict[str, str]): + Optional. The labels with user-defined metadata to organize your + BatchPredictionJobs. Label keys and values can be no longer than + 64 characters (Unicode codepoints), can only contain lowercase + letters, numeric characters, underscores and dashes. + International characters are allowed. See https://goo.gl/xmQnxf + for more information and examples of labels. + credentials (Optional[auth_credentials.Credentials]): + Custom credentials to use to create this batch prediction + job. Overrides credentials set in aiplatform.init. + encryption_spec_key_name (Optional[str]): + Optional. The Cloud KMS resource identifier of the customer + managed encryption key used to protect the job. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. + + If this is set, then all + resources created by the BatchPredictionJob will + be encrypted with the provided encryption key. + + Overrides encryption_spec_key_name set in aiplatform.init. + sync (bool): + Whether to execute this method synchronously. If False, this method + will be executed in concurrent Future and any downstream object will + be immediately returned and synced when the Future has completed. + create_request_timeout (float): + Optional. The timeout for the create request in seconds. + batch_size (int): + Optional. The number of the records (e.g. instances) of the operation given in each batch + to a machine replica. Machine type, and size of a single record should be considered + when setting this parameter, higher value speeds up the batch operation's execution, + but too high value will result in a whole batch not fitting in a machine's memory, + and the whole operation will fail. + The default value is 64. + model_monitoring_objective_config (aiplatform.model_monitoring.ObjectiveConfig): + Optional. The objective config for model monitoring. Passing this parameter enables + monitoring on the model associated with this batch prediction job. + model_monitoring_alert_config (aiplatform.model_monitoring.EmailAlertConfig): + Optional. Configures how model monitoring alerts are sent to the user. Right now + only email alert is supported. + analysis_instance_schema_uri (str): + Optional. Only applicable if model_monitoring_objective_config is also passed. + This parameter specifies the YAML schema file uri describing the format of a single + instance that you want Tensorflow Data Validation (TFDV) to + analyze. If this field is empty, all the feature data types are + inferred from predict_instance_schema_uri, meaning that TFDV + will use the data in the exact format as prediction request/response. + If there are any data type differences between predict instance + and TFDV instance, this field can be used to override the schema. + For models trained with Vertex AI, this field must be set as all the + fields in predict instance formatted as string. + service_account (str): + Optional. Specifies the service account for workload run-as account. + Users submitting jobs must have act-as permission on this run-as account. + reservation_affinity_type (str): + Optional. The type of reservation affinity. + One of NO_RESERVATION, ANY_RESERVATION, SPECIFIC_RESERVATION, + SPECIFIC_THEN_ANY_RESERVATION, SPECIFIC_THEN_NO_RESERVATION + reservation_affinity_key (str): + Optional. Corresponds to the label key of a reservation resource. + To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key + and specify the name of your reservation as its value. + reservation_affinity_values (List[str]): + Optional. Corresponds to the label values of a reservation resource. + This must be the full resource name of the reservation. + Format: 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' + Returns: + (jobs.BatchPredictionJob): + Instantiated representation of the created batch prediction job. + """ + return cls._submit_impl( + job_display_name=job_display_name, + model_name=model_name, + instances_format=instances_format, + predictions_format=predictions_format, + gcs_source=gcs_source, + bigquery_source=bigquery_source, + gcs_destination_prefix=gcs_destination_prefix, + bigquery_destination_prefix=bigquery_destination_prefix, + model_parameters=model_parameters, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + starting_replica_count=starting_replica_count, + max_replica_count=max_replica_count, + generate_explanation=generate_explanation, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + labels=labels, + project=project, + location=location, + credentials=credentials, + encryption_spec_key_name=encryption_spec_key_name, + sync=sync, + create_request_timeout=create_request_timeout, + batch_size=batch_size, + model_monitoring_objective_config=model_monitoring_objective_config, + model_monitoring_alert_config=model_monitoring_alert_config, + analysis_instance_schema_uri=analysis_instance_schema_uri, + service_account=service_account, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, + # Main distinction of `create` vs `submit`: + wait_for_completion=True, + ) + + @classmethod + def submit( + cls, + *, + job_display_name: Optional[str] = None, + model_name: Union[str, "aiplatform.Model"], + instances_format: str = "jsonl", + predictions_format: str = "jsonl", + gcs_source: Optional[Union[str, Sequence[str]]] = None, + bigquery_source: Optional[str] = None, + gcs_destination_prefix: Optional[str] = None, + bigquery_destination_prefix: Optional[str] = None, + model_parameters: Optional[Dict] = None, + machine_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + accelerator_count: Optional[int] = None, + starting_replica_count: Optional[int] = None, + max_replica_count: Optional[int] = None, + generate_explanation: Optional[bool] = False, + explanation_metadata: Optional["aiplatform.explain.ExplanationMetadata"] = None, + explanation_parameters: Optional[ + "aiplatform.explain.ExplanationParameters" + ] = None, + labels: Optional[Dict[str, str]] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + encryption_spec_key_name: Optional[str] = None, + create_request_timeout: Optional[float] = None, + batch_size: Optional[int] = None, + model_monitoring_objective_config: Optional[ + "aiplatform.model_monitoring.ObjectiveConfig" + ] = None, + model_monitoring_alert_config: Optional[ + "aiplatform.model_monitoring.AlertConfig" + ] = None, + analysis_instance_schema_uri: Optional[str] = None, + service_account: Optional[str] = None, + reservation_affinity_type: Optional[str] = None, + reservation_affinity_key: Optional[str] = None, + reservation_affinity_values: Optional[List[str]] = None, + ) -> "BatchPredictionJob": + """Sumbit a batch prediction job (not waiting for completion). + + Args: + job_display_name (str): Required. The user-defined name of the + BatchPredictionJob. The name can be up to 128 characters long and + can be consist of any UTF-8 characters. + model_name (Union[str, aiplatform.Model]): Required. A fully-qualified + model resource name or model ID. + Example: "projects/123/locations/us-central1/models/456" or "456" + when project and location are initialized or passed. May + optionally contain a version ID or alias in + {model_name}@{version} form. Or an instance of + aiplatform.Model. + instances_format (str): Required. The format in which instances are + provided. Must be one of the formats listed in + `Model.supported_input_storage_formats`. Default is "jsonl" when + using `gcs_source`. If a `bigquery_source` is provided, this is + overridden to "bigquery". + predictions_format (str): Required. The format in which Vertex AI + outputs the predictions, must be one of the formats specified in + `Model.supported_output_storage_formats`. Default is "jsonl" when + using `gcs_destination_prefix`. If a `bigquery_destination_prefix` + is provided, this is overridden to "bigquery". + gcs_source (Optional[Sequence[str]]): Google Cloud Storage URI(-s) to + your instances to run batch prediction on. They must match + `instances_format`. + bigquery_source (Optional[str]): BigQuery URI to a table, up to 2000 + characters long. For example: `bq://projectId.bqDatasetId.bqTableId` + gcs_destination_prefix (Optional[str]): The Google Cloud Storage + location of the directory where the output is to be written to. In + the given directory a new directory is created. Its name is + ``prediction--``, where + timestamp is in YYYY-MM-DDThh:mm:ss.sssZ ISO-8601 format. Inside of + it files ``predictions_0001.``, + ``predictions_0002.``, ..., ``predictions_N.`` + are created where ```` depends on chosen + ``predictions_format``, and N may equal 0001 and depends on the + total number of successfully predicted instances. If the Model has + both ``instance`` and ``prediction`` schemata defined then each such + file contains predictions as per the ``predictions_format``. If + prediction for any instance failed (partially or completely), then + an additional ``errors_0001.``, + ``errors_0002.``,..., ``errors_N.`` files are + created (N depends on total number of failed predictions). These + files contain the failed instances, as per their schema, followed by + an additional ``error`` field which as value has + ```google.rpc.Status`` `__ containing only ``code`` and + ``message`` fields. + bigquery_destination_prefix (Optional[str]): The BigQuery project or + dataset location where the output is to be written to. If project is + provided, a new dataset is created with name + ``prediction__`` where is made + BigQuery-dataset-name compatible (for example, most special + characters become underscores), and timestamp is in + YYYY_MM_DDThh_mm_ss_sssZ "based on ISO-8601" format. In the dataset + two tables will be created, ``predictions``, and ``errors``. If the + Model has both + [instance][google.cloud.aiplatform.v1.PredictSchemata.instance_schema_uri] + and + [prediction][google.cloud.aiplatform.v1.PredictSchemata.parameters_schema_uri] + schemata defined then the tables have columns as follows: The + ``predictions`` table contains instances for which the prediction + succeeded, it has columns as per a concatenation of the Model's + instance and prediction schemata. The ``errors`` table contains rows + for which the prediction has failed, it has instance columns, as per + the instance schema, followed by a single "errors" column, which as + values has [google.rpc.Status][google.rpc.Status] represented as a + STRUCT, and containing only ``code`` and ``message``. + model_parameters (Optional[Dict]): The parameters that govern the + predictions. The schema of the parameters may be specified via the + Model's `parameters_schema_uri`. + machine_type (Optional[str]): The type of machine for running batch + prediction on dedicated resources. Not specifying machine type will + result in batch prediction job being run with automatic resources. + accelerator_type (Optional[str]): The type of accelerator(s) that may + be attached to the machine as per `accelerator_count`. Only used if + `machine_type` is set. + accelerator_count (Optional[int]): The number of accelerators to + attach to the `machine_type`. Only used if `machine_type` is set. + starting_replica_count (Optional[int]): The number of machine replicas + used at the start of the batch operation. If not set, Vertex AI + decides starting number, not greater than `max_replica_count`. Only + used if `machine_type` is set. + max_replica_count (Optional[int]): The maximum number of machine + replicas the batch operation may be scaled to. Only used if + `machine_type` is set. Default is 10. + generate_explanation (bool): Optional. Generate explanation along with + the batch prediction results. This will cause the batch prediction + output to include explanations based on the `prediction_format`: - + `bigquery`: output includes a column named `explanation`. The value + is a struct that conforms to the [aiplatform.gapic.Explanation] + object. - `jsonl`: The JSON objects on each line include an + additional entry keyed `explanation`. The value of the entry is a + JSON object that conforms to the [aiplatform.gapic.Explanation] + object. - `csv`: Generating explanations for CSV format is not + supported. + explanation_metadata (aiplatform.explain.ExplanationMetadata): + Optional. Explanation metadata configuration for this + BatchPredictionJob. Can be specified only if `generate_explanation` + is set to `True`. This value overrides the value of + `Model.explanation_metadata`. All fields of `explanation_metadata` + are optional in the request. If a field of the + `explanation_metadata` object is not populated, the corresponding + field of the `Model.explanation_metadata` object is inherited. For + more details, see `Ref docs ` + explanation_parameters (aiplatform.explain.ExplanationParameters): + Optional. Parameters to configure explaining for Model's + predictions. Can be specified only if `generate_explanation` is set + to `True`. This value overrides the value of + `Model.explanation_parameters`. All fields of + `explanation_parameters` are optional in the request. If a field of + the `explanation_parameters` object is not populated, the + corresponding field of the `Model.explanation_parameters` object is + inherited. For more details, see `Ref docs + ` + labels (Dict[str, str]): Optional. The labels with user-defined + metadata to organize your BatchPredictionJobs. Label keys and values + can be no longer than 64 characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, underscores and + dashes. International characters are allowed. See + https://goo.gl/xmQnxf for more information and examples of labels. + credentials (Optional[auth_credentials.Credentials]): Custom + credentials to use to create this batch prediction job. Overrides + credentials set in aiplatform.init. + encryption_spec_key_name (Optional[str]): Optional. The Cloud KMS + resource identifier of the customer managed encryption key used to + protect the job. Has the + form: + ``projects/my-project/locations/my-region/keyRings/my-kr/cryptoKeys/my-key``. + The key needs to be in the same region as where the compute + resource is created. If this is set, then all resources created + by the BatchPredictionJob will be encrypted with the provided + encryption key. Overrides encryption_spec_key_name set in + aiplatform.init. + create_request_timeout (float): Optional. The timeout for the create + request in seconds. + batch_size (int): Optional. The number of the records (e.g. instances) + of the operation given in each batch to a machine replica. Machine + type, and size of a single record should be considered when setting + this parameter, higher value speeds up the batch operation's + execution, but too high value will result in a whole batch not + fitting in a machine's memory, and the whole operation will fail. + The default value is 64. model_monitoring_objective_config + (aiplatform.model_monitoring.ObjectiveConfig): Optional. The + objective config for model monitoring. Passing this parameter + enables monitoring on the model associated with this batch + prediction job. model_monitoring_alert_config + (aiplatform.model_monitoring.EmailAlertConfig): Optional. Configures + how model monitoring alerts are sent to the user. Right now only + email alert is supported. + analysis_instance_schema_uri (str): Optional. Only applicable if + model_monitoring_objective_config is also passed. This parameter + specifies the YAML schema file uri describing the format of a single + instance that you want Tensorflow Data Validation (TFDV) to analyze. + If this field is empty, all the feature data types are inferred from + predict_instance_schema_uri, meaning that TFDV will use the data in + the exact format as prediction request/response. If there are any + data type differences between predict instance and TFDV instance, + this field can be used to override the schema. For models trained + with Vertex AI, this field must be set as all the fields in predict + instance formatted as string. + service_account (str): Optional. Specifies the service account for + workload run-as account. Users submitting jobs must have act-as + permission on this run-as account. + reservation_affinity_type (str): Optional. The type of reservation + affinity. One of NO_RESERVATION, ANY_RESERVATION, + SPECIFIC_RESERVATION, SPECIFIC_THEN_ANY_RESERVATION, + SPECIFIC_THEN_NO_RESERVATION + reservation_affinity_key (str): Optional. Corresponds to the label key + of a reservation resource. To target a SPECIFIC_RESERVATION by name, + use `compute.googleapis.com/reservation-name` as the key and specify + the name of your reservation as its value. + reservation_affinity_values (List[str]): Optional. Corresponds to the + label values of a reservation resource. This must be the full + resource name of the reservation. + Format: + 'projects/{project_id_or_number}/zones/{zone}/reservations/{reservation_name}' + + Returns: + (jobs.BatchPredictionJob): + Instantiated representation of the created batch prediction job. + """ + return cls._submit_impl( + job_display_name=job_display_name, + model_name=model_name, + instances_format=instances_format, + predictions_format=predictions_format, + gcs_source=gcs_source, + bigquery_source=bigquery_source, + gcs_destination_prefix=gcs_destination_prefix, + bigquery_destination_prefix=bigquery_destination_prefix, + model_parameters=model_parameters, + machine_type=machine_type, + accelerator_type=accelerator_type, + accelerator_count=accelerator_count, + starting_replica_count=starting_replica_count, + max_replica_count=max_replica_count, + generate_explanation=generate_explanation, + explanation_metadata=explanation_metadata, + explanation_parameters=explanation_parameters, + labels=labels, + project=project, + location=location, + credentials=credentials, + encryption_spec_key_name=encryption_spec_key_name, + create_request_timeout=create_request_timeout, + batch_size=batch_size, + model_monitoring_objective_config=model_monitoring_objective_config, + model_monitoring_alert_config=model_monitoring_alert_config, + analysis_instance_schema_uri=analysis_instance_schema_uri, + service_account=service_account, + reservation_affinity_type=reservation_affinity_type, + reservation_affinity_key=reservation_affinity_key, + reservation_affinity_values=reservation_affinity_values, + # Main distinction of `create` vs `submit`: + wait_for_completion=False, + sync=True, + ) diff --git a/tests/unit/aiplatform/test_batch_prediction_job_preview.py b/tests/unit/aiplatform/test_batch_prediction_job_preview.py new file mode 100644 index 0000000000..9e8c28902f --- /dev/null +++ b/tests/unit/aiplatform/test_batch_prediction_job_preview.py @@ -0,0 +1,301 @@ +# -*- coding: utf-8 -*- + +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from importlib import reload +from unittest import mock +from unittest.mock import patch + +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.compat.services import ( + job_service_client, +) +from google.cloud.aiplatform.compat.types import ( + batch_prediction_job as gca_batch_prediction_job_compat, + io as gca_io_compat, + job_state as gca_job_state_compat, + machine_resources as gca_machine_resources_compat, + manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat, + reservation_affinity_v1 as gca_reservation_affinity_compat, +) +from google.cloud.aiplatform.preview import jobs as preview_jobs +import constants as test_constants +import pytest + +# TODO(b/242108750): remove temporary logic once model monitoring for batch prediction is GA +_TEST_API_CLIENT = job_service_client.JobServiceClient + +_TEST_PROJECT = test_constants.ProjectConstants._TEST_PROJECT +_TEST_LOCATION = test_constants.ProjectConstants._TEST_LOCATION +_TEST_ID = test_constants.TrainingJobConstants._TEST_ID +_TEST_ALT_ID = "8834795523125638878" +_TEST_DISPLAY_NAME = test_constants.TrainingJobConstants._TEST_DISPLAY_NAME +_TEST_SERVICE_ACCOUNT = test_constants.ProjectConstants._TEST_SERVICE_ACCOUNT + +_TEST_JOB_STATE_SUCCESS = gca_job_state_compat.JobState(4) +_TEST_JOB_STATE_RUNNING = gca_job_state_compat.JobState(3) +_TEST_JOB_STATE_PENDING = gca_job_state_compat.JobState(2) + +_TEST_PARENT = test_constants.ProjectConstants._TEST_PARENT + +_TEST_MODEL_NAME = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/{_TEST_ALT_ID}" +) + +_TEST_BATCH_PREDICTION_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/batchPredictionJobs/{_TEST_ID}" +_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME = "test-batch-prediction-job" + +_TEST_BATCH_PREDICTION_GCS_SOURCE = "gs://example-bucket/folder/instance.jsonl" + +_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX = "gs://example-bucket/folder/output" + +_TEST_MACHINE_TYPE = "n1-standard-4" +_TEST_ACCELERATOR_TYPE = "NVIDIA_TESLA_P100" +_TEST_ACCELERATOR_COUNT = 2 +_TEST_RESERVATION_AFFINITY_TYPE = "SPECIFIC_RESERVATION" +_TEST_RESERVATION_AFFINITY_KEY = "compute.googleapis.com/reservation-name" +_TEST_RESERVATION_AFFINITY_VALUES = [ + "projects/fake-project-id/zones/fake-zone/reservations/fake-reservation-name" +] + + +@pytest.fixture +def get_batch_prediction_job_mock(): + with patch.object( + _TEST_API_CLIENT, "get_batch_prediction_job" + ) as get_batch_prediction_job_mock: + get_batch_prediction_job_mock.side_effect = [ + gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_PENDING, + ), + gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_RUNNING, + ), + gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_SUCCESS, + ), + gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_SUCCESS, + ), + ] + yield get_batch_prediction_job_mock + + +@pytest.fixture +def create_batch_prediction_job_mock(): + with mock.patch.object( + _TEST_API_CLIENT, "create_batch_prediction_job" + ) as create_batch_prediction_job_mock: + create_batch_prediction_job_mock.return_value = ( + gca_batch_prediction_job_compat.BatchPredictionJob( + name=_TEST_BATCH_PREDICTION_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_SUCCESS, + ) + ) + yield create_batch_prediction_job_mock + + +@pytest.mark.usefixtures("google_auth_mock") +class TestBatchPredictionJobPreview: + + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_init_batch_prediction_job(self, get_batch_prediction_job_mock): + preview_jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + get_batch_prediction_job_mock.assert_called_once_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=base._DEFAULT_RETRY + ) + + def test_batch_prediction_job_status(self, get_batch_prediction_job_mock): + bp = preview_jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + + # get_batch_prediction() is called again here + bp_job_state = bp.state + + assert get_batch_prediction_job_mock.call_count == 2 + assert bp_job_state == _TEST_JOB_STATE_RUNNING + + get_batch_prediction_job_mock.assert_called_with( + name=_TEST_BATCH_PREDICTION_JOB_NAME, retry=base._DEFAULT_RETRY + ) + + def test_batch_prediction_job_done_get(self, get_batch_prediction_job_mock): + bp = preview_jobs.BatchPredictionJob( + batch_prediction_job_name=_TEST_BATCH_PREDICTION_JOB_NAME + ) + + assert bp.done() is False + assert get_batch_prediction_job_mock.call_count == 2 + + @mock.patch.object(preview_jobs, "_JOB_WAIT_TIME", 1) + @mock.patch.object(preview_jobs, "_LOG_WAIT_TIME", 1) + @pytest.mark.parametrize("sync", [True, False]) + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_create_with_reservation( + self, create_batch_prediction_job_mock, sync + ): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # Make SDK batch_predict method call + batch_prediction_job = preview_jobs.BatchPredictionJob.create( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, + sync=sync, + create_request_timeout=None, + service_account=_TEST_SERVICE_ACCOUNT, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + reservation_affinity_type=_TEST_RESERVATION_AFFINITY_TYPE, + reservation_affinity_key=_TEST_RESERVATION_AFFINITY_KEY, + reservation_affinity_values=_TEST_RESERVATION_AFFINITY_VALUES, + ) + + batch_prediction_job.wait_for_resource_creation() + + batch_prediction_job.wait() + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io_compat.GcsSource( + uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] + ), + ), + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io_compat.GcsDestination( + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX + ), + predictions_format="jsonl", + ), + service_account=_TEST_SERVICE_ACCOUNT, + dedicated_resources=gca_machine_resources_compat.BatchDedicatedResources( + machine_spec=gca_machine_resources_compat.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + reservation_affinity=gca_reservation_affinity_compat.ReservationAffinity( + reservation_affinity_type=_TEST_RESERVATION_AFFINITY_TYPE, + key=_TEST_RESERVATION_AFFINITY_KEY, + values=_TEST_RESERVATION_AFFINITY_VALUES, + ), + ), + ), + manual_batch_tuning_parameters=gca_manual_batch_tuning_parameters_compat.ManualBatchTuningParameters(), + ) + + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + ) + + @mock.patch.object(preview_jobs, "_JOB_WAIT_TIME", 1) + @mock.patch.object(preview_jobs, "_LOG_WAIT_TIME", 1) + @pytest.mark.usefixtures("get_batch_prediction_job_mock") + def test_batch_predict_job_submit(self, create_batch_prediction_job_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + # Make SDK batch_predict method call + batch_prediction_job = preview_jobs.BatchPredictionJob.submit( + model_name=_TEST_MODEL_NAME, + job_display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + gcs_source=_TEST_BATCH_PREDICTION_GCS_SOURCE, + gcs_destination_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX, + service_account=_TEST_SERVICE_ACCOUNT, + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + reservation_affinity_type=_TEST_RESERVATION_AFFINITY_TYPE, + reservation_affinity_key=_TEST_RESERVATION_AFFINITY_KEY, + reservation_affinity_values=_TEST_RESERVATION_AFFINITY_VALUES, + ) + + batch_prediction_job.wait_for_resource_creation() + assert batch_prediction_job.done() is False + assert ( + batch_prediction_job.state + != preview_jobs.gca_job_state.JobState.JOB_STATE_SUCCEEDED + ) + + batch_prediction_job.wait_for_completion() + assert ( + batch_prediction_job.state + == preview_jobs.gca_job_state.JobState.JOB_STATE_SUCCEEDED + ) + + # Construct expected request + expected_gapic_batch_prediction_job = gca_batch_prediction_job_compat.BatchPredictionJob( + display_name=_TEST_BATCH_PREDICTION_JOB_DISPLAY_NAME, + model=_TEST_MODEL_NAME, + input_config=gca_batch_prediction_job_compat.BatchPredictionJob.InputConfig( + instances_format="jsonl", + gcs_source=gca_io_compat.GcsSource( + uris=[_TEST_BATCH_PREDICTION_GCS_SOURCE] + ), + ), + output_config=gca_batch_prediction_job_compat.BatchPredictionJob.OutputConfig( + gcs_destination=gca_io_compat.GcsDestination( + output_uri_prefix=_TEST_BATCH_PREDICTION_GCS_DEST_PREFIX + ), + predictions_format="jsonl", + ), + service_account=_TEST_SERVICE_ACCOUNT, + dedicated_resources=gca_machine_resources_compat.BatchDedicatedResources( + machine_spec=gca_machine_resources_compat.MachineSpec( + machine_type=_TEST_MACHINE_TYPE, + accelerator_type=_TEST_ACCELERATOR_TYPE, + accelerator_count=_TEST_ACCELERATOR_COUNT, + reservation_affinity=gca_reservation_affinity_compat.ReservationAffinity( + reservation_affinity_type=_TEST_RESERVATION_AFFINITY_TYPE, + key=_TEST_RESERVATION_AFFINITY_KEY, + values=_TEST_RESERVATION_AFFINITY_VALUES, + ), + ), + ), + manual_batch_tuning_parameters=gca_manual_batch_tuning_parameters_compat.ManualBatchTuningParameters(), + ) + + create_batch_prediction_job_mock.assert_called_once_with( + parent=_TEST_PARENT, + batch_prediction_job=expected_gapic_batch_prediction_job, + timeout=None, + )