diff --git a/pandas_gbq/gbq.py b/pandas_gbq/gbq.py index 8db1d4ea..71ed878d 100644 --- a/pandas_gbq/gbq.py +++ b/pandas_gbq/gbq.py @@ -116,6 +116,7 @@ def read_gbq( *, col_order=None, bigquery_client=None, + dry_run=False, ): r"""Read data from Google BigQuery to a pandas DataFrame. @@ -266,6 +267,8 @@ def read_gbq( bigquery_client : google.cloud.bigquery.Client, optional A Google Cloud BigQuery Python Client instance. If provided, it will be used for reading data, while the project and credentials parameters will be ignored. + dry_run : bool, default False + If true, executes the query in dry run mode and returns a statistic report as a Pandas series. Returns ------- @@ -318,6 +321,19 @@ def read_gbq( bigquery_client=bigquery_client, ) + if dry_run: + if not _is_query(query_or_table): + # If the input is a table reference, turn it to a query + query_or_table = f"SELECT * FROM `{query_or_table}`;" + return connector.run_query( + query_or_table, + configuration=configuration, + max_results=max_results, + progress_bar_type=progress_bar_type, + dtypes=dtypes, + dry_run=True, + ) + if _is_query(query_or_table): final_df = connector.run_query( query_or_table, diff --git a/pandas_gbq/gbq_connector.py b/pandas_gbq/gbq_connector.py index 97a22db4..0e24fcca 100644 --- a/pandas_gbq/gbq_connector.py +++ b/pandas_gbq/gbq_connector.py @@ -3,6 +3,7 @@ # license that can be found in the LICENSE file. +import copy import logging import time import typing @@ -198,7 +199,9 @@ def download_table( user_dtypes=dtypes, ) - def run_query(self, query, max_results=None, progress_bar_type=None, **kwargs): + def run_query( + self, query, max_results=None, progress_bar_type=None, dry_run=False, **kwargs + ): from google.cloud import bigquery job_config_dict = { @@ -235,6 +238,13 @@ def run_query(self, query, max_results=None, progress_bar_type=None, **kwargs): self._start_timer() job_config = bigquery.QueryJobConfig.from_api_repr(job_config_dict) + if dry_run: + job_config.dry_run = True + + return self._report_dry_run_stats( + self.client.query(query, job_config), + ) + if FEATURES.bigquery_has_query_and_wait: rows_iter = pandas_gbq.query.query_and_wait_via_client_library( self, @@ -266,6 +276,47 @@ def run_query(self, query, max_results=None, progress_bar_type=None, **kwargs): user_dtypes=dtypes, ) + def _report_dry_run_stats( + self, + query_job, + ) -> "pandas.Series": + job_api_repr = copy.deepcopy(query_job._properties) + + index = [] + values = [] + + fields = job_api_repr["statistics"]["query"]["schema"]["fields"] + index.append("fieldCount") + values.append(len(fields)) + index.append("fields") + values.append(fields) + + query_config = job_api_repr["configuration"]["query"] + for key in ("destinationTable", "useLegacySql"): + index.append(key) + values.append(query_config.get(key)) + + query_stats = job_api_repr["statistics"]["query"] + for key in ( + "referencedTables", + "totalBytesProcessed", + "cacheHit", + "statementType", + ): + index.append(key) + values.append(query_stats.get(key)) + + import pandas + + index.append("creationTime") + values.append( + pandas.Timestamp( + job_api_repr["statistics"]["creationTime"], unit="ms", tz="UTC" + ) + ) + + return pandas.Series(values, index=index) + def _download_results( self, rows_iter, diff --git a/tests/system/test_read_gbq.py b/tests/system/test_read_gbq.py index 946da668..73643828 100644 --- a/tests/system/test_read_gbq.py +++ b/tests/system/test_read_gbq.py @@ -672,3 +672,48 @@ def test_read_gbq_with_bq_client(read_gbq_with_bq_client): {"numbers": pandas.Series([1, 2, 3], dtype="Int64")} ) pandas.testing.assert_frame_equal(actual_result, expected_result) + + +def test_read_gbq_table_dry_run(read_gbq, writable_table): + result = read_gbq(writable_table, dry_run=True) + + assert isinstance(result, pandas.Series) + pandas.testing.assert_index_equal( + result.index, + pandas.Index( + [ + "fieldCount", + "fields", + "destinationTable", + "useLegacySql", + "referencedTables", + "totalBytesProcessed", + "cacheHit", + "statementType", + "creationTime", + ] + ), + ) + + +def test_read_gbq_query_dry_run(read_gbq, writable_table): + query = f"SELECT * FROM {writable_table} LIMIT 10" + result = read_gbq(query, dry_run=True) + + assert isinstance(result, pandas.Series) + pandas.testing.assert_index_equal( + result.index, + pandas.Index( + [ + "fieldCount", + "fields", + "destinationTable", + "useLegacySql", + "referencedTables", + "totalBytesProcessed", + "cacheHit", + "statementType", + "creationTime", + ] + ), + ) diff --git a/tests/unit/test_gbq.py b/tests/unit/test_gbq.py index 75574820..22085740 100644 --- a/tests/unit/test_gbq.py +++ b/tests/unit/test_gbq.py @@ -91,6 +91,54 @@ def get_table(table_ref_or_id, **kwargs): return mock_bigquery_client +@pytest.fixture(autouse=True) +def dryrun_bigquery_client(mock_bigquery_client, mock_query_job): + mock_query_job._properties = { + "configuration": { + "query": { + "destinationTable": { + "projectId": "project-id", + "datasetId": "dataset-id", + "tableId": "table-id", + }, + "useLegacySql": False, + }, + "dryRun": True, + "jobType": "QUERY", + }, + "jobReference": {"projectId": "bigframes-dev", "location": "US"}, + "statistics": { + "creationTime": 1745880402624.0, + "totalBytesProcessed": "38324173849", + "query": { + "totalBytesProcessed": "38324173849", + "totalBytesBilled": "0", + "cacheHit": False, + "referencedTables": [ + { + "projectId": "projectId", + "datasetId": "datasetId", + "tableId": "tableId", + } + ], + "schema": { + "fields": [ + {"name": "title", "type": "STRING", "mode": "NULLABLE"}, + ] + }, + "statementType": "SELECT", + "totalBytesProcessedAccuracy": "PRECISE", + }, + "reservation_id": "reservation_id", + "edition": "edition", + }, + "status": {"state": "DONE"}, + } + mock_bigquery_client.query.return_value = mock_query_job + + return mock_bigquery_client + + @pytest.mark.parametrize( ("type_", "expected"), [ @@ -937,3 +985,41 @@ def test_run_query_with_dml_query(mock_bigquery_client, mock_query_job): type(mock_query_job).destination = mock.PropertyMock(return_value=None) connector.run_query("UPDATE tablename SET value = '';") mock_bigquery_client.list_rows.assert_not_called() + + +@pytest.mark.parametrize( + "query_or_table", + [ + pytest.param("my-project.my_dataset.my_table", id="table_ref"), + pytest.param("SELECT * FROM my-project.my_dataset.my_table", id="query"), + ], +) +def test_read_gbq_dry_run( + query_or_table, dryrun_bigquery_client, mock_service_account_credentials +): + mock_service_account_credentials.project_id = "service_account_project_id" + result = gbq.read_gbq( + query_or_table, + credentials=mock_service_account_credentials, + project_id="param-project", + dry_run=True, + ) + + assert isinstance(result, pandas.Series) + dryrun_bigquery_client.query.assert_called_once() + pandas.testing.assert_index_equal( + result.index, + pandas.Index( + [ + "fieldCount", + "fields", + "destinationTable", + "useLegacySql", + "referencedTables", + "totalBytesProcessed", + "cacheHit", + "statementType", + "creationTime", + ] + ), + )