diff --git a/pyrefly/lib/alt/narrow.rs b/pyrefly/lib/alt/narrow.rs index 0f4930b885..71b35142df 100644 --- a/pyrefly/lib/alt/narrow.rs +++ b/pyrefly/lib/alt/narrow.rs @@ -21,6 +21,7 @@ use pyrefly_types::facet::UnresolvedFacetKind; use pyrefly_types::simplify::intersect; use pyrefly_types::simplify::simplify_tuples; use pyrefly_types::type_info::JoinStyle; +use pyrefly_types::typed_dict::ExtraItems; use pyrefly_util::prelude::SliceExt; use pyrefly_util::visit::Visit; use ruff_python_ast::Arguments; @@ -678,6 +679,55 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { }) } + fn narrow_key_membership(&self, ty: &Type, key: &Name, present: bool) -> Type { + self.distribute_over_union(ty, |member| match member { + _ if self.behaves_like_any(member) => member.clone(), + Type::TypedDict(typed_dict) => { + let fields = self.typed_dict_fields(typed_dict); + match fields.get(key) { + Some(field) if present || !field.required => member.clone(), + Some(_) => self.heap.mk_never(), + None if present => match self.typed_dict_extra_items(typed_dict) { + ExtraItems::Closed => self.heap.mk_never(), + ExtraItems::Default | ExtraItems::Extra(_) => member.clone(), + }, + None => member.clone(), + } + } + _ if self.is_dict_like(member) => member.clone(), + _ => member.clone(), + }) + } + + fn has_dict_like_member(&self, ty: &Type) -> bool { + match ty { + ty if self.is_dict_like(ty) => true, + Type::Union(union) => union + .members + .iter() + .any(|member| self.has_dict_like_member(member)), + _ => false, + } + } + + fn key_membership_value_type(&self, ty: &Type, key: &Name, range: TextRange) -> Type { + let slice = Ast::str_expr(key.as_str(), TextRange::empty(TextSize::from(0))); + let ignore_errors = self.error_swallower(); + self.distribute_over_union(ty, |member| match member { + Type::TypedDict(typed_dict) => { + let fields = self.typed_dict_fields(typed_dict); + if let Some(field) = fields.get(key) { + field.ty.clone() + } else { + self.typed_dict_extra_items(typed_dict) + .extra_item(self.stdlib) + .ty + } + } + _ => self.subscript_infer_for_type(member, &slice, range, &ignore_errors), + }) + } + /// Narrow a union by keeping only members whose facet is identity-compatible with `right`. fn narrow_facet_is( &self, @@ -1536,7 +1586,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { (Some(_), None) => return type_info.clone(), (None, _) => self.force_for_narrowing(type_info.ty(), range, errors), }; - if self.is_dict_like(&base_ty) { + let narrowed_base = self.narrow_key_membership(&base_ty, key, true); + let mut narrowed = match &resolved_chain { + Some(chain) => type_info.with_narrow(chain.facets(), narrowed_base.clone()), + None => type_info.clone().with_ty(narrowed_base.clone()), + }; + if self.has_dict_like_member(&narrowed_base) { let key_facet = FacetKind::Key(key.to_string()); let facets = match resolved_chain { Some(chain) => { @@ -1546,13 +1601,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } None => Vec1::new(key_facet), }; - let chain = FacetChain::new(facets); // Apply a facet narrow w/ that key's type, so that the usual subscript inference // code path which raises a warning for NotRequired keys does not execute later - let value_ty = self.get_facet_chain_type(type_info, &chain, range); - type_info.with_narrow(chain.facets(), value_ty) + let value_ty = self.key_membership_value_type(&narrowed_base, key, range); + narrowed = narrowed.with_narrow(&facets, value_ty); + narrowed } else { - type_info.clone() + narrowed } } NarrowOp::Atomic(subject, AtomicNarrowOp::NotHasKey(key)) => { @@ -1564,7 +1619,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { (Some(_), None) => return type_info.clone(), (None, _) => self.force_for_narrowing(type_info.ty(), range, errors), }; - if self.is_dict_like(&base_ty) { + let narrowed_base = self.narrow_key_membership(&base_ty, key, false); + let mut narrowed = match &resolved_chain { + Some(chain) => type_info.with_narrow(chain.facets(), narrowed_base), + None => type_info.clone().with_ty(narrowed_base), + }; + if self.has_dict_like_member(&base_ty) { let key_facet = FacetKind::Key(key.to_string()); let facets = match resolved_chain { Some(chain) => { @@ -1575,11 +1635,10 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { None => Vec1::new(key_facet), }; // Invalidate existing facet narrows - let mut type_info = type_info.clone(); - type_info.update_for_assignment(&facets, None); - type_info + narrowed.update_for_assignment(&facets, None); + narrowed } else { - type_info.clone() + narrowed } } NarrowOp::Atomic(subject, AtomicNarrowOp::HasAttr(attr)) => { diff --git a/pyrefly/lib/test/typed_dict.rs b/pyrefly/lib/test/typed_dict.rs index abb4e69a89..ba9b5a5add 100644 --- a/pyrefly/lib/test/typed_dict.rs +++ b/pyrefly/lib/test/typed_dict.rs @@ -2288,6 +2288,52 @@ def test_empty_not_in(e: Empty, k: str): "#, ); +testcase!( + test_typed_dict_union_subject_contains_narrowing, + r#" +from typing import TypedDict, assert_type + +class Foo(TypedDict, closed=True): + a: int + +class Bar(TypedDict, closed=True): + b: int + +def test(foo: Foo | Bar) -> None: + if "a" in foo: + assert_type(foo, Foo) + assert_type(foo["a"], int) + else: + assert_type(foo, Bar) + assert_type(foo["b"], int) +"#, +); + +testcase!( + test_typed_dict_open_contains_narrowing, + r#" +from typing import TypedDict, assert_type + +class TD(TypedDict): + x: int + +class TD2(TypedDict): + x: int + y: int + +def test(td: TD) -> None: + if "y" in td: + assert_type(td["y"], object) + +def test_union(td: TD | TD2) -> None: + if "y" in td: + assert_type(td, TD | TD2) + assert_type(td["y"], object) + else: + assert_type(td, TD) +"#, +); + testcase!( test_illegal_unpacking_in_def, r#"