@@ -6644,6 +6644,17 @@ def narrow_type_by_identity_equality(
66446644 # in the context of other type checker behaviour.
66456645 should_coerce_literals : bool
66466646
6647+ # custom_eq_indices:
6648+ # Operands at these indices define a custom `__eq__`. These can do arbitrary things, so we
6649+ # have to be more careful about what narrowing we can conclude from a successful comparison
6650+ custom_eq_indices : set [int ]
6651+
6652+ # enum_comparison_is_ambiguous:
6653+ # `if x is Fruits.APPLE` we know `x` is `Fruits.APPLE`, but `if x == Fruits.APPLE: ...`
6654+ # it could e.g. be an int or str if Fruits is an IntEnum or StrEnum.
6655+ # See ambiguous_enum_equality_keys for more details
6656+ enum_comparison_is_ambiguous : bool
6657+
66476658 if operator in {"is" , "is not" }:
66486659 is_target_for_value_narrowing = is_singleton_identity_type
66496660 should_coerce_literals = True
@@ -6665,91 +6676,103 @@ def narrow_type_by_identity_equality(
66656676 else :
66666677 raise AssertionError
66676678
6668- value_targets = []
6669- type_targets = []
6679+ partial_type_maps = []
6680+
6681+ # For each narrowable index, we see what we can narrow based on each relevant target
66706682 for i in expr_indices :
6671- expr_type = operand_types [i ]
6672- if should_coerce_literals :
6673- # TODO: doing this prevents narrowing a single-member Enum to literal
6674- # of its member, because we expand it here and then refuse to add equal
6675- # types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
6676- # `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
6677- # See testMatchEnumSingleChoice
6678- expr_type = coerce_to_literal (expr_type )
6683+ if i not in narrowable_indices :
6684+ continue
66796685 if i in custom_eq_indices :
6680- # We can't use types with custom __eq__ as targets for narrowing
6681- # E.g. if (x: int | None) == (y: CustomEq | None), we cannot narrow x to None
6686+ # Handled later
66826687 continue
6683- if is_target_for_value_narrowing (get_proper_type (expr_type )):
6684- value_targets .append ((i , TypeRange (expr_type , is_upper_bound = False )))
6685- else :
6686- type_targets .append ((i , TypeRange (expr_type , is_upper_bound = False )))
6687-
6688- partial_type_maps = []
66896688
6690- if value_targets :
6691- for i in expr_indices :
6692- if i not in narrowable_indices :
6689+ expr_type = operand_types [i ]
6690+ expanded_expr_type = try_expanding_sum_type_to_union (
6691+ coerce_to_literal (expr_type ), None
6692+ )
6693+ expr_enum_keys = ambiguous_enum_equality_keys (expr_type )
6694+ for j in expr_indices :
6695+ if i == j :
66936696 continue
6694- if i in custom_eq_indices :
6695- # Handled later
6697+ if j in custom_eq_indices :
6698+ # We can't use types with custom __eq__ as targets for narrowing
6699+ # E.g. if (x: int | None) == (y: CustomEq | None), we cannot narrow x to None
66966700 continue
6697- expr_type = operand_types [i ]
6698- expr_type = coerce_to_literal (expr_type )
6699- expr_type = try_expanding_sum_type_to_union (expr_type , None )
6700- expr_enum_keys = ambiguous_enum_equality_keys (expr_type )
6701- for j , target in value_targets :
6702- if i == j :
6703- continue
6704- if (
6705- # See comments in ambiguous_enum_equality_keys
6706- enum_comparison_is_ambiguous
6707- and len (expr_enum_keys | ambiguous_enum_equality_keys (target .item )) > 1
6708- ):
6709- continue
6701+ target_type = operand_types [j ]
6702+ if should_coerce_literals :
6703+ # TODO: doing this prevents narrowing a single-member Enum to literal
6704+ # of its member, because we expand it here and then refuse to add equal
6705+ # types to typemaps. As a result, `x: Foo; x == Foo.A` does not narrow
6706+ # `x` to `Literal[Foo.A]` iff `Foo` has exactly one member.
6707+ # See testMatchEnumSingleChoice
6708+ target_type = coerce_to_literal (target_type )
6709+
6710+ if (
6711+ # See comments in ambiguous_enum_equality_keys
6712+ enum_comparison_is_ambiguous
6713+ and len (expr_enum_keys | ambiguous_enum_equality_keys (target_type )) > 1
6714+ ):
6715+ continue
6716+
6717+ target = TypeRange (target_type , is_upper_bound = False )
6718+ is_value_target = is_target_for_value_narrowing (get_proper_type (target_type ))
6719+
6720+ if is_value_target :
67106721 if_map , else_map = conditional_types_to_typemaps (
6711- operands [i ], * conditional_types (expr_type , [target ])
6722+ operands [i ], * conditional_types (expanded_expr_type , [target ])
67126723 )
67136724 partial_type_maps .append ((if_map , else_map ))
6714-
6715- if type_targets :
6716- for i in expr_indices :
6717- if i not in narrowable_indices :
6718- continue
6719- if i in custom_eq_indices :
6720- # Handled later
6721- continue
6722- expr_type = operand_types [i ]
6723- for j , target in type_targets :
6724- if i == j :
6725- continue
6725+ else :
67266726 if_map , else_map = conditional_types_to_typemaps (
67276727 operands [i ], * conditional_types (expr_type , [target ])
67286728 )
6729+ # For value targets, it is safe to narrow in the negative case.
6730+ # e.g. if (x: Literal[5] | None) != (y: Literal[5]), we can narrow x to None
6731+ # However, for non-value targets, we cannot do this narrowing,
6732+ # and so we ignore else_map
6733+ # e.g. if (x: str | None) != (y: str), we cannot narrow x to None
67296734 if if_map :
6730- # For type_targets, we cannot narrow in the negative case
6731- # e.g. if (x: str | None) != (y: str), we cannot narrow x to None
6732- else_map = {}
6733- partial_type_maps .append ((if_map , else_map ))
6735+ partial_type_maps .append ((if_map , {}))
67346736
6737+ # Handle narrowing for operands with custom __eq__ methods specially
6738+ # In most cases, we won't be able to do any narrowing
67356739 for i in custom_eq_indices :
67366740 if i not in narrowable_indices :
67376741 continue
67386742 union_expr_type = get_proper_type (operand_types [i ])
67396743 if not isinstance (union_expr_type , UnionType ):
6744+ # Here we won't be able to do any positive narrowing, because we can't conclude
6745+ # anything from a custom __eq__ returning True.
6746+ # But we might be able to do some negative narrowing, since we can assume
6747+ # a custom __eq__ is reflexive. This should only apply to custom __eq__ enums,
6748+ # see testNarrowingEqualityCustomEqualityEnum
67406749 expr_type = operand_types [i ]
6741- for j , target in value_targets :
6742- _if_map , else_map = conditional_types_to_typemaps (
6743- operands [i ], * conditional_types (expr_type , [target ])
6744- )
6745- if else_map :
6746- partial_type_maps .append (({}, else_map ))
6750+ for j in expr_indices :
6751+ if j in custom_eq_indices :
6752+ continue
6753+ target_type = operand_types [j ]
6754+ if should_coerce_literals :
6755+ target_type = coerce_to_literal (target_type )
6756+ target = TypeRange (target_type , is_upper_bound = False )
6757+ is_value_target = is_target_for_value_narrowing (get_proper_type (target_type ))
6758+
6759+ if is_value_target :
6760+ if_map , else_map = conditional_types_to_typemaps (
6761+ operands [i ], * conditional_types (expr_type , [target ])
6762+ )
6763+ if else_map :
6764+ partial_type_maps .append (({}, else_map ))
67476765 continue
67486766
6767+ # If our operand with custom __eq__ is a union, where only some members of the union
6768+ # implement custom __eq__, then we can narrow down the other members as usual.
6769+ # This is basically the same logic as the main narrowing loop above.
67496770 or_if_maps : list [TypeMap ] = []
67506771 or_else_maps : list [TypeMap ] = []
67516772 for expr_type in union_expr_type .items :
67526773 if has_custom_eq_checks (expr_type ):
6774+ # Always include union items with custom __eq__ in the type
6775+ # we narrow to in the if_map
67536776 or_if_maps .append ({operands [i ]: expr_type })
67546777
67556778 for j in expr_indices :
@@ -6784,6 +6807,8 @@ def narrow_type_by_identity_equality(
67846807
67856808 partial_type_maps .append ((final_if_map , final_else_map ))
67866809
6810+ # Handle narrowing for comparisons that produce additional narrowing, like
6811+ # `type(x) == T` or `x.__class__ is T`
67876812 for i in expr_indices :
67886813 type_expr = operands [i ]
67896814 if (
0 commit comments