|
14 | 14 |
|
15 | 15 | from __future__ import annotations
|
16 | 16 |
|
| 17 | +import datetime |
| 18 | +import decimal |
17 | 19 | import os
|
18 | 20 | import textwrap
|
19 | 21 | from typing import Optional
|
20 | 22 | from unittest import mock
|
21 | 23 |
|
| 24 | +import dateutil |
| 25 | +import dateutil.relativedelta |
22 | 26 | from google.adk.tools import BaseTool
|
23 | 27 | from google.adk.tools.bigquery import BigQueryCredentialsConfig
|
24 | 28 | from google.adk.tools.bigquery import BigQueryToolset
|
@@ -829,3 +833,143 @@ def test_execute_sql_no_default_auth(
|
829 | 833 | result = execute_sql(project, query, credentials, tool_config, tool_context)
|
830 | 834 | assert result == {"status": "SUCCESS", "rows": query_result}
|
831 | 835 | mock_default_auth.assert_not_called()
|
| 836 | + |
| 837 | + |
| 838 | +@pytest.mark.parametrize( |
| 839 | + ("query", "query_result", "tool_result_rows"), |
| 840 | + [ |
| 841 | + pytest.param( |
| 842 | + "SELECT [1,2,3] AS x", |
| 843 | + [{"x": [1, 2, 3]}], |
| 844 | + [{"x": [1, 2, 3]}], |
| 845 | + id="ARRAY", |
| 846 | + ), |
| 847 | + pytest.param( |
| 848 | + "SELECT TRUE AS x", [{"x": True}], [{"x": True}], id="BOOL" |
| 849 | + ), |
| 850 | + pytest.param( |
| 851 | + "SELECT b'Hello World!' AS x", |
| 852 | + [{"x": b"Hello World!"}], |
| 853 | + [{"x": "b'Hello World!'"}], |
| 854 | + id="BYTES", |
| 855 | + ), |
| 856 | + pytest.param( |
| 857 | + "SELECT DATE '2025-07-21' AS x", |
| 858 | + [{"x": datetime.date(2025, 7, 21)}], |
| 859 | + [{"x": "2025-07-21"}], |
| 860 | + id="DATE", |
| 861 | + ), |
| 862 | + pytest.param( |
| 863 | + "SELECT DATETIME '2025-07-21 14:30:45' AS x", |
| 864 | + [{"x": datetime.datetime(2025, 7, 21, 14, 30, 45)}], |
| 865 | + [{"x": "2025-07-21 14:30:45"}], |
| 866 | + id="DATETIME", |
| 867 | + ), |
| 868 | + pytest.param( |
| 869 | + "SELECT ST_GEOGFROMTEXT('POINT(-122.21 47.48)') as x", |
| 870 | + [{"x": "POINT(-122.21 47.48)"}], |
| 871 | + [{"x": "POINT(-122.21 47.48)"}], |
| 872 | + id="GEOGRAPHY", |
| 873 | + ), |
| 874 | + pytest.param( |
| 875 | + "SELECT INTERVAL 10 DAY as x", |
| 876 | + [{"x": dateutil.relativedelta.relativedelta(days=10)}], |
| 877 | + [{"x": "relativedelta(days=+10)"}], |
| 878 | + id="INTERVAL", |
| 879 | + ), |
| 880 | + pytest.param( |
| 881 | + "SELECT JSON_OBJECT('name', 'Alice', 'age', 30) AS x", |
| 882 | + [{"x": {"age": 30, "name": "Alice"}}], |
| 883 | + [{"x": {"age": 30, "name": "Alice"}}], |
| 884 | + id="JSON", |
| 885 | + ), |
| 886 | + pytest.param("SELECT 1 AS x", [{"x": 1}], [{"x": 1}], id="INT64"), |
| 887 | + pytest.param( |
| 888 | + "SELECT CAST(1.2 AS NUMERIC) AS x", |
| 889 | + [{"x": decimal.Decimal("1.2")}], |
| 890 | + [{"x": "1.2"}], |
| 891 | + id="NUMERIC", |
| 892 | + ), |
| 893 | + pytest.param( |
| 894 | + "SELECT CAST(1.2 AS BIGNUMERIC) AS x", |
| 895 | + [{"x": decimal.Decimal("1.2")}], |
| 896 | + [{"x": "1.2"}], |
| 897 | + id="BIGNUMERIC", |
| 898 | + ), |
| 899 | + pytest.param( |
| 900 | + "SELECT 1.23 AS x", [{"x": 1.23}], [{"x": 1.23}], id="FLOAT64" |
| 901 | + ), |
| 902 | + pytest.param( |
| 903 | + "SELECT RANGE(DATE '2023-01-01', DATE '2023-01-31') as x", |
| 904 | + [{ |
| 905 | + "x": { |
| 906 | + "start": datetime.date(2023, 1, 1), |
| 907 | + "end": datetime.date(2023, 1, 31), |
| 908 | + } |
| 909 | + }], |
| 910 | + [{ |
| 911 | + "x": ( |
| 912 | + "{'start': datetime.date(2023, 1, 1), 'end':" |
| 913 | + " datetime.date(2023, 1, 31)}" |
| 914 | + ) |
| 915 | + }], |
| 916 | + id="RANGE", |
| 917 | + ), |
| 918 | + pytest.param( |
| 919 | + "SELECT 'abc' AS x", [{"x": "abc"}], [{"x": "abc"}], id="STRING" |
| 920 | + ), |
| 921 | + pytest.param( |
| 922 | + "SELECT STRUCT('Alice' AS name, 30 AS age) as x", |
| 923 | + [{"x": {"name": "Alice", "age": 30}}], |
| 924 | + [{"x": {"name": "Alice", "age": 30}}], |
| 925 | + id="STRUCT", |
| 926 | + ), |
| 927 | + pytest.param( |
| 928 | + "SELECT TIME '10:30:45' as x", |
| 929 | + [{"x": datetime.time(10, 30, 45)}], |
| 930 | + [{"x": "10:30:45"}], |
| 931 | + id="TIME", |
| 932 | + ), |
| 933 | + pytest.param( |
| 934 | + "SELECT TIMESTAMP '2025-07-21 10:30:45-07:00' as x", |
| 935 | + [{ |
| 936 | + "x": datetime.datetime( |
| 937 | + 2025, 7, 21, 17, 30, 45, tzinfo=datetime.timezone.utc |
| 938 | + ) |
| 939 | + }], |
| 940 | + [{"x": "2025-07-21 17:30:45+00:00"}], |
| 941 | + id="TIMESTAMP", |
| 942 | + ), |
| 943 | + pytest.param( |
| 944 | + "SELECT NULL AS x", [{"x": None}], [{"x": None}], id="NULL" |
| 945 | + ), |
| 946 | + ], |
| 947 | +) |
| 948 | +@mock.patch.dict(os.environ, {}, clear=True) |
| 949 | +@mock.patch("google.cloud.bigquery.Client.query_and_wait", autospec=True) |
| 950 | +@mock.patch("google.cloud.bigquery.Client.query", autospec=True) |
| 951 | +def test_execute_sql_result_dtype( |
| 952 | + mock_query, mock_query_and_wait, query, query_result, tool_result_rows |
| 953 | +): |
| 954 | + """Test execute_sql tool invocation for various BigQuery data types. |
| 955 | +
|
| 956 | + See all the supported BigQuery data types at |
| 957 | + https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#data_type_list. |
| 958 | + """ |
| 959 | + project = "my_project" |
| 960 | + statement_type = "SELECT" |
| 961 | + credentials = mock.create_autospec(Credentials, instance=True) |
| 962 | + tool_config = BigQueryToolConfig() |
| 963 | + tool_context = mock.create_autospec(ToolContext, instance=True) |
| 964 | + |
| 965 | + # Simulate the result of query API |
| 966 | + query_job = mock.create_autospec(bigquery.QueryJob) |
| 967 | + query_job.statement_type = statement_type |
| 968 | + mock_query.return_value = query_job |
| 969 | + |
| 970 | + # Simulate the result of query_and_wait API |
| 971 | + mock_query_and_wait.return_value = query_result |
| 972 | + |
| 973 | + # Test the tool worked without invoking default auth |
| 974 | + result = execute_sql(project, query, credentials, tool_config, tool_context) |
| 975 | + assert result == {"status": "SUCCESS", "rows": tool_result_rows} |
0 commit comments