Skip to content

Commit edfe6c1

Browse files
authored
[ty] Narrow type context during collection literal inference (#23844)
Collection literals are effectively a special case of generic calls, so there's no reason we shouldn't be applying the same narrowing logic here. Part of astral-sh/ty#3001.
1 parent dd16d68 commit edfe6c1

3 files changed

Lines changed: 127 additions & 39 deletions

File tree

crates/ty_python_semantic/resources/mdtest/assignment/annotations.md

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -445,14 +445,21 @@ reveal_type(x8) # revealed: Literal[True]
445445
x9: int | str = f2(True)
446446
reveal_type(x9) # revealed: Literal[True]
447447

448-
# TODO: Should not error. We could choose a concrete type here (pyright arbitrarily picks the
449-
# first), or keep the union (pyrefly does this). Mypy infers `list[int]` and errors.
450-
# error: [invalid-assignment]
451448
x10: list[int | str] | list[int | None] = [1, 2, 3]
452-
reveal_type(x10) # revealed: list[int | str] | list[int | None]
449+
reveal_type(x10) # revealed: list[int | str]
453450

454451
x11: Sequence[int | str] | Sequence[int | None] = [1, 2, 3]
455452
reveal_type(x11) # revealed: list[int]
453+
454+
x12: list[int] | list[int | None] | list[str | None] = ["1", "2"]
455+
reveal_type(x12) # revealed: list[str | None]
456+
457+
x13: dict[str, list[int | None]] | dict[str, list[str | None]] = {"a": ["b"]}
458+
reveal_type(x13) # revealed: dict[str, list[str | None]]
459+
460+
x14 = [{"a": [1], "b": 1}, {"a": [1]}]
461+
x14.append(reveal_type({"b": 1})) # revealed: dict[str, list[int] | int]
462+
reveal_type(x14) # revealed: list[dict[str, list[int] | int] | dict[str, list[int]]]
456463
```
457464

458465
## Annotations influence generic call argument inference

crates/ty_python_semantic/resources/mdtest/bidirectional.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,26 @@ def _(x: int):
420420
y(lst(True), [z])
421421
```
422422

423+
```py
424+
def g[T](x: T, y: list[T | None]) -> T:
425+
return x
426+
427+
def _(flag: bool):
428+
if flag:
429+
x = 1
430+
431+
# error: [possibly-unresolved-reference]
432+
x1: int | str = g(x, [1])
433+
reveal_type(x1) # revealed: int
434+
435+
if flag:
436+
y = "1"
437+
438+
# error: [possibly-unresolved-reference]
439+
x2: list[int | None] | list[str | None] = [y]
440+
reveal_type(x2) # revealed: list[str | None]
441+
```
442+
423443
```py
424444
class Bar(TypedDict):
425445
bar: int

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 96 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4976,13 +4976,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
49764976
_ => &[],
49774977
};
49784978

4979-
// We silence diagnostics until we successfully narrow to a specific type.
4980-
let was_in_multi_inference = self.context.set_multi_inference(true);
4981-
49824979
let mut try_narrow = |narrowed_ty| {
49834980
let mut speculated_bindings = bindings.clone();
49844981
let narrowed_tcx = TypeContext::new(Some(narrowed_ty));
49854982

4983+
// We silence diagnostics until we successfully narrow to a specific type.
4984+
let was_in_multi_inference = self.context.set_multi_inference(true);
4985+
49864986
// Attempt to infer the argument types using the narrowed type context.
49874987
self.infer_all_argument_types(
49884988
ast_arguments.clone(),
@@ -4993,6 +4993,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
49934993
MultiInferenceState::Ignore,
49944994
);
49954995

4996+
// Restore the multi-inference state.
4997+
self.context.set_multi_inference(was_in_multi_inference);
4998+
49964999
// Ensure the argument types match their annotated types.
49975000
if speculated_bindings
49985001
.check_types_impl(
@@ -5023,8 +5026,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
50235026
//
50245027
// If necessary, infer the argument types again with diagnostics enabled.
50255028
if !was_in_multi_inference {
5026-
self.context.set_multi_inference(was_in_multi_inference);
5027-
50285029
self.infer_all_argument_types(
50295030
ast_arguments.clone(),
50305031
argument_types,
@@ -5067,9 +5068,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
50675068
}
50685069
}
50695070

5070-
// Re-enable diagnostics, and infer against the entire union as a fallback.
5071-
self.context.set_multi_inference(was_in_multi_inference);
5072-
5071+
// Infer against the entire union as a fallback.
50735072
self.infer_all_argument_types(
50745073
ast_arguments,
50755074
argument_types,
@@ -5784,11 +5783,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
57845783
ctx: _,
57855784
} = list;
57865785

5787-
let mut elts = elts.iter().map(|elt| [Some(elt)]);
5786+
let elts = elts.iter().map(|elt| [Some(elt)]).collect_vec();
57885787
let mut infer_elt_ty =
57895788
|builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx);
57905789

5791-
self.infer_collection_literal(KnownClass::List, &mut elts, &mut infer_elt_ty, tcx)
5790+
self.infer_collection_literal(KnownClass::List, &elts, &mut infer_elt_ty, tcx)
57925791
.unwrap_or_else(|| {
57935792
KnownClass::List.to_specialized_instance(self.db(), &[Type::unknown()])
57945793
})
@@ -5801,11 +5800,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
58015800
elts,
58025801
} = set;
58035802

5804-
let mut elts = elts.iter().map(|elt| [Some(elt)]);
5803+
let elts = elts.iter().map(|elt| [Some(elt)]).collect_vec();
58055804
let mut infer_elt_ty =
58065805
|builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx);
58075806

5808-
self.infer_collection_literal(KnownClass::Set, &mut elts, &mut infer_elt_ty, tcx)
5807+
self.infer_collection_literal(KnownClass::Set, &elts, &mut infer_elt_ty, tcx)
58095808
.unwrap_or_else(|| {
58105809
KnownClass::Set.to_specialized_instance(self.db(), &[Type::unknown()])
58115810
})
@@ -5888,9 +5887,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
58885887
.to_specialized_instance(self.db(), &[Type::unknown(), Type::unknown()]);
58895888
}
58905889

5891-
let mut items = items
5890+
let items = items
58925891
.iter()
5893-
.map(|item| [item.key.as_ref(), Some(&item.value)]);
5892+
.map(|item| [item.key.as_ref(), Some(&item.value)])
5893+
.collect_vec();
58945894

58955895
// Avoid inferring the items multiple times if we already attempted to infer the
58965896
// dictionary literal as a `TypedDict`. This also allows us to infer using the
@@ -5902,7 +5902,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
59025902
.unwrap_or_else(|| builder.infer_expression(elt, tcx))
59035903
};
59045904

5905-
self.infer_collection_literal(KnownClass::Dict, &mut items, &mut infer_elt_ty, tcx)
5905+
self.infer_collection_literal(KnownClass::Dict, &items, &mut infer_elt_ty, tcx)
59065906
.unwrap_or_else(|| {
59075907
KnownClass::Dict
59085908
.to_specialized_instance(self.db(), &[Type::unknown(), Type::unknown()])
@@ -5955,7 +5955,74 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
59555955
fn infer_collection_literal<'expr, const N: usize>(
59565956
&mut self,
59575957
collection_class: KnownClass,
5958-
elts: &mut dyn Iterator<Item = [Option<&'expr ast::Expr>; N]>,
5958+
elts: &[[Option<&'expr ast::Expr>; N]],
5959+
infer_elt_expression: &mut dyn FnMut(&mut Self, ArgExpr<'db, 'expr>) -> Type<'db>,
5960+
tcx: TypeContext<'db>,
5961+
) -> Option<Type<'db>> {
5962+
let db = self.db();
5963+
5964+
// If the type context is a union, attempt to narrow to a specific element.
5965+
let narrow_targets: &[_] = match tcx.annotation {
5966+
Some(Type::Union(union)) => union.elements(db),
5967+
_ => &[],
5968+
};
5969+
5970+
let mut try_narrow = |narrowed_ty| {
5971+
let narrowed_tcx = TypeContext::new(Some(narrowed_ty));
5972+
5973+
// We silence diagnostics until we successfully narrow to a specific type.
5974+
let prev_multi_inference = self.set_multi_inference_state(MultiInferenceState::Ignore);
5975+
let was_in_multi_inference = self.context.set_multi_inference(true);
5976+
5977+
// Attempt to infer the collection literal using the narrowed type context.
5978+
let inferred_ty = self.infer_collection_literal_impl(
5979+
collection_class,
5980+
elts,
5981+
infer_elt_expression,
5982+
narrowed_tcx,
5983+
)?;
5984+
5985+
// Restore the multi-inference state.
5986+
self.context.set_multi_inference(was_in_multi_inference);
5987+
self.set_multi_inference_state(prev_multi_inference);
5988+
5989+
// Ensure the inferred return type is assignable to the (narrowed) declared type.
5990+
if !inferred_ty.is_assignable_to(db, narrowed_ty) {
5991+
return None;
5992+
}
5993+
5994+
// Successfully narrowed to an element of the union.
5995+
//
5996+
// If necessary, infer the collection literal again with diagnostics enabled.
5997+
if !was_in_multi_inference {
5998+
self.infer_collection_literal_impl(
5999+
collection_class,
6000+
elts,
6001+
infer_elt_expression,
6002+
narrowed_tcx,
6003+
);
6004+
}
6005+
6006+
Some(inferred_ty)
6007+
};
6008+
6009+
for narrowed_ty in narrow_targets
6010+
.iter()
6011+
.filter(|ty| ty.class_specialization(db).is_some())
6012+
{
6013+
if let Some(result) = try_narrow(*narrowed_ty) {
6014+
return Some(result);
6015+
}
6016+
}
6017+
6018+
self.infer_collection_literal_impl(collection_class, elts, infer_elt_expression, tcx)
6019+
}
6020+
6021+
// Infer the type of a collection literal expression.
6022+
fn infer_collection_literal_impl<'expr, const N: usize>(
6023+
&mut self,
6024+
collection_class: KnownClass,
6025+
elts: &[[Option<&'expr ast::Expr>; N]],
59596026
infer_elt_expression: &mut dyn FnMut(&mut Self, ArgExpr<'db, 'expr>) -> Type<'db>,
59606027
tcx: TypeContext<'db>,
59616028
) -> Option<Type<'db>> {
@@ -5980,7 +6047,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
59806047
let Some((collection_alias, generic_context, elt_tys)) = elt_tys(collection_class) else {
59816048
// Infer the element types without type context, and fallback to `Unknown` for
59826049
// custom typesheds.
5983-
for (i, elt) in elts.flatten().flatten().enumerate() {
6050+
for (i, elt) in elts.iter().flatten().flatten().enumerate() {
59846051
infer_elt_expression(self, (i, elt, TypeContext::default()));
59856052
}
59866053

@@ -6009,12 +6076,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
60096076
FxHashMap::default();
60106077

60116078
if let Some(tcx) = tcx.annotation
6012-
// If there are multiple potential type contexts, we fallback to `Unknown`.
6013-
// TODO: We could perform multi-inference here.
6014-
&& tcx
6015-
.filter_union(self.db(), |ty| ty.class_specialization(self.db()).is_some())
6016-
.class_specialization(self.db())
6017-
.is_some()
6079+
&& tcx.class_specialization(self.db()).is_some()
60186080
{
60196081
let collection_instance =
60206082
Type::instance(self.db(), ClassType::Generic(collection_alias));
@@ -6245,15 +6307,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
62456307
fn infer_comprehension_specialization<const N: usize>(
62466308
&mut self,
62476309
collection_class: KnownClass,
6248-
elements: &[Option<&ast::Expr>; N],
6310+
elements: [Option<&ast::Expr>; N],
62496311
inference: &ScopeInference<'db>,
62506312
tcx: TypeContext<'db>,
62516313
) -> Option<Type<'db>> {
6252-
let mut elements = [elements].into_iter().copied();
62536314
let mut infer_element_ty =
62546315
|_builder: &mut Self, (_, elt, _)| inference.expression_type(elt);
62556316

6256-
self.infer_collection_literal(collection_class, &mut elements, &mut infer_element_ty, tcx)
6317+
self.infer_collection_literal(collection_class, &[elements], &mut infer_element_ty, tcx)
62576318
}
62586319

62596320
fn infer_list_comprehension_expression(
@@ -6280,7 +6341,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
62806341
let inference = infer_scope_types(self.db(), scope, tcx);
62816342
self.extend_scope(inference);
62826343

6283-
self.infer_comprehension_specialization(KnownClass::List, &[Some(elt)], inference, tcx)
6344+
self.infer_comprehension_specialization(KnownClass::List, [Some(elt)], inference, tcx)
62846345
.unwrap_or_else(|| {
62856346
KnownClass::List.to_specialized_instance(self.db(), &[Type::unknown()])
62866347
})
@@ -6310,7 +6371,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
63106371
let inference = infer_scope_types(self.db(), scope, tcx);
63116372
self.extend_scope(inference);
63126373

6313-
self.infer_comprehension_specialization(KnownClass::Set, &[Some(elt)], inference, tcx)
6374+
self.infer_comprehension_specialization(KnownClass::Set, [Some(elt)], inference, tcx)
63146375
.unwrap_or_else(|| {
63156376
KnownClass::Set.to_specialized_instance(self.db(), &[Type::unknown()])
63166377
})
@@ -6343,7 +6404,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
63436404

63446405
self.infer_comprehension_specialization(
63456406
KnownClass::Dict,
6346-
&[Some(key), Some(value)],
6407+
[Some(key), Some(value)],
63476408
inference,
63486409
tcx,
63496410
)
@@ -6378,10 +6439,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
63786439
} = listcomp;
63796440

63806441
// Infer the element type using the outer type context.
6381-
let mut elts = [[Some(elt.as_ref())]].into_iter();
6442+
let elts = [[Some(elt.as_ref())]];
63826443
let mut infer_elt_ty =
63836444
|builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx);
6384-
self.infer_collection_literal(KnownClass::List, &mut elts, &mut infer_elt_ty, tcx);
6445+
self.infer_collection_literal(KnownClass::List, &elts, &mut infer_elt_ty, tcx);
63856446

63866447
self.infer_comprehensions(generators);
63876448
}
@@ -6399,10 +6460,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
63996460
} = setcomp;
64006461

64016462
// Infer the element type using the outer type context.
6402-
let mut elts = [[Some(elt.as_ref())]].into_iter();
6463+
let elts = [[Some(elt.as_ref())]];
64036464
let mut infer_elt_ty =
64046465
|builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx);
6405-
self.infer_collection_literal(KnownClass::Set, &mut elts, &mut infer_elt_ty, tcx);
6466+
self.infer_collection_literal(KnownClass::Set, &elts, &mut infer_elt_ty, tcx);
64066467

64076468
self.infer_comprehensions(generators);
64086469
}
@@ -6421,10 +6482,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
64216482
} = dictcomp;
64226483

64236484
// Infer the key and value types using the outer type context.
6424-
let mut elts = [[Some(key.as_ref()), Some(value.as_ref())]].into_iter();
6485+
let elts = [[Some(key.as_ref()), Some(value.as_ref())]];
64256486
let mut infer_elt_ty =
64266487
|builder: &mut Self, (_, elt, tcx)| builder.infer_expression(elt, tcx);
6427-
self.infer_collection_literal(KnownClass::Dict, &mut elts, &mut infer_elt_ty, tcx);
6488+
self.infer_collection_literal(KnownClass::Dict, &elts, &mut infer_elt_ty, tcx);
64286489

64296490
self.infer_comprehensions(generators);
64306491
}

0 commit comments

Comments
 (0)