diff --git a/packages/google-cloud-bigquery-storage/samples/pyarrow/append_rows_with_arrow.py b/packages/google-cloud-bigquery-storage/samples/pyarrow/append_rows_with_arrow.py index 1d4adad52f01..cac46f98fc15 100644 --- a/packages/google-cloud-bigquery-storage/samples/pyarrow/append_rows_with_arrow.py +++ b/packages/google-cloud-bigquery-storage/samples/pyarrow/append_rows_with_arrow.py @@ -16,14 +16,14 @@ import datetime import decimal -from google.cloud import bigquery from google.cloud.bigquery import enums -from google.cloud.bigquery_storage_v1 import types as gapic_types -from google.cloud.bigquery_storage_v1.writer import AppendRowsStream import pandas as pd - import pyarrow as pa +from google.cloud import bigquery +from google.cloud.bigquery_storage_v1 import types as gapic_types +from google.cloud.bigquery_storage_v1.writer import AppendRowsStream + TABLE_LENGTH = 100_000 BQ_SCHEMA = [ @@ -100,7 +100,10 @@ def make_table(project_id, dataset_id, bq_client): def create_stream(bqstorage_write_client, table): - stream_name = f"projects/{table.project}/datasets/{table.dataset_id}/tables/{table.table_id}/_default" + stream_name = ( + f"projects/{table.project}/datasets/{table.dataset_id}/" + f"tables/{table.table_id}/_default" + ) request_template = gapic_types.AppendRowsRequest() request_template.write_stream = stream_name @@ -160,18 +163,64 @@ def generate_pyarrow_table(num_rows=TABLE_LENGTH): def generate_write_requests(pyarrow_table): - # Determine max_chunksize of the record batches. Because max size of - # AppendRowsRequest is 10 MB, we need to split the table if it's too big. - # See: https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1#appendrowsrequest - max_request_bytes = 10 * 2**20 # 10 MB - chunk_num = int(pyarrow_table.nbytes / max_request_bytes) + 1 - chunk_size = int(pyarrow_table.num_rows / chunk_num) - - # Construct request(s). - for batch in pyarrow_table.to_batches(max_chunksize=chunk_size): + # Maximum size for a single AppendRowsRequest is 10 MB. + # To be safe, we'll aim for a soft limit of 7 MB. + max_request_bytes = 7 * 1024 * 1024 # 7 MB + + def _create_request(batches): + """Helper to create an AppendRowsRequest from a list of batches.""" + combined_table = pa.Table.from_batches(batches) request = gapic_types.AppendRowsRequest() - request.arrow_rows.rows.serialized_record_batch = batch.serialize().to_pybytes() - yield request + request.arrow_rows.rows.serialized_record_batch = ( + combined_table.combine_chunks().to_batches()[0].serialize().to_pybytes() + ) + return request + + batches = pyarrow_table.to_batches() + + current_batches = [] + current_size = 0 + + while batches: + batch = batches.pop() + batch_size = batch.nbytes + + if current_size + batch_size > max_request_bytes: + if batch.num_rows > 1: + # Split the batch into 2 sub batches with identical chunksizes + mid = batch.num_rows // 2 + batch_left = batch.slice(offset=0, length=mid) + batch_right = batch.slice(offset=mid) + + # Append the new batches into the stack and continue poping. + batches.append(batch_right) + batches.append(batch_left) + continue + + # If the batch is single row and still larger than max_request_size + else: + # If current batches is empty, throw error + if len(current_batches) == 0: + raise ValueError( + f"A single PyArrow batch of one row is larger than the maximum request size " + f"(batch size: {batch_size} > max request size: {max_request_bytes}). Cannot proceed." + ) + # Otherwise, generate the request, reset current_size and current_batches + else: + yield _create_request(current_batches) + + current_batches = [] + current_size = 0 + batches.append(batch) + + # Otherwise, add the batch into current_batches + else: + current_batches.append(batch) + current_size += batch_size + + # Flush remaining batches + if current_batches: + yield _create_request(current_batches) def verify_result(client, table, futures): @@ -181,14 +230,13 @@ def verify_result(client, table, futures): assert bq_table.schema == BQ_SCHEMA # Verify table size. - query = client.query(f"SELECT COUNT(1) FROM `{bq_table}`;") + query = client.query(f"SELECT DISTINCT int64_col FROM `{bq_table}`;") query_result = query.result().to_dataframe() - # There might be extra rows due to retries. - assert query_result.iloc[0, 0] >= TABLE_LENGTH + assert len(query_result) == TABLE_LENGTH # Verify that table was split into multiple requests. - assert len(futures) == 2 + assert len(futures) == 3 def main(project_id, dataset):