Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import datetime
import decimal

from google.cloud import bigquery
from google.cloud.bigquery import enums
from google.cloud.bigquery_storage_v1 import types as gapic_types
from google.cloud.bigquery_storage_v1.writer import AppendRowsStream
import pandas as pd

import pyarrow as pa

from google.cloud import bigquery
from google.cloud.bigquery_storage_v1 import types as gapic_types
from google.cloud.bigquery_storage_v1.writer import AppendRowsStream

TABLE_LENGTH = 100_000

BQ_SCHEMA = [
Expand Down Expand Up @@ -100,7 +100,10 @@ def make_table(project_id, dataset_id, bq_client):


def create_stream(bqstorage_write_client, table):
stream_name = f"projects/{table.project}/datasets/{table.dataset_id}/tables/{table.table_id}/_default"
stream_name = (
f"projects/{table.project}/datasets/{table.dataset_id}/"
f"tables/{table.table_id}/_default"
)
request_template = gapic_types.AppendRowsRequest()
request_template.write_stream = stream_name

Expand Down Expand Up @@ -160,18 +163,64 @@ def generate_pyarrow_table(num_rows=TABLE_LENGTH):


def generate_write_requests(pyarrow_table):
# Determine max_chunksize of the record batches. Because max size of
# AppendRowsRequest is 10 MB, we need to split the table if it's too big.
# See: https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1#appendrowsrequest
max_request_bytes = 10 * 2**20 # 10 MB
chunk_num = int(pyarrow_table.nbytes / max_request_bytes) + 1
chunk_size = int(pyarrow_table.num_rows / chunk_num)

# Construct request(s).
for batch in pyarrow_table.to_batches(max_chunksize=chunk_size):
# Maximum size for a single AppendRowsRequest is 10 MB.
# To be safe, we'll aim for a soft limit of 7 MB.
max_request_bytes = 7 * 1024 * 1024 # 7 MB

def _create_request(batches):
"""Helper to create an AppendRowsRequest from a list of batches."""
combined_table = pa.Table.from_batches(batches)
request = gapic_types.AppendRowsRequest()
request.arrow_rows.rows.serialized_record_batch = batch.serialize().to_pybytes()
yield request
request.arrow_rows.rows.serialized_record_batch = (
combined_table.combine_chunks().to_batches()[0].serialize().to_pybytes()
)
return request

batches = pyarrow_table.to_batches()

current_batches = []
current_size = 0

while batches:
batch = batches.pop()
batch_size = batch.nbytes

if current_size + batch_size > max_request_bytes:
if batch.num_rows > 1:
# Split the batch into 2 sub batches with identical chunksizes
mid = batch.num_rows // 2
batch_left = batch.slice(offset=0, length=mid)
batch_right = batch.slice(offset=mid)

# Append the new batches into the stack and continue poping.
batches.append(batch_right)
batches.append(batch_left)
continue

# If the batch is single row and still larger than max_request_size
else:
# If current batches is empty, throw error
if len(current_batches) == 0:
raise ValueError(
f"A single PyArrow batch of one row is larger than the maximum request size "
f"(batch size: {batch_size} > max request size: {max_request_bytes}). Cannot proceed."
)
# Otherwise, generate the request, reset current_size and current_batches
else:
yield _create_request(current_batches)

current_batches = []
current_size = 0
batches.append(batch)

# Otherwise, add the batch into current_batches
else:
current_batches.append(batch)
current_size += batch_size

# Flush remaining batches
if current_batches:
yield _create_request(current_batches)


def verify_result(client, table, futures):
Expand All @@ -181,14 +230,13 @@ def verify_result(client, table, futures):
assert bq_table.schema == BQ_SCHEMA

# Verify table size.
query = client.query(f"SELECT COUNT(1) FROM `{bq_table}`;")
query = client.query(f"SELECT DISTINCT int64_col FROM `{bq_table}`;")
query_result = query.result().to_dataframe()

# There might be extra rows due to retries.
assert query_result.iloc[0, 0] >= TABLE_LENGTH
assert len(query_result) == TABLE_LENGTH

# Verify that table was split into multiple requests.
assert len(futures) == 2
assert len(futures) == 3


def main(project_id, dataset):
Expand Down
Loading