Skip to content

Commit faf44a0

Browse files
d-v-bclaude
andcommitted
feat(pipeline): codec_pipeline.max_workers config controls sync pool size
Replace n_workers (sentinel '0 = no parallelism') with max_workers (integer, default 1 = sequential). Passed explicitly into read_sync/ write_sync; the pipeline never reads config itself. Add codec_pipeline.max_workers config key (default None = auto = os.cpu_count()), resolved at the async-wrapper boundary by _resolve_max_workers(). The wrapper passes the resolved value into read_sync/write_sync. _get_pool(max_workers) sizes the pool on demand and grows it (replaces the existing pool) if a larger size is requested. Shrinking requests reuse the larger pool. This replaces the unwired threading.max_workers + async.concurrency mishmash with one explicit knob. async.concurrency continues to control BatchedCodecPipeline IO concurrency. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 45da373 commit faf44a0

3 files changed

Lines changed: 62 additions & 30 deletions

File tree

src/zarr/core/codec_pipeline.py

Lines changed: 60 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import os
43
import threading
54
from concurrent.futures import ThreadPoolExecutor
65
from dataclasses import dataclass, field
@@ -37,17 +36,45 @@
3736

3837

3938
_pool: ThreadPoolExecutor | None = None
39+
_pool_size: int = 0
4040
_pool_lock = threading.Lock()
4141

4242

43-
def _get_pool() -> ThreadPoolExecutor:
44-
"""Get or create the module-level thread pool for codec compute."""
45-
global _pool
46-
if _pool is None:
43+
def _resolve_max_workers() -> int:
44+
"""Resolve ``codec_pipeline.max_workers`` config to an effective worker count.
45+
46+
``None`` means "auto" → ``os.cpu_count()`` (or 1 if unavailable).
47+
Values < 1 are clamped to 1 (sequential).
48+
"""
49+
import os as _os
50+
51+
cfg = config.get("codec_pipeline.max_workers", default=None)
52+
if cfg is None:
53+
return _os.cpu_count() or 1
54+
return max(1, int(cfg))
55+
56+
57+
def _get_pool(max_workers: int) -> ThreadPoolExecutor:
58+
"""Get or create the module-level thread pool, sized to ``max_workers``.
59+
60+
The pool grows on demand — if a request arrives for more workers than
61+
the current pool has, the existing pool is shut down and replaced.
62+
Shrinking requests reuse the existing larger pool (it just leaves
63+
workers idle).
64+
65+
Callers that want sequential execution should not call this — they
66+
should run the task list inline. ``max_workers`` must be >= 1.
67+
"""
68+
global _pool, _pool_size
69+
if max_workers < 1:
70+
raise ValueError(f"max_workers must be >= 1, got {max_workers}")
71+
if _pool is None or _pool_size < max_workers:
4772
with _pool_lock:
48-
if _pool is None:
49-
max_workers = os.cpu_count() or 4
73+
if _pool is None or _pool_size < max_workers:
74+
if _pool is not None:
75+
_pool.shutdown(wait=False)
5076
_pool = ThreadPoolExecutor(max_workers=max_workers)
77+
_pool_size = max_workers
5178
return _pool
5279

5380

@@ -897,15 +924,19 @@ def read_sync(
897924
batch_info: Iterable[tuple[ByteGetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
898925
out: NDBuffer,
899926
drop_axes: tuple[int, ...] = (),
900-
n_workers: int = 0,
927+
max_workers: int = 1,
901928
) -> tuple[GetResult, ...]:
902929
"""Synchronous read: fetch -> decode -> scatter, per chunk.
903930
904-
When ``n_workers > 0`` and there are multiple chunks, each
931+
When ``max_workers > 1`` and there are multiple chunks, each
905932
chunk's full lifecycle (fetch + decode + scatter) runs as one
906-
task on the module-level thread pool — overlapping IO of one
907-
chunk with decode/scatter of another. Scatter is thread-safe
908-
because the chunks have non-overlapping output selections.
933+
task on a thread pool sized to ``max_workers`` — overlapping IO
934+
of one chunk with decode/scatter of another. Scatter is
935+
thread-safe because the chunks have non-overlapping output
936+
selections.
937+
938+
``max_workers=1`` runs everything sequentially in the calling
939+
thread (no pool involvement).
909940
910941
Mirrors ``BatchedCodecPipeline.read_batch``: when the AB codec
911942
supports partial decoding (e.g. sharding), the codec handles its
@@ -943,8 +974,8 @@ def _read_one_partial(
943974
out[out_selection] = decoded
944975
return GetResult(status="present")
945976

946-
if n_workers > 0 and len(batch) > 1:
947-
pool = _get_pool()
977+
if max_workers > 1 and len(batch) > 1:
978+
pool = _get_pool(max_workers)
948979
return tuple(pool.map(_read_one_partial, batch))
949980
return tuple(_read_one_partial(item) for item in batch)
950981

@@ -964,8 +995,8 @@ def _read_one(
964995
out[out_selection] = selected
965996
return GetResult(status="present")
966997

967-
if n_workers > 0 and len(batch) > 1:
968-
pool = _get_pool()
998+
if max_workers > 1 and len(batch) > 1:
999+
pool = _get_pool(max_workers)
9691000
return tuple(pool.map(_read_one, batch))
9701001
return tuple(_read_one(item) for item in batch)
9711002

@@ -974,14 +1005,17 @@ def write_sync(
9741005
batch_info: Iterable[tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple, bool]],
9751006
value: NDBuffer,
9761007
drop_axes: tuple[int, ...] = (),
977-
n_workers: int = 0,
1008+
max_workers: int = 1,
9781009
) -> None:
9791010
"""Synchronous write: fetch existing -> merge+encode -> store.
9801011
981-
When ``n_workers > 0`` and there are multiple chunks, each
1012+
When ``max_workers > 1`` and there are multiple chunks, each
9821013
chunk's full lifecycle (get-existing + merge + encode + set/delete)
983-
runs as one task on the module-level thread pool — overlapping
984-
IO of one chunk with compute of another.
1014+
runs as one task on a thread pool sized to ``max_workers`` —
1015+
overlapping IO of one chunk with compute of another.
1016+
1017+
``max_workers=1`` runs everything sequentially in the calling
1018+
thread (no pool involvement).
9851019
9861020
When the codec pipeline supports partial encoding (e.g. a
9871021
sharding codec with no outer AA/BB codecs), the AB codec handles
@@ -1010,8 +1044,8 @@ def _encode_one_partial(
10101044
chunk_value = value if scalar else value[out_selection]
10111045
codec._encode_partial_sync(bs, chunk_value, chunk_selection, chunk_spec)
10121046

1013-
if n_workers > 0 and len(batch) > 1:
1014-
pool = _get_pool()
1047+
if max_workers > 1 and len(batch) > 1:
1048+
pool = _get_pool(max_workers)
10151049
# consume the iterator to surface exceptions
10161050
list(pool.map(_encode_one_partial, batch))
10171051
else:
@@ -1054,8 +1088,8 @@ def _write_one(
10541088
else:
10551089
bs.set_sync(encoded)
10561090

1057-
if n_workers > 0 and len(batch) > 1:
1058-
pool = _get_pool()
1091+
if max_workers > 1 and len(batch) > 1:
1092+
pool = _get_pool(max_workers)
10591093
list(pool.map(_write_one, batch))
10601094
else:
10611095
for item in batch:
@@ -1083,9 +1117,7 @@ async def read(
10831117
and isinstance(first_bg, StorePath)
10841118
and isinstance(first_bg.store, SupportsGetSync)
10851119
):
1086-
return self.read_sync(
1087-
batch, out, drop_axes, n_workers=int(config.get("async.concurrency") or 0)
1088-
)
1120+
return self.read_sync(batch, out, drop_axes, max_workers=_resolve_max_workers())
10891121

10901122
# Async fallback: fetch all chunks, decode via async codec API, scatter
10911123
chunk_bytes_batch = await concurrent_map(
@@ -1134,9 +1166,7 @@ async def write(
11341166
and isinstance(first_bs, StorePath)
11351167
and isinstance(first_bs.store, SupportsSetSync)
11361168
):
1137-
self.write_sync(
1138-
batch, value, drop_axes, n_workers=int(config.get("async.concurrency") or 0)
1139-
)
1169+
self.write_sync(batch, value, drop_axes, max_workers=_resolve_max_workers())
11401170
return
11411171

11421172
# Async fallback: same pattern as BatchedCodecPipeline.write_batch

src/zarr/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def enable_gpu(self) -> ConfigSet:
106106
"codec_pipeline": {
107107
"path": "zarr.core.codec_pipeline.BatchedCodecPipeline",
108108
"batch_size": 1,
109+
"max_workers": None,
109110
},
110111
"codecs": {
111112
"blosc": "zarr.codecs.blosc.BloscCodec",

tests/test_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def test_config_defaults_set() -> None:
6363
"codec_pipeline": {
6464
"path": "zarr.core.codec_pipeline.BatchedCodecPipeline",
6565
"batch_size": 1,
66+
"max_workers": None,
6667
},
6768
"codecs": {
6869
"blosc": "zarr.codecs.blosc.BloscCodec",

0 commit comments

Comments
 (0)