From f3affe7c97e2f3e6ae059c3ff11883cc8a1732db Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Thu, 9 Apr 2026 11:57:34 +0900 Subject: [PATCH 1/2] fix --- pyrefly/lib/alt/narrow.rs | 66 ++++++++++++++++++++++++++++------ pyrefly/lib/test/typed_dict.rs | 21 +++++++++++ 2 files changed, 77 insertions(+), 10 deletions(-) diff --git a/pyrefly/lib/alt/narrow.rs b/pyrefly/lib/alt/narrow.rs index 0f4930b885..5a6dc6a698 100644 --- a/pyrefly/lib/alt/narrow.rs +++ b/pyrefly/lib/alt/narrow.rs @@ -21,6 +21,7 @@ use pyrefly_types::facet::UnresolvedFacetKind; use pyrefly_types::simplify::intersect; use pyrefly_types::simplify::simplify_tuples; use pyrefly_types::type_info::JoinStyle; +use pyrefly_types::typed_dict::ExtraItems; use pyrefly_util::prelude::SliceExt; use pyrefly_util::visit::Visit; use ruff_python_ast::Arguments; @@ -678,6 +679,41 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { }) } + fn narrow_key_membership(&self, ty: &Type, key: &Name, present: bool) -> Type { + self.distribute_over_union(ty, |member| match member { + _ if self.behaves_like_any(member) => member.clone(), + Type::TypedDict(typed_dict) => { + let fields = self.typed_dict_fields(typed_dict); + match fields.get(key) { + Some(field) if present || !field.required => member.clone(), + Some(_) => self.heap.mk_never(), + None if matches!( + self.typed_dict_extra_items(typed_dict), + ExtraItems::Extra(_) + ) => + { + member.clone() + } + None if present => self.heap.mk_never(), + None => member.clone(), + } + } + _ if self.is_dict_like(member) => member.clone(), + _ => member.clone(), + }) + } + + fn has_dict_like_member(&self, ty: &Type) -> bool { + match ty { + ty if self.is_dict_like(ty) => true, + Type::Union(union) => union + .members + .iter() + .any(|member| self.has_dict_like_member(member)), + _ => false, + } + } + /// Narrow a union by keeping only members whose facet is identity-compatible with `right`. fn narrow_facet_is( &self, @@ -1536,7 +1572,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { (Some(_), None) => return type_info.clone(), (None, _) => self.force_for_narrowing(type_info.ty(), range, errors), }; - if self.is_dict_like(&base_ty) { + let narrowed_base = self.narrow_key_membership(&base_ty, key, true); + let mut narrowed = match &resolved_chain { + Some(chain) => type_info.with_narrow(chain.facets(), narrowed_base.clone()), + None => type_info.clone().with_ty(narrowed_base.clone()), + }; + if self.has_dict_like_member(&narrowed_base) { let key_facet = FacetKind::Key(key.to_string()); let facets = match resolved_chain { Some(chain) => { @@ -1546,13 +1587,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } None => Vec1::new(key_facet), }; - let chain = FacetChain::new(facets); + let chain = FacetChain::new(facets.clone()); // Apply a facet narrow w/ that key's type, so that the usual subscript inference // code path which raises a warning for NotRequired keys does not execute later - let value_ty = self.get_facet_chain_type(type_info, &chain, range); - type_info.with_narrow(chain.facets(), value_ty) + let value_ty = self.get_facet_chain_type(&narrowed, &chain, range); + narrowed = narrowed.with_narrow(&facets, value_ty); + narrowed } else { - type_info.clone() + narrowed } } NarrowOp::Atomic(subject, AtomicNarrowOp::NotHasKey(key)) => { @@ -1564,7 +1606,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { (Some(_), None) => return type_info.clone(), (None, _) => self.force_for_narrowing(type_info.ty(), range, errors), }; - if self.is_dict_like(&base_ty) { + let narrowed_base = self.narrow_key_membership(&base_ty, key, false); + let mut narrowed = match &resolved_chain { + Some(chain) => type_info.with_narrow(chain.facets(), narrowed_base), + None => type_info.clone().with_ty(narrowed_base), + }; + if self.has_dict_like_member(&base_ty) { let key_facet = FacetKind::Key(key.to_string()); let facets = match resolved_chain { Some(chain) => { @@ -1575,11 +1622,10 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { None => Vec1::new(key_facet), }; // Invalidate existing facet narrows - let mut type_info = type_info.clone(); - type_info.update_for_assignment(&facets, None); - type_info + narrowed.update_for_assignment(&facets, None); + narrowed } else { - type_info.clone() + narrowed } } NarrowOp::Atomic(subject, AtomicNarrowOp::HasAttr(attr)) => { diff --git a/pyrefly/lib/test/typed_dict.rs b/pyrefly/lib/test/typed_dict.rs index abb4e69a89..7bdd8184c7 100644 --- a/pyrefly/lib/test/typed_dict.rs +++ b/pyrefly/lib/test/typed_dict.rs @@ -2288,6 +2288,27 @@ def test_empty_not_in(e: Empty, k: str): "#, ); +testcase!( + test_typed_dict_union_subject_contains_narrowing, + r#" +from typing import TypedDict, assert_type + +class Foo(TypedDict): + a: int + +class Bar(TypedDict): + b: int + +def test(foo: Foo | Bar) -> None: + if "a" in foo: + assert_type(foo, Foo) + assert_type(foo["a"], int) + else: + assert_type(foo, Bar) + assert_type(foo["b"], int) +"#, +); + testcase!( test_illegal_unpacking_in_def, r#" From 4253034debab8e0105e88fe95172f63fbe8f3891 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Thu, 9 Apr 2026 15:28:33 +0900 Subject: [PATCH 2/2] a try --- pyrefly/lib/alt/narrow.rs | 33 +++++++++++++++++++++++---------- pyrefly/lib/test/typed_dict.rs | 29 +++++++++++++++++++++++++++-- 2 files changed, 50 insertions(+), 12 deletions(-) diff --git a/pyrefly/lib/alt/narrow.rs b/pyrefly/lib/alt/narrow.rs index 5a6dc6a698..71b35142df 100644 --- a/pyrefly/lib/alt/narrow.rs +++ b/pyrefly/lib/alt/narrow.rs @@ -687,14 +687,10 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { match fields.get(key) { Some(field) if present || !field.required => member.clone(), Some(_) => self.heap.mk_never(), - None if matches!( - self.typed_dict_extra_items(typed_dict), - ExtraItems::Extra(_) - ) => - { - member.clone() - } - None if present => self.heap.mk_never(), + None if present => match self.typed_dict_extra_items(typed_dict) { + ExtraItems::Closed => self.heap.mk_never(), + ExtraItems::Default | ExtraItems::Extra(_) => member.clone(), + }, None => member.clone(), } } @@ -714,6 +710,24 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } + fn key_membership_value_type(&self, ty: &Type, key: &Name, range: TextRange) -> Type { + let slice = Ast::str_expr(key.as_str(), TextRange::empty(TextSize::from(0))); + let ignore_errors = self.error_swallower(); + self.distribute_over_union(ty, |member| match member { + Type::TypedDict(typed_dict) => { + let fields = self.typed_dict_fields(typed_dict); + if let Some(field) = fields.get(key) { + field.ty.clone() + } else { + self.typed_dict_extra_items(typed_dict) + .extra_item(self.stdlib) + .ty + } + } + _ => self.subscript_infer_for_type(member, &slice, range, &ignore_errors), + }) + } + /// Narrow a union by keeping only members whose facet is identity-compatible with `right`. fn narrow_facet_is( &self, @@ -1587,10 +1601,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } None => Vec1::new(key_facet), }; - let chain = FacetChain::new(facets.clone()); // Apply a facet narrow w/ that key's type, so that the usual subscript inference // code path which raises a warning for NotRequired keys does not execute later - let value_ty = self.get_facet_chain_type(&narrowed, &chain, range); + let value_ty = self.key_membership_value_type(&narrowed_base, key, range); narrowed = narrowed.with_narrow(&facets, value_ty); narrowed } else { diff --git a/pyrefly/lib/test/typed_dict.rs b/pyrefly/lib/test/typed_dict.rs index 7bdd8184c7..ba9b5a5add 100644 --- a/pyrefly/lib/test/typed_dict.rs +++ b/pyrefly/lib/test/typed_dict.rs @@ -2293,10 +2293,10 @@ testcase!( r#" from typing import TypedDict, assert_type -class Foo(TypedDict): +class Foo(TypedDict, closed=True): a: int -class Bar(TypedDict): +class Bar(TypedDict, closed=True): b: int def test(foo: Foo | Bar) -> None: @@ -2309,6 +2309,31 @@ def test(foo: Foo | Bar) -> None: "#, ); +testcase!( + test_typed_dict_open_contains_narrowing, + r#" +from typing import TypedDict, assert_type + +class TD(TypedDict): + x: int + +class TD2(TypedDict): + x: int + y: int + +def test(td: TD) -> None: + if "y" in td: + assert_type(td["y"], object) + +def test_union(td: TD | TD2) -> None: + if "y" in td: + assert_type(td, TD | TD2) + assert_type(td["y"], object) + else: + assert_type(td, TD) +"#, +); + testcase!( test_illegal_unpacking_in_def, r#"