11from __future__ import annotations
22
3- import os
43import threading
54from concurrent .futures import ThreadPoolExecutor
65from dataclasses import dataclass , field
3736
3837
3938_pool : ThreadPoolExecutor | None = None
39+ _pool_size : int = 0
4040_pool_lock = threading .Lock ()
4141
4242
43- def _get_pool () -> ThreadPoolExecutor :
44- """Get or create the module-level thread pool for codec compute."""
45- global _pool
46- if _pool is None :
43+ def _resolve_max_workers () -> int :
44+ """Resolve ``codec_pipeline.max_workers`` config to an effective worker count.
45+
46+ ``None`` means "auto" → ``os.cpu_count()`` (or 1 if unavailable).
47+ Values < 1 are clamped to 1 (sequential).
48+ """
49+ import os as _os
50+
51+ cfg = config .get ("codec_pipeline.max_workers" , default = None )
52+ if cfg is None :
53+ return _os .cpu_count () or 1
54+ return max (1 , int (cfg ))
55+
56+
57+ def _get_pool (max_workers : int ) -> ThreadPoolExecutor :
58+ """Get or create the module-level thread pool, sized to ``max_workers``.
59+
60+ The pool grows on demand — if a request arrives for more workers than
61+ the current pool has, the existing pool is shut down and replaced.
62+ Shrinking requests reuse the existing larger pool (it just leaves
63+ workers idle).
64+
65+ Callers that want sequential execution should not call this — they
66+ should run the task list inline. ``max_workers`` must be >= 1.
67+ """
68+ global _pool , _pool_size
69+ if max_workers < 1 :
70+ raise ValueError (f"max_workers must be >= 1, got { max_workers } " )
71+ if _pool is None or _pool_size < max_workers :
4772 with _pool_lock :
48- if _pool is None :
49- max_workers = os .cpu_count () or 4
73+ if _pool is None or _pool_size < max_workers :
74+ if _pool is not None :
75+ _pool .shutdown (wait = False )
5076 _pool = ThreadPoolExecutor (max_workers = max_workers )
77+ _pool_size = max_workers
5178 return _pool
5279
5380
@@ -897,15 +924,19 @@ def read_sync(
897924 batch_info : Iterable [tuple [ByteGetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
898925 out : NDBuffer ,
899926 drop_axes : tuple [int , ...] = (),
900- n_workers : int = 0 ,
927+ max_workers : int = 1 ,
901928 ) -> tuple [GetResult , ...]:
902929 """Synchronous read: fetch -> decode -> scatter, per chunk.
903930
904- When ``n_workers > 0 `` and there are multiple chunks, each
931+ When ``max_workers > 1 `` and there are multiple chunks, each
905932 chunk's full lifecycle (fetch + decode + scatter) runs as one
906- task on the module-level thread pool — overlapping IO of one
907- chunk with decode/scatter of another. Scatter is thread-safe
908- because the chunks have non-overlapping output selections.
933+ task on a thread pool sized to ``max_workers`` — overlapping IO
934+ of one chunk with decode/scatter of another. Scatter is
935+ thread-safe because the chunks have non-overlapping output
936+ selections.
937+
938+ ``max_workers=1`` runs everything sequentially in the calling
939+ thread (no pool involvement).
909940
910941 Mirrors ``BatchedCodecPipeline.read_batch``: when the AB codec
911942 supports partial decoding (e.g. sharding), the codec handles its
@@ -943,8 +974,8 @@ def _read_one_partial(
943974 out [out_selection ] = decoded
944975 return GetResult (status = "present" )
945976
946- if n_workers > 0 and len (batch ) > 1 :
947- pool = _get_pool ()
977+ if max_workers > 1 and len (batch ) > 1 :
978+ pool = _get_pool (max_workers )
948979 return tuple (pool .map (_read_one_partial , batch ))
949980 return tuple (_read_one_partial (item ) for item in batch )
950981
@@ -964,8 +995,8 @@ def _read_one(
964995 out [out_selection ] = selected
965996 return GetResult (status = "present" )
966997
967- if n_workers > 0 and len (batch ) > 1 :
968- pool = _get_pool ()
998+ if max_workers > 1 and len (batch ) > 1 :
999+ pool = _get_pool (max_workers )
9691000 return tuple (pool .map (_read_one , batch ))
9701001 return tuple (_read_one (item ) for item in batch )
9711002
@@ -974,14 +1005,17 @@ def write_sync(
9741005 batch_info : Iterable [tuple [ByteSetter , ArraySpec , SelectorTuple , SelectorTuple , bool ]],
9751006 value : NDBuffer ,
9761007 drop_axes : tuple [int , ...] = (),
977- n_workers : int = 0 ,
1008+ max_workers : int = 1 ,
9781009 ) -> None :
9791010 """Synchronous write: fetch existing -> merge+encode -> store.
9801011
981- When ``n_workers > 0 `` and there are multiple chunks, each
1012+ When ``max_workers > 1 `` and there are multiple chunks, each
9821013 chunk's full lifecycle (get-existing + merge + encode + set/delete)
983- runs as one task on the module-level thread pool — overlapping
984- IO of one chunk with compute of another.
1014+ runs as one task on a thread pool sized to ``max_workers`` —
1015+ overlapping IO of one chunk with compute of another.
1016+
1017+ ``max_workers=1`` runs everything sequentially in the calling
1018+ thread (no pool involvement).
9851019
9861020 When the codec pipeline supports partial encoding (e.g. a
9871021 sharding codec with no outer AA/BB codecs), the AB codec handles
@@ -1010,8 +1044,8 @@ def _encode_one_partial(
10101044 chunk_value = value if scalar else value [out_selection ]
10111045 codec ._encode_partial_sync (bs , chunk_value , chunk_selection , chunk_spec )
10121046
1013- if n_workers > 0 and len (batch ) > 1 :
1014- pool = _get_pool ()
1047+ if max_workers > 1 and len (batch ) > 1 :
1048+ pool = _get_pool (max_workers )
10151049 # consume the iterator to surface exceptions
10161050 list (pool .map (_encode_one_partial , batch ))
10171051 else :
@@ -1054,8 +1088,8 @@ def _write_one(
10541088 else :
10551089 bs .set_sync (encoded )
10561090
1057- if n_workers > 0 and len (batch ) > 1 :
1058- pool = _get_pool ()
1091+ if max_workers > 1 and len (batch ) > 1 :
1092+ pool = _get_pool (max_workers )
10591093 list (pool .map (_write_one , batch ))
10601094 else :
10611095 for item in batch :
@@ -1083,9 +1117,7 @@ async def read(
10831117 and isinstance (first_bg , StorePath )
10841118 and isinstance (first_bg .store , SupportsGetSync )
10851119 ):
1086- return self .read_sync (
1087- batch , out , drop_axes , n_workers = int (config .get ("async.concurrency" ) or 0 )
1088- )
1120+ return self .read_sync (batch , out , drop_axes , max_workers = _resolve_max_workers ())
10891121
10901122 # Async fallback: fetch all chunks, decode via async codec API, scatter
10911123 chunk_bytes_batch = await concurrent_map (
@@ -1134,9 +1166,7 @@ async def write(
11341166 and isinstance (first_bs , StorePath )
11351167 and isinstance (first_bs .store , SupportsSetSync )
11361168 ):
1137- self .write_sync (
1138- batch , value , drop_axes , n_workers = int (config .get ("async.concurrency" ) or 0 )
1139- )
1169+ self .write_sync (batch , value , drop_axes , max_workers = _resolve_max_workers ())
11401170 return
11411171
11421172 # Async fallback: same pattern as BatchedCodecPipeline.write_batch
0 commit comments