Skip to content

Commit 769c80b

Browse files
committed
Split timeout in multi-request methods
If a method makes multiple requests and is given a timeout, that timeout should represent the total allowed time for all requests combined.
1 parent a5188cb commit 769c80b

5 files changed

Lines changed: 199 additions & 23 deletions

File tree

bigquery/google/cloud/bigquery/client.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
except ImportError: # Python 2.7
2323
import collections as collections_abc
2424

25+
import concurrent.futures
2526
import copy
2627
import functools
2728
import gzip
@@ -47,6 +48,7 @@
4748
import google.api_core.client_options
4849
import google.api_core.exceptions
4950
from google.api_core import page_iterator
51+
from google.auth.transport.requests import TimeoutGuard
5052
import google.cloud._helpers
5153
from google.cloud import exceptions
5254
from google.cloud.client import ClientWithProject
@@ -2557,21 +2559,27 @@ def list_partitions(self, table, retry=DEFAULT_RETRY, timeout=None):
25572559
timeout (Optional[float]):
25582560
The number of seconds to wait for the underlying HTTP transport
25592561
before using ``retry``.
2562+
If multiple requests are made under the hood, ``timeout`` is
2563+
interpreted as the approximate total time of **all** requests.
25602564
25612565
Returns:
25622566
List[str]:
25632567
A list of the partition ids present in the partitioned table
25642568
"""
2565-
# TODO: split timeout between all API calls in the method
25662569
table = _table_arg_to_table_ref(table, default_project=self.project)
2567-
meta_table = self.get_table(
2568-
TableReference(
2569-
self.dataset(table.dataset_id, project=table.project),
2570-
"%s$__PARTITIONS_SUMMARY__" % table.table_id,
2571-
),
2572-
retry=retry,
2573-
timeout=timeout,
2574-
)
2570+
2571+
with TimeoutGuard(
2572+
timeout, timeout_error_type=concurrent.futures.TimeoutError
2573+
) as guard:
2574+
meta_table = self.get_table(
2575+
TableReference(
2576+
self.dataset(table.dataset_id, project=table.project),
2577+
"%s$__PARTITIONS_SUMMARY__" % table.table_id,
2578+
),
2579+
retry=retry,
2580+
timeout=timeout,
2581+
)
2582+
timeout = guard.remaining_timeout
25752583

25762584
subset = [col for col in meta_table.schema if col.name == "partition_id"]
25772585
return [
@@ -2638,6 +2646,8 @@ def list_rows(
26382646
timeout (Optional[float]):
26392647
The number of seconds to wait for the underlying HTTP transport
26402648
before using ``retry``.
2649+
If multiple requests are made under the hood, ``timeout`` is
2650+
interpreted as the approximate total time of **all** requests.
26412651
26422652
Returns:
26432653
google.cloud.bigquery.table.RowIterator:
@@ -2648,7 +2658,6 @@ def list_rows(
26482658
(this is distinct from the total number of rows in the
26492659
current page: ``iterator.page.num_items``).
26502660
"""
2651-
# TODO: split timeout between all internal API calls
26522661
table = _table_arg_to_table(table, default_project=self.project)
26532662

26542663
if not isinstance(table, Table):
@@ -2663,7 +2672,11 @@ def list_rows(
26632672
# No schema, but no selected_fields. Assume the developer wants all
26642673
# columns, so get the table resource for them rather than failing.
26652674
elif len(schema) == 0:
2666-
table = self.get_table(table.reference, retry=retry, timeout=timeout)
2675+
with TimeoutGuard(
2676+
timeout, timeout_error_type=concurrent.futures.TimeoutError
2677+
) as guard:
2678+
table = self.get_table(table.reference, retry=retry, timeout=timeout)
2679+
timeout = guard.remaining_timeout
26672680
schema = table.schema
26682681

26692682
params = {}

bigquery/google/cloud/bigquery/job.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from six.moves import http_client
2727

2828
import google.api_core.future.polling
29+
from google.auth.transport.requests import TimeoutGuard
2930
from google.cloud import exceptions
3031
from google.cloud.exceptions import NotFound
3132
from google.cloud.bigquery.dataset import Dataset
@@ -793,6 +794,8 @@ def result(self, retry=DEFAULT_RETRY, timeout=None):
793794
timeout (Optional[float]):
794795
The number of seconds to wait for the underlying HTTP transport
795796
before using ``retry``.
797+
If multiple requests are made under the hood, ``timeout`` is
798+
interpreted as the approximate total time of **all** requests.
796799
797800
Returns:
798801
_AsyncJob: This instance.
@@ -803,10 +806,12 @@ def result(self, retry=DEFAULT_RETRY, timeout=None):
803806
concurrent.futures.TimeoutError:
804807
if the job did not complete in the given timeout.
805808
"""
806-
# TODO: combine _begin timeout with super().result() timeout!
807-
# borrow timeout guard from google auth lib
808809
if self.state is None:
809-
self._begin(retry=retry, timeout=timeout)
810+
with TimeoutGuard(
811+
timeout, timeout_error_type=concurrent.futures.TimeoutError
812+
) as guard:
813+
self._begin(retry=retry, timeout=timeout)
814+
timeout = guard.remaining_timeout
810815
# TODO: modify PollingFuture so it can pass a retry argument to done().
811816
return super(_AsyncJob, self).result(timeout=timeout)
812817

@@ -3163,6 +3168,8 @@ def result(
31633168
timeout (Optional[float]):
31643169
The number of seconds to wait for the underlying HTTP transport
31653170
before using ``retry``.
3171+
If multiple requests are made under the hood, ``timeout`` is
3172+
interpreted as the approximate total time of **all** requests.
31663173
31673174
Returns:
31683175
google.cloud.bigquery.table.RowIterator:
@@ -3180,16 +3187,27 @@ def result(
31803187
If the job did not complete in the given timeout.
31813188
"""
31823189
try:
3183-
# TODO: combine timeout with timeouts passed to super().result()
3184-
# and _get_query_results (total timeout shared by both)
3185-
# borrow timeout guard from google auth lib
3186-
super(QueryJob, self).result(timeout=timeout)
3190+
guard = TimeoutGuard(
3191+
timeout, timeout_error_type=concurrent.futures.TimeoutError
3192+
)
3193+
with guard:
3194+
super(QueryJob, self).result(retry=retry, timeout=timeout)
3195+
timeout = guard.remaining_timeout
31873196

31883197
# Return an iterator instead of returning the job.
31893198
if not self._query_results:
3190-
self._query_results = self._client._get_query_results(
3191-
self.job_id, retry, project=self.project, location=self.location
3199+
guard = TimeoutGuard(
3200+
timeout, timeout_error_type=concurrent.futures.TimeoutError
31923201
)
3202+
with guard:
3203+
self._query_results = self._client._get_query_results(
3204+
self.job_id,
3205+
retry,
3206+
project=self.project,
3207+
location=self.location,
3208+
timeout=timeout,
3209+
)
3210+
timeout = guard.remaining_timeout
31933211
except exceptions.GoogleCloudError as exc:
31943212
exc.message += self._format_for_exception(self.query, self.job_id)
31953213
exc.query_job = self
@@ -3209,7 +3227,11 @@ def result(
32093227
dest_table = Table(dest_table_ref, schema=schema)
32103228
dest_table._properties["numRows"] = self._query_results.total_rows
32113229
rows = self._client.list_rows(
3212-
dest_table, page_size=page_size, retry=retry, max_results=max_results
3230+
dest_table,
3231+
page_size=page_size,
3232+
max_results=max_results,
3233+
retry=retry,
3234+
timeout=timeout,
32133235
)
32143236
rows._preserve_order = _contains_order_by(self.query)
32153237
return rows

bigquery/noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def default(session):
3434
run the tests.
3535
"""
3636
# Install all test dependencies, then install local packages in-place.
37-
session.install("mock", "pytest", "pytest-cov")
37+
session.install("mock", "pytest", "pytest-cov", "freezegun")
3838
for local_dep in LOCAL_DEPS:
3939
session.install("-e", local_dep)
4040

bigquery/tests/unit/test_client.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import unittest
2525
import warnings
2626

27+
import freezegun
2728
import mock
2829
import requests
2930
import six
@@ -5367,6 +5368,43 @@ def test_list_partitions_with_string_id(self):
53675368

53685369
self.assertEqual(len(partition_list), 0)
53695370

5371+
def test_list_partitions_splitting_timout_between_requests(self):
5372+
from google.cloud.bigquery.table import Table
5373+
5374+
row_count = 2
5375+
meta_info = _make_list_partitons_meta_info(
5376+
self.PROJECT, self.DS_ID, self.TABLE_ID, row_count
5377+
)
5378+
5379+
data = {
5380+
"totalRows": str(row_count),
5381+
"rows": [{"f": [{"v": "20180101"}]}, {"f": [{"v": "20180102"}]},],
5382+
}
5383+
creds = _make_credentials()
5384+
http = object()
5385+
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
5386+
client._connection = make_connection(meta_info, data)
5387+
table = Table(self.TABLE_REF)
5388+
5389+
with freezegun.freeze_time("2019-01-01 00:00:00", tick=False) as frozen_time:
5390+
5391+
def delayed_get_table(*args, **kwargs):
5392+
frozen_time.tick(delta=1.4)
5393+
return orig_get_table(*args, **kwargs)
5394+
5395+
orig_get_table = client.get_table
5396+
client.get_table = mock.Mock(side_effect=delayed_get_table)
5397+
5398+
client.list_partitions(table, timeout=5.0)
5399+
5400+
client.get_table.assert_called_once()
5401+
_, kwargs = client.get_table.call_args
5402+
self.assertEqual(kwargs.get("timeout"), 5.0)
5403+
5404+
client._connection.api_request.assert_called()
5405+
_, kwargs = client._connection.api_request.call_args
5406+
self.assertAlmostEqual(kwargs.get("timeout"), 3.6, places=5)
5407+
53705408
def test_list_rows(self):
53715409
import datetime
53725410
from google.cloud._helpers import UTC
@@ -5687,6 +5725,46 @@ def test_list_rows_with_missing_schema(self):
56875725
self.assertEqual(rows[1].age, 31, msg=repr(table))
56885726
self.assertIsNone(rows[2].age, msg=repr(table))
56895727

5728+
def test_list_rows_splitting_timout_between_requests(self):
5729+
from google.cloud.bigquery.schema import SchemaField
5730+
from google.cloud.bigquery.table import Table
5731+
5732+
response = {"totalRows": "0", "rows": []}
5733+
creds = _make_credentials()
5734+
http = object()
5735+
client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)
5736+
client._connection = make_connection(response, response)
5737+
5738+
table = Table(
5739+
self.TABLE_REF, schema=[SchemaField("field_x", "INTEGER", mode="NULLABLE")]
5740+
)
5741+
5742+
with freezegun.freeze_time("1970-01-01 00:00:00", tick=False) as frozen_time:
5743+
5744+
def delayed_get_table(*args, **kwargs):
5745+
frozen_time.tick(delta=1.4)
5746+
return table
5747+
5748+
client.get_table = mock.Mock(side_effect=delayed_get_table)
5749+
5750+
rows_iter = client.list_rows(
5751+
"{}.{}.{}".format(
5752+
self.TABLE_REF.project,
5753+
self.TABLE_REF.dataset_id,
5754+
self.TABLE_REF.table_id,
5755+
),
5756+
timeout=5.0,
5757+
)
5758+
six.next(rows_iter.pages)
5759+
5760+
client.get_table.assert_called_once()
5761+
_, kwargs = client.get_table.call_args
5762+
self.assertEqual(kwargs.get("timeout"), 5.0)
5763+
5764+
client._connection.api_request.assert_called_once()
5765+
_, kwargs = client._connection.api_request.call_args
5766+
self.assertAlmostEqual(kwargs.get("timeout"), 3.6)
5767+
56905768
def test_list_rows_error(self):
56915769
creds = _make_credentials()
56925770
http = object()

bigquery/tests/unit/test_job.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import textwrap
1919
import unittest
2020

21+
import freezegun
2122
import mock
2223
import pytest
2324
import requests
@@ -956,6 +957,24 @@ def test_result_explicit_w_state(self, result):
956957
begin.assert_not_called()
957958
result.assert_called_once_with(timeout=timeout)
958959

960+
@mock.patch("google.api_core.future.polling.PollingFuture.result")
961+
def test_result_splitting_timout_between_requests(self, result):
962+
client = _make_client(project=self.PROJECT)
963+
job = self._make_one(self.JOB_ID, client)
964+
begin = job._begin = mock.Mock()
965+
retry = mock.Mock()
966+
967+
with freezegun.freeze_time("1970-01-01 00:00:00", tick=False) as frozen_time:
968+
969+
def delayed_begin(*args, **kwargs):
970+
frozen_time.tick(delta=0.3)
971+
972+
begin.side_effect = delayed_begin
973+
job.result(retry=retry, timeout=1.0)
974+
975+
begin.assert_called_once_with(retry=retry, timeout=1.0)
976+
result.assert_called_once_with(timeout=0.7)
977+
959978
def test_cancelled_wo_error_result(self):
960979
client = _make_client(project=self.PROJECT)
961980
job = self._make_one(self.JOB_ID, client)
@@ -4551,7 +4570,8 @@ def test_result_w_timeout(self):
45514570
client = _make_client(project=self.PROJECT, connection=connection)
45524571
job = self._make_one(self.JOB_ID, self.QUERY, client)
45534572

4554-
job.result(timeout=1.0)
4573+
with freezegun.freeze_time("1970-01-01 00:00:00", tick=False):
4574+
job.result(timeout=1.0)
45554575

45564576
self.assertEqual(len(connection.api_request.call_args_list), 3)
45574577
begin_request = connection.api_request.call_args_list[0]
@@ -4566,6 +4586,49 @@ def test_result_w_timeout(self):
45664586
self.assertEqual(query_request[1]["query_params"]["timeoutMs"], 900)
45674587
self.assertEqual(reload_request[1]["method"], "GET")
45684588

4589+
@mock.patch("google.api_core.future.polling.PollingFuture.result")
4590+
def test_result_splitting_timout_between_requests(self, polling_result):
4591+
begun_resource = self._make_resource()
4592+
query_resource = {
4593+
"jobComplete": True,
4594+
"jobReference": {"projectId": self.PROJECT, "jobId": self.JOB_ID},
4595+
"schema": {"fields": [{"name": "col1", "type": "STRING"}]},
4596+
"totalRows": "5",
4597+
}
4598+
done_resource = copy.deepcopy(begun_resource)
4599+
done_resource["status"] = {"state": "DONE"}
4600+
4601+
connection = _make_connection(begun_resource, query_resource, done_resource)
4602+
client = _make_client(project=self.PROJECT, connection=connection)
4603+
job = self._make_one(self.JOB_ID, self.QUERY, client)
4604+
4605+
client.list_rows = mock.Mock()
4606+
4607+
with freezegun.freeze_time("1970-01-01 00:00:00", tick=False) as frozen_time:
4608+
4609+
def delayed_result(*args, **kwargs):
4610+
frozen_time.tick(delta=0.8)
4611+
4612+
polling_result.side_effect = delayed_result
4613+
4614+
def delayed_get_results(*args, **kwargs):
4615+
frozen_time.tick(delta=0.5)
4616+
return orig_get_results(*args, **kwargs)
4617+
4618+
orig_get_results = client._get_query_results
4619+
client._get_query_results = mock.Mock(side_effect=delayed_get_results)
4620+
job.result(timeout=2.0)
4621+
4622+
polling_result.assert_called_once_with(timeout=2.0)
4623+
4624+
client._get_query_results.assert_called_once()
4625+
_, kwargs = client._get_query_results.call_args
4626+
self.assertAlmostEqual(kwargs.get("timeout"), 1.2)
4627+
4628+
client.list_rows.assert_called_once()
4629+
_, kwargs = client.list_rows.call_args
4630+
self.assertAlmostEqual(kwargs.get("timeout"), 0.7)
4631+
45694632
def test_result_w_page_size(self):
45704633
# Arrange
45714634
query_results_resource = {

0 commit comments

Comments
 (0)