Skip to content

Commit 4b5cb84

Browse files
Rewrite to make it pass tests
There was an issue with the ordering in finding the next axis, so I ended up rewriting the section to try to fix the error - One of the tests in the make_subplots fails, but unsure if that is unintended, as by the spec, the xaxis should be shared when using shared_axis (but it isn't in the test; each subplot has its own axis) - Also the ruff check fails, but it is because the Dictionary typed is matching the plotly
1 parent a243cee commit 4b5cb84

1 file changed

Lines changed: 174 additions & 89 deletions

File tree

plotly/_subplots.py

Lines changed: 174 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
# little differently.
99
import collections
1010

11+
import plotly.graph_objects as go
12+
from typing import Literal, Optional, Tuple, TypedDict, Iterable
13+
1114
_single_subplot_types = {"scene", "geo", "polar", "ternary", "map", "mapbox"}
1215
_subplot_types = set.union(_single_subplot_types, {"xy", "domain"})
1316

@@ -31,6 +34,16 @@
3134
"SubplotRef", ("subplot_type", "layout_keys", "trace_kwargs")
3235
)
3336

37+
class SubplotSpec(TypedDict):
38+
type : Literal['xy', 'scene', 'polar', 'ternary', 'map', 'mapbox', 'domain'] | str
39+
secondary_y : bool
40+
colspan : int
41+
rowspan : int
42+
l : float
43+
r : float
44+
t : float
45+
b : float
46+
3447

3548
def _get_initial_max_subplot_ids():
3649
max_subplot_ids = {subplot_type: 0 for subplot_type in _single_subplot_types}
@@ -889,103 +902,175 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname):
889902

890903
return figure
891904

892-
893905
def _configure_shared_axes(
894-
layout, grid_ref, specs, x_or_y, shared, row_dir, secondary_y
895-
):
896-
rows = len(grid_ref)
897-
cols = len(grid_ref[0])
898-
899-
layout_key_ind = ["x", "y"].index(x_or_y)
900-
901-
if row_dir < 0:
902-
rows_iter = range(rows - 1, -1, -1)
903-
else:
904-
rows_iter = range(rows)
905-
906-
if secondary_y:
907-
cols_iter = range(cols - 1, -1, -1)
908-
axis_index = 1
909-
else:
910-
cols_iter = range(cols)
911-
axis_index = 0
912-
913-
def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label):
914-
if subplot_ref is None:
915-
return first_axis_id
916-
917-
if x_or_y == "x":
918-
span = spec["colspan"]
919-
match_axis = 'xaxis'
920-
else:
921-
span = spec["rowspan"]
922-
match_axis = 'yaxis'
906+
layout : go.Layout,
907+
grid_ref : Tuple[Tuple[SubplotRef]],
908+
specs : Tuple[Tuple[SubplotSpec]],
909+
x_or_y : Literal['x', 'y'],
910+
shared : bool | Literal['rows', 'columns', 'all'],
911+
row_direction : Literal[1, -1],
912+
secondary_y : bool
913+
) -> None:
914+
'''
915+
Sets the axes to be shared, making them use the same axis
916+
917+
Parameters:
918+
-----------
919+
layout (go.Layout) : The layout of the figure to be updating
920+
grid_ref (Tuple[Tuple[SubplotRef]]) : The grid of subplots within the figure; grid_ref[row][column] = subplot at that coordinate
921+
specs (Tuple[Tuple[SubplotSpec]]) : The specifications of each of the subplots within the figure; specs[row][column] = specs of the subplot at that coordinate
922+
x_or_y ('x' | 'y') : The axis to make shared (x-axis or y-axis)
923+
shared ('rows' | 'columns' | 'all' | bool) : Share the axis within the row, column, or across all of the subplots (True defaults to columns mode)
924+
row_direction (1 | -1) : The directional that the rows go
925+
secondary_y (bool) : Whether there are different or shared y-axis
926+
'''
927+
928+
row_count : int = len(grid_ref)
929+
column_count : int = len(grid_ref[0])
923930

924-
if subplot_ref.subplot_type == "xy" and span == 1:
925-
if first_axis_id is None:
926-
first_axis_name = subplot_ref.layout_keys[layout_key_ind]
927-
first_axis_id = first_axis_name.replace("axis", "")
928-
else:
929-
axis_name = subplot_ref.layout_keys[layout_key_ind]
930-
axis_to_match = layout[axis_name]
931-
subplot_ref.trace_kwargs[match_axis] = first_axis_id # Changes the reference axis in the set up to the initial axis (the axis to match)
932-
axis_to_match.matches = first_axis_id
933-
if remove_label:
934-
axis_to_match.showticklabels = False
935-
936-
return first_axis_id
937-
938-
if shared == "columns" or (x_or_y == "x" and shared is True):
939-
for c in cols_iter:
940-
first_axis_id = None
941-
ok_to_remove_label = x_or_y == "x"
942-
for r in rows_iter:
943-
if not grid_ref[r][c]:
944-
continue
945-
if axis_index >= len(grid_ref[r][c]):
946-
continue
947-
subplot_ref = grid_ref[r][c][axis_index]
948-
spec = specs[r][c]
949-
first_axis_id = update_axis_matches(
950-
first_axis_id, subplot_ref, spec, ok_to_remove_label
951-
)
931+
rows : Iterable[int] = tuple(range(row_count - 1, -1, -1)) if row_direction < 0 else tuple(range(row_count))
932+
columns : Iterable[int] = tuple(range(column_count - 1, -1, -1)) if secondary_y else tuple(range(column_count))
952933

953-
elif shared == "rows" or (x_or_y == "y" and shared is True):
954-
for r in rows_iter:
955-
first_axis_id = None
956-
ok_to_remove_label = x_or_y == "y"
957-
for c in cols_iter:
958-
if not grid_ref[r][c]:
959-
continue
960-
if axis_index >= len(grid_ref[r][c]):
934+
axis_index : int = 1 if secondary_y else 0
935+
layout_axis_index : int = 0 if x_or_y == 'x' else 1
936+
937+
def find_label_and_index(row_order : int | Iterable[int], column_order : int | Iterable[int]) -> Optional[Tuple[str, Tuple[int, int]]]:
938+
'''
939+
Searches the grid through the row, column order provided (doing row, then column); will only check things that appear in those lists
940+
941+
Parameters:
942+
-----------
943+
row_order (int | Iterable[int]): If an int, will look only at the that row index, else it will look at all of the rows in the order of the iterable
944+
column_order (int | Iterable[int]): If an int, will only look at that column index, else it will look at all of the columns in the order of the iterable
945+
946+
Return:
947+
-------
948+
Returns (Label : str, (Row : int, Column : int)): returning the label found, and the row and column it was found at (uses x_or_y to determine which of the axes' labels to pull)
949+
Return (None): No label was found
950+
'''
951+
952+
# Turn them into lists with one element, so that both row_order and column_order are iterables
953+
row_order : Iterable[int] = [row_order] if isinstance(row_order, int) else row_order
954+
column_order : Iterable[int] = [column_order] if isinstance(column_order, int) else column_order
955+
956+
957+
# Iterate through the rows and columns
958+
for row in row_order:
959+
for column in column_order:
960+
if not grid_ref[row][column] or axis_index >= len(grid_ref[row][column]):
961+
continue
962+
963+
subplot_reference : SubplotRef = grid_ref[row][column][axis_index]
964+
spec : SubplotSpec = specs[row][column]
965+
966+
if subplot_reference is None:
967+
continue
968+
969+
span = spec['colspan'] if x_or_y == 'x' else spec['rowspan']
970+
if subplot_reference.subplot_type != 'xy' or span != 1:
971+
continue
972+
973+
label_name : str = subplot_reference.layout_keys[layout_axis_index]
974+
label : str = label_name.replace("axis", "")
975+
return label, (row, column)
976+
return None
977+
978+
979+
def update_trace_axis(matched_label : str, row : int, column : int, can_remove_label : bool) -> None:
980+
'''
981+
Updates the trace at the given row and column with the given label, and removes the label visibility if necessary
982+
983+
Parameters:
984+
-----------
985+
matched_label (str) : The label to make the axis match (uses the x_or_y value to determine which of the axes to change), if there is a subplot at the given location
986+
row (int) : The row of the subplot within grid_ref to update
987+
column (int) : The column of the subplot within grid_ref to update
988+
can_remove_label (bool): Whether the label should be visible (only the bottom label should be visible)
989+
can_change_trace_kwargs (bool): If True the label itself can be changed directly to be the exact same axis (ie use the exact same axis in the trace keyword arguments), or if False, can only mark as matching (ie don't change the trace keyword args)
990+
'''
991+
if not grid_ref[row][column] or axis_index >= len(grid_ref[row][column]):
992+
return
993+
994+
subplot_reference : SubplotRef = grid_ref[row][column][axis_index]
995+
spec : SubplotSpec = specs[row][column]
996+
997+
if subplot_reference is None:
998+
return
999+
1000+
span = spec['colspan'] if x_or_y == 'x' else spec['rowspan']
1001+
if subplot_reference.subplot_type != 'xy' or span != 1:
1002+
return
1003+
1004+
axis_name : str = subplot_reference.layout_keys[layout_axis_index]
1005+
axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis'
1006+
axis : go.XAxis = layout[axis_name]
1007+
1008+
axis.matches = matched_label
1009+
subplot_reference.trace_kwargs[axis_dimension] = matched_label
1010+
1011+
if can_remove_label:
1012+
axis.showticklabels = False
1013+
1014+
def columns_mode():
1015+
for column in columns:
1016+
# Get the label used by all the rows in the column
1017+
label_data = find_label_and_index(rows, column)
1018+
if label_data is None:
1019+
continue
1020+
column_label, (label_row, _) = label_data
1021+
# Set all of the values in the column
1022+
1023+
can_remove_label : bool = (x_or_y == 'x')
1024+
1025+
for row in rows:
1026+
if row == label_row: # Don't update the figure that the label we are matching comes from
9611027
continue
962-
subplot_ref = grid_ref[r][c][axis_index]
963-
spec = specs[r][c]
964-
first_axis_id = update_axis_matches(
965-
first_axis_id, subplot_ref, spec, ok_to_remove_label
966-
)
1028+
1029+
update_trace_axis(column_label, row, column, can_remove_label)
9671030

968-
elif shared == "all":
969-
first_axis_id = None
970-
for ri, r in enumerate(rows_iter):
971-
for c in cols_iter:
972-
if not grid_ref[r][c]:
973-
continue
974-
if axis_index >= len(grid_ref[r][c]):
1031+
1032+
def rows_mode():
1033+
for row in rows:
1034+
label_data = find_label_and_index(row, columns)
1035+
if label_data is None:
1036+
continue
1037+
row_label, (_, label_column) = label_data
1038+
1039+
can_remove_label : bool = (x_or_y == 'y')
1040+
1041+
for column in columns:
1042+
if column == label_column: # Don't update the figure that the label we are matching comes from
1043+
continue
1044+
1045+
update_trace_axis(row_label, row, column, can_remove_label)
1046+
1047+
def all_mode():
1048+
label_data = find_label_and_index(rows, columns)
1049+
if label_data is None:
1050+
return
1051+
label, (label_row, label_column) = label_data
1052+
1053+
for row_index, row in enumerate(rows):
1054+
for column in columns:
1055+
if row == label_row and column == label_column: # Don't update the figure that the label we are matching comes from
9751056
continue
976-
subplot_ref = grid_ref[r][c][axis_index]
977-
spec = specs[r][c]
9781057

979-
if x_or_y == "y":
980-
ok_to_remove_label = c < cols - 1 if secondary_y else c > 0
1058+
if x_or_y == 'y':
1059+
can_remove_label : bool = (column < column_count - 1 if secondary_y else column > 0)
9811060
else:
982-
ok_to_remove_label = ri > 0 if row_dir > 0 else r < rows - 1
983-
984-
first_axis_id = update_axis_matches(
985-
first_axis_id, subplot_ref, spec, ok_to_remove_label
986-
)
987-
988-
1061+
can_remove_label : bool = (row_index > 0 if row_direction > 0 else row < row_count - 1)
1062+
1063+
update_trace_axis(label, row, column, can_remove_label)
1064+
1065+
match(shared, x_or_y, shared):
1066+
case ('columns', _, _) | (_, 'x', True): # If columns mode, or shared and x
1067+
columns_mode()
1068+
case ('rows', _, _) | (_, 'y', True): # If rows mode, or shared and y
1069+
rows_mode()
1070+
case ('all', _, _): # If all mode
1071+
all_mode()
1072+
case _: # If reached the other case
1073+
return
9891074

9901075
def _init_subplot_xy(layout, secondary_y, x_domain, y_domain, max_subplot_ids=None):
9911076
if max_subplot_ids is None:

0 commit comments

Comments
 (0)