@@ -14,7 +14,6 @@ private import FunctionType
1414private import FunctionOverloading as FunctionOverloading
1515private import BlanketImplementation as BlanketImplementation
1616private import codeql.rust.elements.internal.VariableImpl:: Impl as VariableImpl
17- private import codeql.rust.internal.CachedStages
1817private import codeql.typeinference.internal.TypeInference
1918private import codeql.rust.frameworks.stdlib.Stdlib
2019private import codeql.rust.frameworks.stdlib.Builtins as Builtins
@@ -394,13 +393,7 @@ private module Input3 implements InputSig3 {
394393 }
395394 }
396395
397- class Call extends Expr instanceof FunctionCallMatchingInput:: Access {
398- Type getTypeArgument ( TypeArgumentPosition apos , TypePath path ) {
399- result = super .getTypeArgument ( apos , path )
400- }
401-
402- AstNode getNodeAt ( TypePosition pos ) { result = super .getNodeAt ( pos ) }
403-
396+ class Call extends FunctionCallMatchingInput:: Access {
404397 /** Gets the target of this call. */
405398 Callable getTargetCertain ( ) {
406399 exists ( ImplOrTraitItemNodeOption i , FunctionDeclaration f , Path p |
@@ -421,7 +414,7 @@ private module Input3 implements InputSig3 {
421414
422415 Type inferCallReturnType ( AstNode n , TypePath path ) {
423416 exists ( Call call , TypePath path0 |
424- result = inferCallReturnType ( call , _, n , path0 ) and
417+ result = M3 :: inferCallReturnType ( call , _, n , path0 ) and
425418 if
426419 // index expression `x[i]` desugars to `*x.index(i)`, so we must account for
427420 // the implicit deref
@@ -598,11 +591,15 @@ private module Input3 implements InputSig3 {
598591 predicate inferLubStep ( AstNode n1 , TypePath path1 , AstNode n2 , TypePath path2 ) {
599592 path1 .isEmpty ( ) and
600593 (
601- n2 = any ( ArrayListExpr ale | n1 = ale .getAnExpr ( ) ) and
594+ n2 = n1 . ( ArrayListExpr ) .getAnExpr ( ) and
602595 path2 = TypePath:: singleton ( getArrayTypeParameter ( ) )
603596 or
604- bodyReturns ( n2 , n1 ) and
605- path2 .isEmpty ( )
597+ exists ( ReturnExpr re , Rust:: Callable c |
598+ n1 = re .getExpr ( ) and
599+ c = re .getEnclosingCallable ( ) and
600+ n2 = c .getBody ( ) and
601+ path2 .isEmpty ( )
602+ )
606603 or
607604 exists ( Struct s |
608605 n1 = [ n2 .( RangeExpr ) .getStart ( ) , n2 .( RangeExpr ) .getEnd ( ) ] and
@@ -617,14 +614,22 @@ private module Input3 implements InputSig3 {
617614 result = inferTypeFromAnnotationTopDown ( n , path )
618615 or
619616 result = inferClosureExprBodyTypeTopDown ( n , path )
617+ or
618+ exists ( FunctionPosition pos | not pos .isReturn ( ) |
619+ result = inferConstructionType ( n , pos , path )
620+ or
621+ result = inferOperationType ( n , pos , path )
622+ )
620623 }
621624
622625 Type inferTypeSpecific ( AstNode n , TypePath path ) {
623626 result = inferAssignmentOperationType ( n , path )
624627 or
625- result = inferConstructionType ( n , path )
626- or
627- result = inferOperationType ( n , path )
628+ exists ( FunctionPosition pos | pos .isReturn ( ) |
629+ result = inferConstructionType ( n , pos , path )
630+ or
631+ result = inferOperationType ( n , pos , path )
632+ )
628633 or
629634 result = inferFieldExprType ( n , path )
630635 or
@@ -650,7 +655,12 @@ private module Input3 implements InputSig3 {
650655
651656private module M3 = Make3< Input3 > ;
652657
653- import M3
658+ // import M3
659+ predicate inferType = M3:: inferType / 1 ;
660+
661+ predicate inferType = M3:: inferType / 2 ;
662+
663+ predicate inferTypeCertain = M3:: inferTypeCertain / 2 ;
654664
655665module Consistency {
656666 import M2:: Consistency
@@ -917,14 +927,6 @@ private Struct getRangeType(RangeExpr re) {
917927 result instanceof RangeToInclusiveStruct
918928}
919929
920- private predicate bodyReturns ( Expr body , Expr e ) {
921- exists ( ReturnExpr re , Callable c |
922- e = re .getExpr ( ) and
923- c = re .getEnclosingCallable ( ) and
924- body = c .getBody ( )
925- )
926- }
927-
928930pragma [ nomagic]
929931private Type inferTypeFromAnnotationTopDown ( AstNode n , TypePath path ) {
930932 // Normally, these are coercion sites, but in case a type is unknown we
@@ -1082,7 +1084,7 @@ private module ContextTyping {
10821084 * context in which the call appears, for example a call like
10831085 * `Default::default()`.
10841086 */
1085- abstract class ContextTypedCallCand extends AstNode {
1087+ abstract class ContextTypedCallCand extends Expr {
10861088 abstract Type getTypeArgument ( TypeArgumentPosition apos , TypePath path ) ;
10871089
10881090 predicate hasTypeArgument ( TypeArgumentPosition apos ) { exists ( this .getTypeArgument ( apos , _) ) }
@@ -2653,7 +2655,9 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput
26532655 )
26542656 }
26552657
2656- abstract class Access extends ContextTyping:: ContextTypedCallCand {
2658+ final class Access = AccessImpl ;
2659+
2660+ abstract private class AccessImpl extends ContextTyping:: ContextTypedCallCand {
26572661 abstract AstNode getNodeAt ( FunctionPosition pos ) ;
26582662
26592663 bindingset [ derefChainBorrow]
@@ -2668,7 +2672,7 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput
26682672 abstract predicate hasUnknownTypeAt ( string derefChainBorrow , FunctionPosition pos , TypePath path ) ;
26692673 }
26702674
2671- private class AssocFunctionCallAccess extends Access instanceof AssocFunctionResolution:: AssocFunctionCall
2675+ private class AssocFunctionCallAccess extends AccessImpl instanceof AssocFunctionResolution:: AssocFunctionCall
26722676 {
26732677 AssocFunctionCallAccess ( ) {
26742678 // handled in the `OperationMatchingInput` module
@@ -2755,7 +2759,7 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput
27552759 }
27562760 }
27572761
2758- private class NonAssocFunctionCallAccess extends Access instanceof NonAssocCallExpr ,
2762+ private class NonAssocFunctionCallAccess extends AccessImpl instanceof NonAssocCallExpr ,
27592763 CallExprImpl:: CallExprCall
27602764 {
27612765 pragma [ nomagic]
@@ -2815,7 +2819,7 @@ private Type inferCallArgumentTypeTopDown(
28152819) {
28162820 exists ( string derefChainBorrow |
28172821 FunctionCallMatchingInput:: decodeDerefChainBorrow ( derefChainBorrow , derefChain , borrow ) and
2818- result = inferCallArgumentTypeTopDown ( call , derefChainBorrow , pos , n , path )
2822+ result = M3 :: inferCallArgumentTypeTopDown ( call , derefChainBorrow , pos , n , path )
28192823 )
28202824}
28212825
@@ -3024,16 +3028,13 @@ private module ConstructionMatchingInput implements MatchingInputSig {
30243028private module ConstructionMatching = Matching< ConstructionMatchingInput > ;
30253029
30263030pragma [ nomagic]
3027- private Type inferConstructionTypePreCheck ( AstNode n , FunctionPosition pos , TypePath path ) {
3031+ private Type inferConstructionType ( AstNode n , FunctionPosition pos , TypePath path ) {
30283032 exists ( ConstructionMatchingInput:: Access a |
30293033 n = a .getNodeAt ( pos ) and
30303034 result = ConstructionMatching:: inferAccessType ( a , pos , path )
30313035 )
30323036}
30333037
3034- private predicate inferConstructionType =
3035- CheckContextTyping< inferConstructionTypePreCheck / 3 > :: check / 2 ;
3036-
30373038pragma [ nomagic]
30383039private Type inferUnknownType ( AstNode n , TypePath path ) {
30393040 result = TUnknownType ( ) and
@@ -3119,16 +3120,14 @@ private module OperationMatchingInput implements MatchingInputSig {
31193120private module OperationMatching = Matching< OperationMatchingInput > ;
31203121
31213122pragma [ nomagic]
3122- private Type inferOperationTypePreCheck ( AstNode n , FunctionPosition pos , TypePath path ) {
3123+ private Type inferOperationType ( AstNode n , FunctionPosition pos , TypePath path ) {
31233124 exists ( OperationMatchingInput:: Access a |
31243125 n = a .getNodeAt ( pos ) and
31253126 result = OperationMatching:: inferAccessType ( a , pos , path ) and
31263127 if pos .asPosition ( ) = 0 then not path .isEmpty ( ) else any ( )
31273128 )
31283129}
31293130
3130- private predicate inferOperationType = CheckContextTyping< inferOperationTypePreCheck / 3 > :: check / 2 ;
3131-
31323131pragma [ nomagic]
31333132private Type getFieldExprLookupType ( FieldExpr fe , string name , DerefChain derefChain ) {
31343133 exists ( TypePath path |
@@ -3153,6 +3152,7 @@ private Type getFieldExprLookupType(FieldExpr fe, string name, DerefChain derefC
31533152 */
31543153cached
31553154StructField resolveStructFieldExpr ( FieldExpr fe , DerefChain derefChain ) {
3155+ M3:: CachedStage:: ref ( ) and
31563156 exists ( string name , DataType ty |
31573157 ty = getFieldExprLookupType ( fe , pragma [ only_bind_into ] ( name ) , derefChain )
31583158 |
@@ -3174,6 +3174,7 @@ private Type getTupleFieldExprLookupType(FieldExpr fe, int pos, DerefChain deref
31743174 */
31753175cached
31763176TupleField resolveTupleFieldExpr ( FieldExpr fe , DerefChain derefChain ) {
3177+ M3:: CachedStage:: ref ( ) and
31773178 exists ( int i |
31783179 result =
31793180 getTupleFieldExprLookupType ( fe , pragma [ only_bind_into ] ( i ) , derefChain )
@@ -3664,7 +3665,7 @@ private Type inferCastExprType(CastExpr ce, TypePath path) {
36643665/** Holds if `n` is implicitly dereferenced and/or borrowed. */
36653666cached
36663667predicate implicitDerefChainBorrow ( Expr e , DerefChain derefChain , boolean borrow ) {
3667- CachedStage:: ref ( ) and
3668+ M3 :: CachedStage:: ref ( ) and
36683669 exists ( BorrowKind bk |
36693670 any ( AssocFunctionResolution:: AssocFunctionCall afc )
36703671 .argumentHasImplicitDerefChainBorrow ( e , derefChain , bk ) and
@@ -3691,6 +3692,7 @@ predicate implicitDerefChainBorrow(Expr e, DerefChain derefChain, boolean borrow
36913692 */
36923693cached
36933694Addressable resolveCallTarget ( InvocationExpr call , boolean dispatch ) {
3695+ M3:: CachedStage:: ref ( ) and
36943696 dispatch = false and
36953697 result = call .( NonAssocCallExpr ) .resolveCallTargetViaPathResolution ( )
36963698 or
@@ -3711,7 +3713,7 @@ private module Debug {
37113713 exists ( string filepath , int startline , int startcolumn , int endline , int endcolumn |
37123714 result .getLocation ( ) .hasLocationInfo ( filepath , startline , startcolumn , endline , endcolumn ) and
37133715 filepath .matches ( "%/main.rs" ) and
3714- startline = 1102
3716+ startline = 103
37153717 )
37163718 }
37173719
@@ -3737,24 +3739,11 @@ private module Debug {
37373739 t = self .getTypeAt ( path )
37383740 }
37393741
3740- // predicate debugInferFunctionCallType(AstNode n, TypePath path, Type t) {
3741- // n = getRelevantLocatable() and
3742- // t = inferFunctionCallType(n, path)
3743- // }
3744- predicate debugInferConstructionType ( AstNode n , TypePath path , Type t ) {
3745- n = getRelevantLocatable ( ) and
3746- t = inferConstructionType ( n , path )
3747- }
3748-
37493742 predicate debugTypeMention ( TypeMention tm , TypePath path , Type type ) {
37503743 tm = getRelevantLocatable ( ) and
37513744 tm .getTypeAt ( path ) = type
37523745 }
37533746
3754- // Type debugInferAnnotatedType(AstNode n, TypePath path) {
3755- // n = getRelevantLocatable() and
3756- // result = inferAnnotatedType(n, path)
3757- // }
37583747 pragma [ nomagic]
37593748 private int countTypesAtPath ( AstNode n , TypePath path , Type t ) {
37603749 t = inferType ( n , path ) and
@@ -3803,7 +3792,7 @@ private module Debug {
38033792 c = max ( countTypePaths ( _, _, _) )
38043793 }
38053794
3806- Type debuginferTypeCertain ( AstNode n , TypePath path ) {
3795+ Type debugInferTypeCertain ( AstNode n , TypePath path ) {
38073796 n = getRelevantLocatable ( ) and
38083797 result = inferTypeCertain ( n , path )
38093798 }
0 commit comments