@@ -4976,13 +4976,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
49764976 _ => & [ ] ,
49774977 } ;
49784978
4979- // We silence diagnostics until we successfully narrow to a specific type.
4980- let was_in_multi_inference = self . context . set_multi_inference ( true ) ;
4981-
49824979 let mut try_narrow = |narrowed_ty| {
49834980 let mut speculated_bindings = bindings. clone ( ) ;
49844981 let narrowed_tcx = TypeContext :: new ( Some ( narrowed_ty) ) ;
49854982
4983+ // We silence diagnostics until we successfully narrow to a specific type.
4984+ let was_in_multi_inference = self . context . set_multi_inference ( true ) ;
4985+
49864986 // Attempt to infer the argument types using the narrowed type context.
49874987 self . infer_all_argument_types (
49884988 ast_arguments. clone ( ) ,
@@ -4993,6 +4993,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
49934993 MultiInferenceState :: Ignore ,
49944994 ) ;
49954995
4996+ // Restore the multi-inference state.
4997+ self . context . set_multi_inference ( was_in_multi_inference) ;
4998+
49964999 // Ensure the argument types match their annotated types.
49975000 if speculated_bindings
49985001 . check_types_impl (
@@ -5023,8 +5026,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
50235026 //
50245027 // If necessary, infer the argument types again with diagnostics enabled.
50255028 if !was_in_multi_inference {
5026- self . context . set_multi_inference ( was_in_multi_inference) ;
5027-
50285029 self . infer_all_argument_types (
50295030 ast_arguments. clone ( ) ,
50305031 argument_types,
@@ -5067,9 +5068,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
50675068 }
50685069 }
50695070
5070- // Re-enable diagnostics, and infer against the entire union as a fallback.
5071- self . context . set_multi_inference ( was_in_multi_inference) ;
5072-
5071+ // Infer against the entire union as a fallback.
50735072 self . infer_all_argument_types (
50745073 ast_arguments,
50755074 argument_types,
@@ -5784,11 +5783,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
57845783 ctx : _,
57855784 } = list;
57865785
5787- let mut elts = elts. iter ( ) . map ( |elt| [ Some ( elt) ] ) ;
5786+ let elts = elts. iter ( ) . map ( |elt| [ Some ( elt) ] ) . collect_vec ( ) ;
57885787 let mut infer_elt_ty =
57895788 |builder : & mut Self , ( _, elt, tcx) | builder. infer_expression ( elt, tcx) ;
57905789
5791- self . infer_collection_literal ( KnownClass :: List , & mut elts, & mut infer_elt_ty, tcx)
5790+ self . infer_collection_literal ( KnownClass :: List , & elts, & mut infer_elt_ty, tcx)
57925791 . unwrap_or_else ( || {
57935792 KnownClass :: List . to_specialized_instance ( self . db ( ) , & [ Type :: unknown ( ) ] )
57945793 } )
@@ -5801,11 +5800,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
58015800 elts,
58025801 } = set;
58035802
5804- let mut elts = elts. iter ( ) . map ( |elt| [ Some ( elt) ] ) ;
5803+ let elts = elts. iter ( ) . map ( |elt| [ Some ( elt) ] ) . collect_vec ( ) ;
58055804 let mut infer_elt_ty =
58065805 |builder : & mut Self , ( _, elt, tcx) | builder. infer_expression ( elt, tcx) ;
58075806
5808- self . infer_collection_literal ( KnownClass :: Set , & mut elts, & mut infer_elt_ty, tcx)
5807+ self . infer_collection_literal ( KnownClass :: Set , & elts, & mut infer_elt_ty, tcx)
58095808 . unwrap_or_else ( || {
58105809 KnownClass :: Set . to_specialized_instance ( self . db ( ) , & [ Type :: unknown ( ) ] )
58115810 } )
@@ -5888,9 +5887,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
58885887 . to_specialized_instance ( self . db ( ) , & [ Type :: unknown ( ) , Type :: unknown ( ) ] ) ;
58895888 }
58905889
5891- let mut items = items
5890+ let items = items
58925891 . iter ( )
5893- . map ( |item| [ item. key . as_ref ( ) , Some ( & item. value ) ] ) ;
5892+ . map ( |item| [ item. key . as_ref ( ) , Some ( & item. value ) ] )
5893+ . collect_vec ( ) ;
58945894
58955895 // Avoid inferring the items multiple times if we already attempted to infer the
58965896 // dictionary literal as a `TypedDict`. This also allows us to infer using the
@@ -5902,7 +5902,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
59025902 . unwrap_or_else ( || builder. infer_expression ( elt, tcx) )
59035903 } ;
59045904
5905- self . infer_collection_literal ( KnownClass :: Dict , & mut items, & mut infer_elt_ty, tcx)
5905+ self . infer_collection_literal ( KnownClass :: Dict , & items, & mut infer_elt_ty, tcx)
59065906 . unwrap_or_else ( || {
59075907 KnownClass :: Dict
59085908 . to_specialized_instance ( self . db ( ) , & [ Type :: unknown ( ) , Type :: unknown ( ) ] )
@@ -5955,7 +5955,74 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
59555955 fn infer_collection_literal < ' expr , const N : usize > (
59565956 & mut self ,
59575957 collection_class : KnownClass ,
5958- elts : & mut dyn Iterator < Item = [ Option < & ' expr ast:: Expr > ; N ] > ,
5958+ elts : & [ [ Option < & ' expr ast:: Expr > ; N ] ] ,
5959+ infer_elt_expression : & mut dyn FnMut ( & mut Self , ArgExpr < ' db , ' expr > ) -> Type < ' db > ,
5960+ tcx : TypeContext < ' db > ,
5961+ ) -> Option < Type < ' db > > {
5962+ let db = self . db ( ) ;
5963+
5964+ // If the type context is a union, attempt to narrow to a specific element.
5965+ let narrow_targets: & [ _ ] = match tcx. annotation {
5966+ Some ( Type :: Union ( union) ) => union. elements ( db) ,
5967+ _ => & [ ] ,
5968+ } ;
5969+
5970+ let mut try_narrow = |narrowed_ty| {
5971+ let narrowed_tcx = TypeContext :: new ( Some ( narrowed_ty) ) ;
5972+
5973+ // We silence diagnostics until we successfully narrow to a specific type.
5974+ let prev_multi_inference = self . set_multi_inference_state ( MultiInferenceState :: Ignore ) ;
5975+ let was_in_multi_inference = self . context . set_multi_inference ( true ) ;
5976+
5977+ // Attempt to infer the collection literal using the narrowed type context.
5978+ let inferred_ty = self . infer_collection_literal_impl (
5979+ collection_class,
5980+ elts,
5981+ infer_elt_expression,
5982+ narrowed_tcx,
5983+ ) ?;
5984+
5985+ // Restore the multi-inference state.
5986+ self . context . set_multi_inference ( was_in_multi_inference) ;
5987+ self . set_multi_inference_state ( prev_multi_inference) ;
5988+
5989+ // Ensure the inferred return type is assignable to the (narrowed) declared type.
5990+ if !inferred_ty. is_assignable_to ( db, narrowed_ty) {
5991+ return None ;
5992+ }
5993+
5994+ // Successfully narrowed to an element of the union.
5995+ //
5996+ // If necessary, infer the collection literal again with diagnostics enabled.
5997+ if !was_in_multi_inference {
5998+ self . infer_collection_literal_impl (
5999+ collection_class,
6000+ elts,
6001+ infer_elt_expression,
6002+ narrowed_tcx,
6003+ ) ;
6004+ }
6005+
6006+ Some ( inferred_ty)
6007+ } ;
6008+
6009+ for narrowed_ty in narrow_targets
6010+ . iter ( )
6011+ . filter ( |ty| ty. class_specialization ( db) . is_some ( ) )
6012+ {
6013+ if let Some ( result) = try_narrow ( * narrowed_ty) {
6014+ return Some ( result) ;
6015+ }
6016+ }
6017+
6018+ self . infer_collection_literal_impl ( collection_class, elts, infer_elt_expression, tcx)
6019+ }
6020+
6021+ // Infer the type of a collection literal expression.
6022+ fn infer_collection_literal_impl < ' expr , const N : usize > (
6023+ & mut self ,
6024+ collection_class : KnownClass ,
6025+ elts : & [ [ Option < & ' expr ast:: Expr > ; N ] ] ,
59596026 infer_elt_expression : & mut dyn FnMut ( & mut Self , ArgExpr < ' db , ' expr > ) -> Type < ' db > ,
59606027 tcx : TypeContext < ' db > ,
59616028 ) -> Option < Type < ' db > > {
@@ -5980,7 +6047,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
59806047 let Some ( ( collection_alias, generic_context, elt_tys) ) = elt_tys ( collection_class) else {
59816048 // Infer the element types without type context, and fallback to `Unknown` for
59826049 // custom typesheds.
5983- for ( i, elt) in elts. flatten ( ) . flatten ( ) . enumerate ( ) {
6050+ for ( i, elt) in elts. iter ( ) . flatten ( ) . flatten ( ) . enumerate ( ) {
59846051 infer_elt_expression ( self , ( i, elt, TypeContext :: default ( ) ) ) ;
59856052 }
59866053
@@ -6009,12 +6076,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
60096076 FxHashMap :: default ( ) ;
60106077
60116078 if let Some ( tcx) = tcx. annotation
6012- // If there are multiple potential type contexts, we fallback to `Unknown`.
6013- // TODO: We could perform multi-inference here.
6014- && tcx
6015- . filter_union ( self . db ( ) , |ty| ty. class_specialization ( self . db ( ) ) . is_some ( ) )
6016- . class_specialization ( self . db ( ) )
6017- . is_some ( )
6079+ && tcx. class_specialization ( self . db ( ) ) . is_some ( )
60186080 {
60196081 let collection_instance =
60206082 Type :: instance ( self . db ( ) , ClassType :: Generic ( collection_alias) ) ;
@@ -6245,15 +6307,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
62456307 fn infer_comprehension_specialization < const N : usize > (
62466308 & mut self ,
62476309 collection_class : KnownClass ,
6248- elements : & [ Option < & ast:: Expr > ; N ] ,
6310+ elements : [ Option < & ast:: Expr > ; N ] ,
62496311 inference : & ScopeInference < ' db > ,
62506312 tcx : TypeContext < ' db > ,
62516313 ) -> Option < Type < ' db > > {
6252- let mut elements = [ elements] . into_iter ( ) . copied ( ) ;
62536314 let mut infer_element_ty =
62546315 |_builder : & mut Self , ( _, elt, _) | inference. expression_type ( elt) ;
62556316
6256- self . infer_collection_literal ( collection_class, & mut elements, & mut infer_element_ty, tcx)
6317+ self . infer_collection_literal ( collection_class, & [ elements] , & mut infer_element_ty, tcx)
62576318 }
62586319
62596320 fn infer_list_comprehension_expression (
@@ -6280,7 +6341,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
62806341 let inference = infer_scope_types ( self . db ( ) , scope, tcx) ;
62816342 self . extend_scope ( inference) ;
62826343
6283- self . infer_comprehension_specialization ( KnownClass :: List , & [ Some ( elt) ] , inference, tcx)
6344+ self . infer_comprehension_specialization ( KnownClass :: List , [ Some ( elt) ] , inference, tcx)
62846345 . unwrap_or_else ( || {
62856346 KnownClass :: List . to_specialized_instance ( self . db ( ) , & [ Type :: unknown ( ) ] )
62866347 } )
@@ -6310,7 +6371,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
63106371 let inference = infer_scope_types ( self . db ( ) , scope, tcx) ;
63116372 self . extend_scope ( inference) ;
63126373
6313- self . infer_comprehension_specialization ( KnownClass :: Set , & [ Some ( elt) ] , inference, tcx)
6374+ self . infer_comprehension_specialization ( KnownClass :: Set , [ Some ( elt) ] , inference, tcx)
63146375 . unwrap_or_else ( || {
63156376 KnownClass :: Set . to_specialized_instance ( self . db ( ) , & [ Type :: unknown ( ) ] )
63166377 } )
@@ -6343,7 +6404,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
63436404
63446405 self . infer_comprehension_specialization (
63456406 KnownClass :: Dict ,
6346- & [ Some ( key) , Some ( value) ] ,
6407+ [ Some ( key) , Some ( value) ] ,
63476408 inference,
63486409 tcx,
63496410 )
@@ -6378,10 +6439,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
63786439 } = listcomp;
63796440
63806441 // Infer the element type using the outer type context.
6381- let mut elts = [ [ Some ( elt. as_ref ( ) ) ] ] . into_iter ( ) ;
6442+ let elts = [ [ Some ( elt. as_ref ( ) ) ] ] ;
63826443 let mut infer_elt_ty =
63836444 |builder : & mut Self , ( _, elt, tcx) | builder. infer_expression ( elt, tcx) ;
6384- self . infer_collection_literal ( KnownClass :: List , & mut elts, & mut infer_elt_ty, tcx) ;
6445+ self . infer_collection_literal ( KnownClass :: List , & elts, & mut infer_elt_ty, tcx) ;
63856446
63866447 self . infer_comprehensions ( generators) ;
63876448 }
@@ -6399,10 +6460,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
63996460 } = setcomp;
64006461
64016462 // Infer the element type using the outer type context.
6402- let mut elts = [ [ Some ( elt. as_ref ( ) ) ] ] . into_iter ( ) ;
6463+ let elts = [ [ Some ( elt. as_ref ( ) ) ] ] ;
64036464 let mut infer_elt_ty =
64046465 |builder : & mut Self , ( _, elt, tcx) | builder. infer_expression ( elt, tcx) ;
6405- self . infer_collection_literal ( KnownClass :: Set , & mut elts, & mut infer_elt_ty, tcx) ;
6466+ self . infer_collection_literal ( KnownClass :: Set , & elts, & mut infer_elt_ty, tcx) ;
64066467
64076468 self . infer_comprehensions ( generators) ;
64086469 }
@@ -6421,10 +6482,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
64216482 } = dictcomp;
64226483
64236484 // Infer the key and value types using the outer type context.
6424- let mut elts = [ [ Some ( key. as_ref ( ) ) , Some ( value. as_ref ( ) ) ] ] . into_iter ( ) ;
6485+ let elts = [ [ Some ( key. as_ref ( ) ) , Some ( value. as_ref ( ) ) ] ] ;
64256486 let mut infer_elt_ty =
64266487 |builder : & mut Self , ( _, elt, tcx) | builder. infer_expression ( elt, tcx) ;
6427- self . infer_collection_literal ( KnownClass :: Dict , & mut elts, & mut infer_elt_ty, tcx) ;
6488+ self . infer_collection_literal ( KnownClass :: Dict , & elts, & mut infer_elt_ty, tcx) ;
64286489
64296490 self . infer_comprehensions ( generators) ;
64306491 }
0 commit comments