Skip to content

Commit 12256f3

Browse files
committed
wip12
1 parent 0365a0c commit 12256f3

4 files changed

Lines changed: 137 additions & 206 deletions

File tree

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 41 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ private import FunctionType
1414
private import FunctionOverloading as FunctionOverloading
1515
private import BlanketImplementation as BlanketImplementation
1616
private import codeql.rust.elements.internal.VariableImpl::Impl as VariableImpl
17-
private import codeql.rust.internal.CachedStages
1817
private import codeql.typeinference.internal.TypeInference
1918
private import codeql.rust.frameworks.stdlib.Stdlib
2019
private import codeql.rust.frameworks.stdlib.Builtins as Builtins
@@ -394,13 +393,7 @@ private module Input3 implements InputSig3 {
394393
}
395394
}
396395

397-
class Call extends Expr instanceof FunctionCallMatchingInput::Access {
398-
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
399-
result = super.getTypeArgument(apos, path)
400-
}
401-
402-
AstNode getNodeAt(TypePosition pos) { result = super.getNodeAt(pos) }
403-
396+
class Call extends FunctionCallMatchingInput::Access {
404397
/** Gets the target of this call. */
405398
Callable getTargetCertain() {
406399
exists(ImplOrTraitItemNodeOption i, FunctionDeclaration f, Path p |
@@ -421,7 +414,7 @@ private module Input3 implements InputSig3 {
421414

422415
Type inferCallReturnType(AstNode n, TypePath path) {
423416
exists(Call call, TypePath path0 |
424-
result = inferCallReturnType(call, _, n, path0) and
417+
result = M3::inferCallReturnType(call, _, n, path0) and
425418
if
426419
// index expression `x[i]` desugars to `*x.index(i)`, so we must account for
427420
// the implicit deref
@@ -598,11 +591,15 @@ private module Input3 implements InputSig3 {
598591
predicate inferLubStep(AstNode n1, TypePath path1, AstNode n2, TypePath path2) {
599592
path1.isEmpty() and
600593
(
601-
n2 = any(ArrayListExpr ale | n1 = ale.getAnExpr()) and
594+
n2 = n1.(ArrayListExpr).getAnExpr() and
602595
path2 = TypePath::singleton(getArrayTypeParameter())
603596
or
604-
bodyReturns(n2, n1) and
605-
path2.isEmpty()
597+
exists(ReturnExpr re, Rust::Callable c |
598+
n1 = re.getExpr() and
599+
c = re.getEnclosingCallable() and
600+
n2 = c.getBody() and
601+
path2.isEmpty()
602+
)
606603
or
607604
exists(Struct s |
608605
n1 = [n2.(RangeExpr).getStart(), n2.(RangeExpr).getEnd()] and
@@ -617,14 +614,22 @@ private module Input3 implements InputSig3 {
617614
result = inferTypeFromAnnotationTopDown(n, path)
618615
or
619616
result = inferClosureExprBodyTypeTopDown(n, path)
617+
or
618+
exists(FunctionPosition pos | not pos.isReturn() |
619+
result = inferConstructionType(n, pos, path)
620+
or
621+
result = inferOperationType(n, pos, path)
622+
)
620623
}
621624

622625
Type inferTypeSpecific(AstNode n, TypePath path) {
623626
result = inferAssignmentOperationType(n, path)
624627
or
625-
result = inferConstructionType(n, path)
626-
or
627-
result = inferOperationType(n, path)
628+
exists(FunctionPosition pos | pos.isReturn() |
629+
result = inferConstructionType(n, pos, path)
630+
or
631+
result = inferOperationType(n, pos, path)
632+
)
628633
or
629634
result = inferFieldExprType(n, path)
630635
or
@@ -650,7 +655,12 @@ private module Input3 implements InputSig3 {
650655

651656
private module M3 = Make3<Input3>;
652657

653-
import M3
658+
// import M3
659+
predicate inferType = M3::inferType/1;
660+
661+
predicate inferType = M3::inferType/2;
662+
663+
predicate inferTypeCertain = M3::inferTypeCertain/2;
654664

655665
module Consistency {
656666
import M2::Consistency
@@ -917,14 +927,6 @@ private Struct getRangeType(RangeExpr re) {
917927
result instanceof RangeToInclusiveStruct
918928
}
919929

920-
private predicate bodyReturns(Expr body, Expr e) {
921-
exists(ReturnExpr re, Callable c |
922-
e = re.getExpr() and
923-
c = re.getEnclosingCallable() and
924-
body = c.getBody()
925-
)
926-
}
927-
928930
pragma[nomagic]
929931
private Type inferTypeFromAnnotationTopDown(AstNode n, TypePath path) {
930932
// Normally, these are coercion sites, but in case a type is unknown we
@@ -1082,7 +1084,7 @@ private module ContextTyping {
10821084
* context in which the call appears, for example a call like
10831085
* `Default::default()`.
10841086
*/
1085-
abstract class ContextTypedCallCand extends AstNode {
1087+
abstract class ContextTypedCallCand extends Expr {
10861088
abstract Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
10871089

10881090
predicate hasTypeArgument(TypeArgumentPosition apos) { exists(this.getTypeArgument(apos, _)) }
@@ -2653,7 +2655,9 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput
26532655
)
26542656
}
26552657

2656-
abstract class Access extends ContextTyping::ContextTypedCallCand {
2658+
final class Access = AccessImpl;
2659+
2660+
abstract private class AccessImpl extends ContextTyping::ContextTypedCallCand {
26572661
abstract AstNode getNodeAt(FunctionPosition pos);
26582662

26592663
bindingset[derefChainBorrow]
@@ -2668,7 +2672,7 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput
26682672
abstract predicate hasUnknownTypeAt(string derefChainBorrow, FunctionPosition pos, TypePath path);
26692673
}
26702674

2671-
private class AssocFunctionCallAccess extends Access instanceof AssocFunctionResolution::AssocFunctionCall
2675+
private class AssocFunctionCallAccess extends AccessImpl instanceof AssocFunctionResolution::AssocFunctionCall
26722676
{
26732677
AssocFunctionCallAccess() {
26742678
// handled in the `OperationMatchingInput` module
@@ -2755,7 +2759,7 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput
27552759
}
27562760
}
27572761

2758-
private class NonAssocFunctionCallAccess extends Access instanceof NonAssocCallExpr,
2762+
private class NonAssocFunctionCallAccess extends AccessImpl instanceof NonAssocCallExpr,
27592763
CallExprImpl::CallExprCall
27602764
{
27612765
pragma[nomagic]
@@ -2815,7 +2819,7 @@ private Type inferCallArgumentTypeTopDown(
28152819
) {
28162820
exists(string derefChainBorrow |
28172821
FunctionCallMatchingInput::decodeDerefChainBorrow(derefChainBorrow, derefChain, borrow) and
2818-
result = inferCallArgumentTypeTopDown(call, derefChainBorrow, pos, n, path)
2822+
result = M3::inferCallArgumentTypeTopDown(call, derefChainBorrow, pos, n, path)
28192823
)
28202824
}
28212825

@@ -3024,16 +3028,13 @@ private module ConstructionMatchingInput implements MatchingInputSig {
30243028
private module ConstructionMatching = Matching<ConstructionMatchingInput>;
30253029

30263030
pragma[nomagic]
3027-
private Type inferConstructionTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) {
3031+
private Type inferConstructionType(AstNode n, FunctionPosition pos, TypePath path) {
30283032
exists(ConstructionMatchingInput::Access a |
30293033
n = a.getNodeAt(pos) and
30303034
result = ConstructionMatching::inferAccessType(a, pos, path)
30313035
)
30323036
}
30333037

3034-
private predicate inferConstructionType =
3035-
CheckContextTyping<inferConstructionTypePreCheck/3>::check/2;
3036-
30373038
pragma[nomagic]
30383039
private Type inferUnknownType(AstNode n, TypePath path) {
30393040
result = TUnknownType() and
@@ -3119,16 +3120,14 @@ private module OperationMatchingInput implements MatchingInputSig {
31193120
private module OperationMatching = Matching<OperationMatchingInput>;
31203121

31213122
pragma[nomagic]
3122-
private Type inferOperationTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) {
3123+
private Type inferOperationType(AstNode n, FunctionPosition pos, TypePath path) {
31233124
exists(OperationMatchingInput::Access a |
31243125
n = a.getNodeAt(pos) and
31253126
result = OperationMatching::inferAccessType(a, pos, path) and
31263127
if pos.asPosition() = 0 then not path.isEmpty() else any()
31273128
)
31283129
}
31293130

3130-
private predicate inferOperationType = CheckContextTyping<inferOperationTypePreCheck/3>::check/2;
3131-
31323131
pragma[nomagic]
31333132
private Type getFieldExprLookupType(FieldExpr fe, string name, DerefChain derefChain) {
31343133
exists(TypePath path |
@@ -3153,6 +3152,7 @@ private Type getFieldExprLookupType(FieldExpr fe, string name, DerefChain derefC
31533152
*/
31543153
cached
31553154
StructField resolveStructFieldExpr(FieldExpr fe, DerefChain derefChain) {
3155+
M3::CachedStage::ref() and
31563156
exists(string name, DataType ty |
31573157
ty = getFieldExprLookupType(fe, pragma[only_bind_into](name), derefChain)
31583158
|
@@ -3174,6 +3174,7 @@ private Type getTupleFieldExprLookupType(FieldExpr fe, int pos, DerefChain deref
31743174
*/
31753175
cached
31763176
TupleField resolveTupleFieldExpr(FieldExpr fe, DerefChain derefChain) {
3177+
M3::CachedStage::ref() and
31773178
exists(int i |
31783179
result =
31793180
getTupleFieldExprLookupType(fe, pragma[only_bind_into](i), derefChain)
@@ -3664,7 +3665,7 @@ private Type inferCastExprType(CastExpr ce, TypePath path) {
36643665
/** Holds if `n` is implicitly dereferenced and/or borrowed. */
36653666
cached
36663667
predicate implicitDerefChainBorrow(Expr e, DerefChain derefChain, boolean borrow) {
3667-
CachedStage::ref() and
3668+
M3::CachedStage::ref() and
36683669
exists(BorrowKind bk |
36693670
any(AssocFunctionResolution::AssocFunctionCall afc)
36703671
.argumentHasImplicitDerefChainBorrow(e, derefChain, bk) and
@@ -3691,6 +3692,7 @@ predicate implicitDerefChainBorrow(Expr e, DerefChain derefChain, boolean borrow
36913692
*/
36923693
cached
36933694
Addressable resolveCallTarget(InvocationExpr call, boolean dispatch) {
3695+
M3::CachedStage::ref() and
36943696
dispatch = false and
36953697
result = call.(NonAssocCallExpr).resolveCallTargetViaPathResolution()
36963698
or
@@ -3711,7 +3713,7 @@ private module Debug {
37113713
exists(string filepath, int startline, int startcolumn, int endline, int endcolumn |
37123714
result.getLocation().hasLocationInfo(filepath, startline, startcolumn, endline, endcolumn) and
37133715
filepath.matches("%/main.rs") and
3714-
startline = 1102
3716+
startline = 103
37153717
)
37163718
}
37173719

@@ -3737,24 +3739,11 @@ private module Debug {
37373739
t = self.getTypeAt(path)
37383740
}
37393741

3740-
// predicate debugInferFunctionCallType(AstNode n, TypePath path, Type t) {
3741-
// n = getRelevantLocatable() and
3742-
// t = inferFunctionCallType(n, path)
3743-
// }
3744-
predicate debugInferConstructionType(AstNode n, TypePath path, Type t) {
3745-
n = getRelevantLocatable() and
3746-
t = inferConstructionType(n, path)
3747-
}
3748-
37493742
predicate debugTypeMention(TypeMention tm, TypePath path, Type type) {
37503743
tm = getRelevantLocatable() and
37513744
tm.getTypeAt(path) = type
37523745
}
37533746

3754-
// Type debugInferAnnotatedType(AstNode n, TypePath path) {
3755-
// n = getRelevantLocatable() and
3756-
// result = inferAnnotatedType(n, path)
3757-
// }
37583747
pragma[nomagic]
37593748
private int countTypesAtPath(AstNode n, TypePath path, Type t) {
37603749
t = inferType(n, path) and
@@ -3803,7 +3792,7 @@ private module Debug {
38033792
c = max(countTypePaths(_, _, _))
38043793
}
38053794

3806-
Type debuginferTypeCertain(AstNode n, TypePath path) {
3795+
Type debugInferTypeCertain(AstNode n, TypePath path) {
38073796
n = getRelevantLocatable() and
38083797
result = inferTypeCertain(n, path)
38093798
}

0 commit comments

Comments
 (0)