Skip to content

Commit a31da2b

Browse files
Jammy2211claude
authored andcommitted
perf: defer matplotlib.pyplot via wrapper functions in plot/utils
Replace all module-level `import matplotlib.pyplot as plt` across plot files with `subplots()` and `get_cmap()` wrappers from plot/utils that import matplotlib lazily on first call. This prevents matplotlib from loading during `import autoarray`, deferring ~0.3s of import cost to first plot use. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d0124a9 commit a31da2b

File tree

11 files changed

+63
-46
lines changed

11 files changed

+63
-46
lines changed

autoarray/dataset/plot/imaging_plots.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from typing import Optional
22

3-
import matplotlib.pyplot as plt
43

5-
from autoarray.plot.utils import subplot_save, conf_subplot_figsize, tight_layout
4+
from autoarray.plot.utils import subplots, subplot_save, conf_subplot_figsize, tight_layout
65

76

87
def subplot_imaging_dataset(
@@ -51,7 +50,7 @@ def subplot_imaging_dataset(
5150

5251
from autoarray.plot.array import plot_array
5352

54-
fig, axes = plt.subplots(3, 3, figsize=conf_subplot_figsize(3, 3))
53+
fig, axes = subplots(3, 3, figsize=conf_subplot_figsize(3, 3))
5554
axes = axes.flatten()
5655

5756
plot_array(
@@ -172,7 +171,7 @@ def subplot_imaging_dataset_list(
172171
from autoarray.plot.array import plot_array
173172

174173
n = len(dataset_list)
175-
fig, axes = plt.subplots(n, 3, figsize=conf_subplot_figsize(n, 3))
174+
fig, axes = subplots(n, 3, figsize=conf_subplot_figsize(n, 3))
176175
if n == 1:
177176
axes = [axes]
178177
for i, dataset in enumerate(dataset_list):

autoarray/dataset/plot/interferometer_plots.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import numpy as np
22
from typing import Optional
33

4-
import matplotlib.pyplot as plt
54

65
from autoarray.plot.array import plot_array
76
from autoarray.plot.grid import plot_grid
87
from autoarray.plot.yx import plot_yx
9-
from autoarray.plot.utils import subplot_save, hide_unused_axes, conf_subplot_figsize, tight_layout
8+
from autoarray.plot.utils import subplots, subplot_save, hide_unused_axes, conf_subplot_figsize, tight_layout
109
from autoarray.structures.grids.irregular_2d import Grid2DIrregular
1110

1211

@@ -39,7 +38,7 @@ def subplot_interferometer_dataset(
3938
use_log10
4039
Apply log10 normalisation to image panels.
4140
"""
42-
fig, axes = plt.subplots(2, 3, figsize=conf_subplot_figsize(2, 3))
41+
fig, axes = subplots(2, 3, figsize=conf_subplot_figsize(2, 3))
4342
axes = axes.flatten()
4443

4544
plot_grid(dataset.data.in_grid, ax=axes[0], title="Visibilities", xlabel="", ylabel="")
@@ -117,7 +116,7 @@ def subplot_interferometer_dirty_images(
117116
use_log10
118117
Apply log10 normalisation.
119118
"""
120-
fig, axes = plt.subplots(1, 3, figsize=conf_subplot_figsize(1, 3))
119+
fig, axes = subplots(1, 3, figsize=conf_subplot_figsize(1, 3))
121120

122121
plot_array(
123122
dataset.dirty_image,

autoarray/fit/plot/fit_imaging_plots.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from typing import Optional
22

3-
import matplotlib.pyplot as plt
43

54
from autoarray.plot.array import plot_array
6-
from autoarray.plot.utils import subplot_save, symmetric_vmin_vmax, hide_unused_axes, conf_subplot_figsize, tight_layout
5+
from autoarray.plot.utils import subplots, subplot_save, symmetric_vmin_vmax, hide_unused_axes, conf_subplot_figsize, tight_layout
76

87

98
def subplot_fit_imaging(
@@ -43,7 +42,7 @@ def subplot_fit_imaging(
4342
grid, positions, lines
4443
Optional overlays forwarded to every panel.
4544
"""
46-
fig, axes = plt.subplots(2, 3, figsize=conf_subplot_figsize(2, 3))
45+
fig, axes = subplots(2, 3, figsize=conf_subplot_figsize(2, 3))
4746
axes = axes.flatten()
4847

4948
plot_array(

autoarray/fit/plot/fit_interferometer_plots.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import numpy as np
22
from typing import Optional
33

4-
import matplotlib.pyplot as plt
54

65
from autoarray.plot.array import plot_array
76
from autoarray.plot.yx import plot_yx
8-
from autoarray.plot.utils import subplot_save, symmetric_vmin_vmax, hide_unused_axes, conf_subplot_figsize, tight_layout
7+
from autoarray.plot.utils import subplots, subplot_save, symmetric_vmin_vmax, hide_unused_axes, conf_subplot_figsize, tight_layout
98

109

1110
def subplot_fit_interferometer(
@@ -40,7 +39,7 @@ def subplot_fit_interferometer(
4039
Not used here (UV-plane residuals are scatter plots); kept for API
4140
consistency.
4241
"""
43-
fig, axes = plt.subplots(2, 3, figsize=conf_subplot_figsize(2, 3))
42+
fig, axes = subplots(2, 3, figsize=conf_subplot_figsize(2, 3))
4443
axes = axes.flatten()
4544

4645
uv = fit.dataset.uv_distances / 10**3.0
@@ -135,7 +134,7 @@ def subplot_fit_interferometer_dirty_images(
135134
residuals_symmetric_cmap
136135
Centre residual colour scale symmetrically around zero.
137136
"""
138-
fig, axes = plt.subplots(2, 3, figsize=conf_subplot_figsize(2, 3))
137+
fig, axes = subplots(2, 3, figsize=conf_subplot_figsize(2, 3))
139138
axes = axes.flatten()
140139

141140
plot_array(

autoarray/inversion/plot/inversion_plots.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
from pathlib import Path
55
from typing import Optional, Union
66

7-
import matplotlib.pyplot as plt
87
from autoconf import conf
98

109
from autoarray.inversion.mappers.abstract import Mapper
1110
from autoarray.plot.array import plot_array
12-
from autoarray.plot.utils import numpy_grid, numpy_lines, numpy_positions, subplot_save, hide_unused_axes, conf_subplot_figsize, tight_layout
11+
from autoarray.plot.utils import subplots, numpy_grid, numpy_lines, numpy_positions, subplot_save, hide_unused_axes, conf_subplot_figsize, tight_layout
1312
from autoarray.inversion.plot.mapper_plots import plot_mapper
1413
from autoarray.structures.arrays.uniform_2d import Array2D
1514

@@ -53,7 +52,7 @@ def subplot_of_mapper(
5352
"""
5453
mapper = inversion.cls_list_from(cls=Mapper)[mapper_index]
5554

56-
fig, axes = plt.subplots(3, 4, figsize=conf_subplot_figsize(3, 4))
55+
fig, axes = subplots(3, 4, figsize=conf_subplot_figsize(3, 4))
5756
axes = axes.flatten()
5857

5958
# panel 0: data subtracted
@@ -279,7 +278,7 @@ def subplot_mappings(
279278
)
280279
mapper.slim_indexes_for_pix_indexes(pix_indexes=pix_indexes)
281280

282-
fig, axes = plt.subplots(2, 2, figsize=conf_subplot_figsize(2, 2))
281+
fig, axes = subplots(2, 2, figsize=conf_subplot_figsize(2, 2))
283282
axes = axes.flatten()
284283

285284
# panel 0: data subtracted

autoarray/inversion/plot/mapper_plots.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import logging
22
from typing import Optional
33

4-
import matplotlib.pyplot as plt
54

65
from autoarray.plot.array import plot_array
76
from autoarray.plot.inversion import plot_inversion_reconstruction
8-
from autoarray.plot.utils import numpy_grid, numpy_lines, subplot_save, conf_subplot_figsize, tight_layout
7+
from autoarray.plot.utils import subplots, numpy_grid, numpy_lines, subplot_save, conf_subplot_figsize, tight_layout
98

109
logger = logging.getLogger(__name__)
1110

@@ -114,7 +113,7 @@ def subplot_image_and_mapper(
114113
lines
115114
Lines to overlay on both panels.
116115
"""
117-
fig, axes = plt.subplots(1, 2, figsize=conf_subplot_figsize(1, 2))
116+
fig, axes = subplots(1, 2, figsize=conf_subplot_figsize(1, 2))
118117

119118
plot_array(
120119
image,

autoarray/plot/array.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55
import os
66
from typing import List, Optional, Tuple
77

8-
import matplotlib.pyplot as plt
98
import numpy as np
10-
from matplotlib.colors import LogNorm, Normalize
11-
129
from autoarray.plot.utils import (
10+
subplots,
1311
apply_extent,
1412
apply_labels,
1513
conf_figsize,
@@ -29,7 +27,7 @@
2927

3028
def plot_array(
3129
array,
32-
ax: Optional[plt.Axes] = None,
30+
ax=None,
3331
# --- spatial metadata -------------------------------------------------------
3432
extent: Optional[Tuple[float, float, float, float]] = None,
3533
# --- overlays ---------------------------------------------------------------
@@ -158,7 +156,7 @@ def plot_array(
158156
owns_figure = ax is None
159157
if owns_figure:
160158
figsize = figsize or conf_figsize("figures")
161-
fig, ax = plt.subplots(1, 1, figsize=figsize)
159+
fig, ax = subplots(1, 1, figsize=figsize)
162160
else:
163161
fig = ax.get_figure()
164162

@@ -181,8 +179,10 @@ def plot_array(
181179
vmax_log = np.nanmax(clipped)
182180
if not np.isfinite(vmax_log) or vmax_log <= vmin_log:
183181
vmax_log = vmin_log * 10.0
182+
from matplotlib.colors import LogNorm
184183
norm = LogNorm(vmin=vmin_log, vmax=vmax_log)
185184
elif vmin is not None or vmax is not None:
185+
from matplotlib.colors import Normalize
186186
norm = Normalize(vmin=vmin, vmax=vmax)
187187
else:
188188
norm = None

autoarray/plot/grid.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
from typing import Iterable, List, Optional, Tuple
88

9-
import matplotlib.pyplot as plt
109
import numpy as np
1110

1211
from autoarray.plot.utils import (
12+
subplots,
13+
get_cmap,
1314
apply_extent,
1415
apply_labels,
1516
conf_figsize,
@@ -20,7 +21,7 @@
2021

2122
def plot_grid(
2223
grid,
23-
ax: Optional[plt.Axes] = None,
24+
ax=None,
2425
# --- errors -----------------------------------------------------------------
2526
y_errors: Optional[np.ndarray] = None,
2627
x_errors: Optional[np.ndarray] = None,
@@ -108,13 +109,13 @@ def plot_grid(
108109
owns_figure = ax is None
109110
if owns_figure:
110111
figsize = figsize or conf_figsize("figures")
111-
fig, ax = plt.subplots(1, 1, figsize=figsize)
112+
fig, ax = subplots(1, 1, figsize=figsize)
112113
else:
113114
fig = ax.get_figure()
114115

115116
# --- scatter / errorbar ----------------------------------------------------
116117
if color_array is not None:
117-
cmap = plt.get_cmap(colormap)
118+
cmap = get_cmap(colormap)
118119
colors = cmap((color_array - color_array.min()) / (np.ptp(color_array) or 1))
119120

120121
if y_errors is None and x_errors is None:

autoarray/plot/inversion.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,16 @@
66

77
from typing import List, Optional, Tuple
88

9-
import matplotlib.pyplot as plt
109
import numpy as np
1110
from matplotlib.colors import LogNorm, Normalize
1211

13-
from autoarray.plot.utils import apply_extent, apply_labels, conf_figsize, save_figure, _conf_imshow_origin
12+
from autoarray.plot.utils import subplots, apply_extent, apply_labels, conf_figsize, save_figure, _conf_imshow_origin
1413

1514

1615
def plot_inversion_reconstruction(
1716
pixel_values: np.ndarray,
1817
mapper,
19-
ax: Optional[plt.Axes] = None,
18+
ax=None,
2019
# --- cosmetics --------------------------------------------------------------
2120
title: str = "Reconstruction",
2221
xlabel: str = 'x (")',
@@ -84,7 +83,7 @@ def plot_inversion_reconstruction(
8483
owns_figure = ax is None
8584
if owns_figure:
8685
figsize = figsize or conf_figsize("figures")
87-
fig, ax = plt.subplots(1, 1, figsize=figsize)
86+
fig, ax = subplots(1, 1, figsize=figsize)
8887
else:
8988
fig = ax.get_figure()
9089

0 commit comments

Comments
 (0)