diff --git a/bigframes/session/bq_caching_executor.py b/bigframes/session/bq_caching_executor.py index cf275154ce..c5d6fe3e5f 100644 --- a/bigframes/session/bq_caching_executor.py +++ b/bigframes/session/bq_caching_executor.py @@ -14,6 +14,7 @@ from __future__ import annotations +import concurrent.futures import math import threading from typing import Literal, Mapping, Optional, Sequence, Tuple @@ -514,13 +515,29 @@ def _substitute_large_local_sources(self, original_root: nodes.BigFrameNode): Replace large local sources with the uploaded version of those datasources. """ # Step 1: Upload all previously un-uploaded data + needs_upload = [] for leaf in original_root.unique_nodes(): if isinstance(leaf, nodes.ReadLocalNode): if ( leaf.local_data_source.metadata.total_bytes > bigframes.constants.MAX_INLINE_BYTES ): - self._upload_local_data(leaf.local_data_source) + needs_upload.append(leaf.local_data_source) + + futures: dict[concurrent.futures.Future, local_data.ManagedArrowTable] = dict() + for local_source in needs_upload: + future = self.loader.read_data_async( + local_source, bigframes.core.guid.generate_guid() + ) + futures[future] = local_source + try: + for future in concurrent.futures.as_completed(futures.keys()): + self.cache.cache_remote_replacement(futures[future], future.result()) + except Exception as e: + # cancel all futures + for future in futures: + future.cancel() + raise e # Step 2: Replace local scans with remote scans def map_local_scans(node: nodes.BigFrameNode): @@ -550,18 +567,6 @@ def map_local_scans(node: nodes.BigFrameNode): return original_root.bottom_up(map_local_scans) - def _upload_local_data(self, local_table: local_data.ManagedArrowTable): - if self.cache.get_uploaded_local_data(local_table) is not None: - return - # Lock prevents concurrent repeated work, but slows things down. - # Might be better as a queue and a worker thread - with self._upload_lock: - if self.cache.get_uploaded_local_data(local_table) is None: - uploaded = self.loader.load_data_or_write_data( - local_table, bigframes.core.guid.generate_guid() - ) - self.cache.cache_remote_replacement(local_table, uploaded) - def _execute_plan_gbq( self, plan: nodes.BigFrameNode, diff --git a/bigframes/session/loader.py b/bigframes/session/loader.py index 0944c0dab6..7b5d1bcaf1 100644 --- a/bigframes/session/loader.py +++ b/bigframes/session/loader.py @@ -300,6 +300,17 @@ def __init__( self._session = session self._clock = session_time.BigQuerySyncedClock(bqclient) self._clock.sync() + self._threadpool = concurrent.futures.ThreadPoolExecutor( + max_workers=1, thread_name_prefix="bigframes-loader" + ) + + def read_data_async( + self, local_data: local_data.ManagedArrowTable, offsets_col: str + ) -> concurrent.futures.Future[bq_data.BigqueryDataSource]: + future = self._threadpool.submit( + self._load_data_or_write_data, local_data, offsets_col + ) + return future def read_pandas( self, @@ -350,7 +361,7 @@ def read_managed_data( session=self._session, ) - def load_data_or_write_data( + def _load_data_or_write_data( self, data: local_data.ManagedArrowTable, offsets_col: str,