|
8 | 8 | # little differently. |
9 | 9 | import collections |
10 | 10 |
|
| 11 | +import plotly.graph_objects as go |
| 12 | +from typing import Literal, Optional, Tuple, TypedDict, Iterable |
| 13 | + |
11 | 14 | _single_subplot_types = {"scene", "geo", "polar", "ternary", "map", "mapbox"} |
12 | 15 | _subplot_types = set.union(_single_subplot_types, {"xy", "domain"}) |
13 | 16 |
|
|
31 | 34 | "SubplotRef", ("subplot_type", "layout_keys", "trace_kwargs") |
32 | 35 | ) |
33 | 36 |
|
| 37 | +class SubplotSpec(TypedDict): |
| 38 | + type : Literal['xy', 'scene', 'polar', 'ternary', 'map', 'mapbox', 'domain'] | str |
| 39 | + secondary_y : bool |
| 40 | + colspan : int |
| 41 | + rowspan : int |
| 42 | + l : float |
| 43 | + r : float |
| 44 | + t : float |
| 45 | + b : float |
| 46 | + |
34 | 47 |
|
35 | 48 | def _get_initial_max_subplot_ids(): |
36 | 49 | max_subplot_ids = {subplot_type: 0 for subplot_type in _single_subplot_types} |
@@ -889,103 +902,175 @@ def _check_hv_spacing(dimsize, spacing, name, dimvarname, dimname): |
889 | 902 |
|
890 | 903 | return figure |
891 | 904 |
|
892 | | - |
893 | 905 | def _configure_shared_axes( |
894 | | - layout, grid_ref, specs, x_or_y, shared, row_dir, secondary_y |
895 | | -): |
896 | | - rows = len(grid_ref) |
897 | | - cols = len(grid_ref[0]) |
898 | | - |
899 | | - layout_key_ind = ["x", "y"].index(x_or_y) |
900 | | - |
901 | | - if row_dir < 0: |
902 | | - rows_iter = range(rows - 1, -1, -1) |
903 | | - else: |
904 | | - rows_iter = range(rows) |
905 | | - |
906 | | - if secondary_y: |
907 | | - cols_iter = range(cols - 1, -1, -1) |
908 | | - axis_index = 1 |
909 | | - else: |
910 | | - cols_iter = range(cols) |
911 | | - axis_index = 0 |
912 | | - |
913 | | - def update_axis_matches(first_axis_id, subplot_ref, spec, remove_label): |
914 | | - if subplot_ref is None: |
915 | | - return first_axis_id |
916 | | - |
917 | | - if x_or_y == "x": |
918 | | - span = spec["colspan"] |
919 | | - match_axis = 'xaxis' |
920 | | - else: |
921 | | - span = spec["rowspan"] |
922 | | - match_axis = 'yaxis' |
| 906 | + layout : go.Layout, |
| 907 | + grid_ref : Tuple[Tuple[SubplotRef]], |
| 908 | + specs : Tuple[Tuple[SubplotSpec]], |
| 909 | + x_or_y : Literal['x', 'y'], |
| 910 | + shared : bool | Literal['rows', 'columns', 'all'], |
| 911 | + row_direction : Literal[1, -1], |
| 912 | + secondary_y : bool |
| 913 | +) -> None: |
| 914 | + ''' |
| 915 | + Sets the axes to be shared, making them use the same axis |
| 916 | +
|
| 917 | + Parameters: |
| 918 | + ----------- |
| 919 | + layout (go.Layout) : The layout of the figure to be updating |
| 920 | + grid_ref (Tuple[Tuple[SubplotRef]]) : The grid of subplots within the figure; grid_ref[row][column] = subplot at that coordinate |
| 921 | + specs (Tuple[Tuple[SubplotSpec]]) : The specifications of each of the subplots within the figure; specs[row][column] = specs of the subplot at that coordinate |
| 922 | + x_or_y ('x' | 'y') : The axis to make shared (x-axis or y-axis) |
| 923 | + shared ('rows' | 'columns' | 'all' | bool) : Share the axis within the row, column, or across all of the subplots (True defaults to columns mode) |
| 924 | + row_direction (1 | -1) : The directional that the rows go |
| 925 | + secondary_y (bool) : Whether there are different or shared y-axis |
| 926 | + ''' |
| 927 | + |
| 928 | + row_count : int = len(grid_ref) |
| 929 | + column_count : int = len(grid_ref[0]) |
923 | 930 |
|
924 | | - if subplot_ref.subplot_type == "xy" and span == 1: |
925 | | - if first_axis_id is None: |
926 | | - first_axis_name = subplot_ref.layout_keys[layout_key_ind] |
927 | | - first_axis_id = first_axis_name.replace("axis", "") |
928 | | - else: |
929 | | - axis_name = subplot_ref.layout_keys[layout_key_ind] |
930 | | - axis_to_match = layout[axis_name] |
931 | | - subplot_ref.trace_kwargs[match_axis] = first_axis_id # Changes the reference axis in the set up to the initial axis (the axis to match) |
932 | | - axis_to_match.matches = first_axis_id |
933 | | - if remove_label: |
934 | | - axis_to_match.showticklabels = False |
935 | | - |
936 | | - return first_axis_id |
937 | | - |
938 | | - if shared == "columns" or (x_or_y == "x" and shared is True): |
939 | | - for c in cols_iter: |
940 | | - first_axis_id = None |
941 | | - ok_to_remove_label = x_or_y == "x" |
942 | | - for r in rows_iter: |
943 | | - if not grid_ref[r][c]: |
944 | | - continue |
945 | | - if axis_index >= len(grid_ref[r][c]): |
946 | | - continue |
947 | | - subplot_ref = grid_ref[r][c][axis_index] |
948 | | - spec = specs[r][c] |
949 | | - first_axis_id = update_axis_matches( |
950 | | - first_axis_id, subplot_ref, spec, ok_to_remove_label |
951 | | - ) |
| 931 | + rows : Iterable[int] = tuple(range(row_count - 1, -1, -1)) if row_direction < 0 else tuple(range(row_count)) |
| 932 | + columns : Iterable[int] = tuple(range(column_count - 1, -1, -1)) if secondary_y else tuple(range(column_count)) |
952 | 933 |
|
953 | | - elif shared == "rows" or (x_or_y == "y" and shared is True): |
954 | | - for r in rows_iter: |
955 | | - first_axis_id = None |
956 | | - ok_to_remove_label = x_or_y == "y" |
957 | | - for c in cols_iter: |
958 | | - if not grid_ref[r][c]: |
959 | | - continue |
960 | | - if axis_index >= len(grid_ref[r][c]): |
| 934 | + axis_index : int = 1 if secondary_y else 0 |
| 935 | + layout_axis_index : int = 0 if x_or_y == 'x' else 1 |
| 936 | + |
| 937 | + def find_label_and_index(row_order : int | Iterable[int], column_order : int | Iterable[int]) -> Optional[Tuple[str, Tuple[int, int]]]: |
| 938 | + ''' |
| 939 | + Searches the grid through the row, column order provided (doing row, then column); will only check things that appear in those lists |
| 940 | +
|
| 941 | + Parameters: |
| 942 | + ----------- |
| 943 | + row_order (int | Iterable[int]): If an int, will look only at the that row index, else it will look at all of the rows in the order of the iterable |
| 944 | + column_order (int | Iterable[int]): If an int, will only look at that column index, else it will look at all of the columns in the order of the iterable |
| 945 | +
|
| 946 | + Return: |
| 947 | + ------- |
| 948 | + Returns (Label : str, (Row : int, Column : int)): returning the label found, and the row and column it was found at (uses x_or_y to determine which of the axes' labels to pull) |
| 949 | + Return (None): No label was found |
| 950 | + ''' |
| 951 | + |
| 952 | + # Turn them into lists with one element, so that both row_order and column_order are iterables |
| 953 | + row_order : Iterable[int] = [row_order] if isinstance(row_order, int) else row_order |
| 954 | + column_order : Iterable[int] = [column_order] if isinstance(column_order, int) else column_order |
| 955 | + |
| 956 | + |
| 957 | + # Iterate through the rows and columns |
| 958 | + for row in row_order: |
| 959 | + for column in column_order: |
| 960 | + if not grid_ref[row][column] or axis_index >= len(grid_ref[row][column]): |
| 961 | + continue |
| 962 | + |
| 963 | + subplot_reference : SubplotRef = grid_ref[row][column][axis_index] |
| 964 | + spec : SubplotSpec = specs[row][column] |
| 965 | + |
| 966 | + if subplot_reference is None: |
| 967 | + continue |
| 968 | + |
| 969 | + span = spec['colspan'] if x_or_y == 'x' else spec['rowspan'] |
| 970 | + if subplot_reference.subplot_type != 'xy' or span != 1: |
| 971 | + continue |
| 972 | + |
| 973 | + label_name : str = subplot_reference.layout_keys[layout_axis_index] |
| 974 | + label : str = label_name.replace("axis", "") |
| 975 | + return label, (row, column) |
| 976 | + return None |
| 977 | + |
| 978 | + |
| 979 | + def update_trace_axis(matched_label : str, row : int, column : int, can_remove_label : bool) -> None: |
| 980 | + ''' |
| 981 | + Updates the trace at the given row and column with the given label, and removes the label visibility if necessary |
| 982 | +
|
| 983 | + Parameters: |
| 984 | + ----------- |
| 985 | + matched_label (str) : The label to make the axis match (uses the x_or_y value to determine which of the axes to change), if there is a subplot at the given location |
| 986 | + row (int) : The row of the subplot within grid_ref to update |
| 987 | + column (int) : The column of the subplot within grid_ref to update |
| 988 | + can_remove_label (bool): Whether the label should be visible (only the bottom label should be visible) |
| 989 | + can_change_trace_kwargs (bool): If True the label itself can be changed directly to be the exact same axis (ie use the exact same axis in the trace keyword arguments), or if False, can only mark as matching (ie don't change the trace keyword args) |
| 990 | + ''' |
| 991 | + if not grid_ref[row][column] or axis_index >= len(grid_ref[row][column]): |
| 992 | + return |
| 993 | + |
| 994 | + subplot_reference : SubplotRef = grid_ref[row][column][axis_index] |
| 995 | + spec : SubplotSpec = specs[row][column] |
| 996 | + |
| 997 | + if subplot_reference is None: |
| 998 | + return |
| 999 | + |
| 1000 | + span = spec['colspan'] if x_or_y == 'x' else spec['rowspan'] |
| 1001 | + if subplot_reference.subplot_type != 'xy' or span != 1: |
| 1002 | + return |
| 1003 | + |
| 1004 | + axis_name : str = subplot_reference.layout_keys[layout_axis_index] |
| 1005 | + axis_dimension : str = 'xaxis' if x_or_y == 'x' else 'yaxis' |
| 1006 | + axis : go.XAxis = layout[axis_name] |
| 1007 | + |
| 1008 | + axis.matches = matched_label |
| 1009 | + subplot_reference.trace_kwargs[axis_dimension] = matched_label |
| 1010 | + |
| 1011 | + if can_remove_label: |
| 1012 | + axis.showticklabels = False |
| 1013 | + |
| 1014 | + def columns_mode(): |
| 1015 | + for column in columns: |
| 1016 | + # Get the label used by all the rows in the column |
| 1017 | + label_data = find_label_and_index(rows, column) |
| 1018 | + if label_data is None: |
| 1019 | + continue |
| 1020 | + column_label, (label_row, _) = label_data |
| 1021 | + # Set all of the values in the column |
| 1022 | + |
| 1023 | + can_remove_label : bool = (x_or_y == 'x') |
| 1024 | + |
| 1025 | + for row in rows: |
| 1026 | + if row == label_row: # Don't update the figure that the label we are matching comes from |
961 | 1027 | continue |
962 | | - subplot_ref = grid_ref[r][c][axis_index] |
963 | | - spec = specs[r][c] |
964 | | - first_axis_id = update_axis_matches( |
965 | | - first_axis_id, subplot_ref, spec, ok_to_remove_label |
966 | | - ) |
| 1028 | + |
| 1029 | + update_trace_axis(column_label, row, column, can_remove_label) |
967 | 1030 |
|
968 | | - elif shared == "all": |
969 | | - first_axis_id = None |
970 | | - for ri, r in enumerate(rows_iter): |
971 | | - for c in cols_iter: |
972 | | - if not grid_ref[r][c]: |
973 | | - continue |
974 | | - if axis_index >= len(grid_ref[r][c]): |
| 1031 | + |
| 1032 | + def rows_mode(): |
| 1033 | + for row in rows: |
| 1034 | + label_data = find_label_and_index(row, columns) |
| 1035 | + if label_data is None: |
| 1036 | + continue |
| 1037 | + row_label, (_, label_column) = label_data |
| 1038 | + |
| 1039 | + can_remove_label : bool = (x_or_y == 'y') |
| 1040 | + |
| 1041 | + for column in columns: |
| 1042 | + if column == label_column: # Don't update the figure that the label we are matching comes from |
| 1043 | + continue |
| 1044 | + |
| 1045 | + update_trace_axis(row_label, row, column, can_remove_label) |
| 1046 | + |
| 1047 | + def all_mode(): |
| 1048 | + label_data = find_label_and_index(rows, columns) |
| 1049 | + if label_data is None: |
| 1050 | + return |
| 1051 | + label, (label_row, label_column) = label_data |
| 1052 | + |
| 1053 | + for row_index, row in enumerate(rows): |
| 1054 | + for column in columns: |
| 1055 | + if row == label_row and column == label_column: # Don't update the figure that the label we are matching comes from |
975 | 1056 | continue |
976 | | - subplot_ref = grid_ref[r][c][axis_index] |
977 | | - spec = specs[r][c] |
978 | 1057 |
|
979 | | - if x_or_y == "y": |
980 | | - ok_to_remove_label = c < cols - 1 if secondary_y else c > 0 |
| 1058 | + if x_or_y == 'y': |
| 1059 | + can_remove_label : bool = (column < column_count - 1 if secondary_y else column > 0) |
981 | 1060 | else: |
982 | | - ok_to_remove_label = ri > 0 if row_dir > 0 else r < rows - 1 |
983 | | - |
984 | | - first_axis_id = update_axis_matches( |
985 | | - first_axis_id, subplot_ref, spec, ok_to_remove_label |
986 | | - ) |
987 | | - |
988 | | - |
| 1061 | + can_remove_label : bool = (row_index > 0 if row_direction > 0 else row < row_count - 1) |
| 1062 | + |
| 1063 | + update_trace_axis(label, row, column, can_remove_label) |
| 1064 | + |
| 1065 | + match(shared, x_or_y, shared): |
| 1066 | + case ('columns', _, _) | (_, 'x', True): # If columns mode, or shared and x |
| 1067 | + columns_mode() |
| 1068 | + case ('rows', _, _) | (_, 'y', True): # If rows mode, or shared and y |
| 1069 | + rows_mode() |
| 1070 | + case ('all', _, _): # If all mode |
| 1071 | + all_mode() |
| 1072 | + case _: # If reached the other case |
| 1073 | + return |
989 | 1074 |
|
990 | 1075 | def _init_subplot_xy(layout, secondary_y, x_domain, y_domain, max_subplot_ids=None): |
991 | 1076 | if max_subplot_ids is None: |
|
0 commit comments