Skip to content

Commit 69e7f6f

Browse files
committed
wip10
1 parent 0e414a5 commit 69e7f6f

2 files changed

Lines changed: 80 additions & 43 deletions

File tree

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,7 @@ private predicate bodyReturns(Expr body, Expr e) {
916916
}
917917

918918
pragma[nomagic]
919-
private Type inferUnknownTypeFromAnnotationCand(AstNode n, TypePath path, TypePath prefix) {
919+
private Type inferTypeFromAnnotationTopDown(AstNode n, TypePath path) {
920920
// Normally, these are coercion sites, but in case a type is unknown we
921921
// allow for type information to flow from the type annotation.
922922
exists(TypeMention tm | result = tm.getTypeAt(path) |
@@ -925,17 +925,12 @@ private Type inferUnknownTypeFromAnnotationCand(AstNode n, TypePath path, TypePa
925925
tm = any(ClosureExpr ce | n = ce.getBody()).getRetType().getTypeRepr()
926926
or
927927
tm = getReturnTypeMention(any(Function f | n = f.getBody()))
928-
) and
929-
prefix = path.getAPrefix()
930-
}
931-
932-
private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) {
933-
exists(TypePath prefix |
934-
result = inferUnknownTypeFromAnnotationCand(n, path, prefix) and
935-
hasUnknownTypeAt(n, prefix)
936928
)
937929
}
938930

931+
private predicate inferUnknownTypeFromAnnotation =
932+
TopDownTyping<inferTypeFromAnnotationTopDown/2>::inferType/2;
933+
939934
pragma[nomagic]
940935
private TupleType inferTupleRootType(AstNode n) {
941936
// `typeEquality` handles the non-root cases
@@ -2819,7 +2814,7 @@ private Type inferFunctionCallType0(
28192814
call.hasUnknownTypeAt(derefChainBorrow, pos, path0) and
28202815
result = TUnknownType()
28212816
or
2822-
result = inferCallTypeOut(call, pos, n, derefChainBorrow, path0)
2817+
result = inferCallTypeOut(call, derefChainBorrow, pos, n, path0)
28232818
)
28242819
|
28252820
if
@@ -3630,11 +3625,10 @@ private Type inferForLoopExprType(AstNode n, TypePath path) {
36303625
}
36313626

36323627
pragma[nomagic]
3633-
private Type inferClosureExprBodyTypeCand(AstNode n, TypePath path, TypePath prefix) {
3628+
private Type inferClosureExprBodyTypeTopDown(AstNode n, TypePath path) {
36343629
exists(ClosureExpr ce |
36353630
n = ce.getClosureBody() and
3636-
result = inferType(ce, closureReturnPath().appendInverse(path)) and
3637-
prefix = path.getAPrefix()
3631+
result = inferType(ce, closureReturnPath().appendInverse(path))
36383632
)
36393633
}
36403634

@@ -3661,10 +3655,7 @@ private Type inferClosureExprType(AstNode n, TypePath path) {
36613655
)
36623656
)
36633657
or
3664-
exists(TypePath prefix |
3665-
result = inferClosureExprBodyTypeCand(n, path, prefix) and
3666-
hasUnknownTypeAt(n, prefix)
3667-
)
3658+
result = TopDownTyping<inferClosureExprBodyTypeTopDown/2>::inferType(n, path)
36683659
}
36693660

36703661
pragma[nomagic]

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ signature module InputSig1<LocationSig Location> {
147147

148148
/**
149149
* A special pseudo type used to represent cases where the actual type needs
150-
* to be inferred from the context. For example, in
150+
* to be inferred from the context in a top-down manner. For example, in
151151
*
152152
* ```rust
153153
* let x = Vec::new();
@@ -2146,13 +2146,14 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
21462146
*/
21472147
signature module InputSig3 {
21482148
/**
2149-
* References to cached predicates that should be included to the cached
2150-
* stage of type inference. Such predicates should reference `CachedStage::ref`.
2149+
* A predicate used to reference cached predicates that should be included to the
2150+
* cached stage of type inference. Such predicates should themselves reference
2151+
* `CachedStage::ref`.
21512152
*/
21522153
default predicate cachedStageRevRef() { none() }
21532154

21542155
/**
2155-
* Point this predicate to the `inferType` predicate in the output of this module.
2156+
* Point this predicate to the `inferType` predicate from the output of this module.
21562157
*
21572158
* Needed to be able to refer to `inferType` in default signature implementations.
21582159
*/
@@ -2278,7 +2279,8 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22782279
}
22792280

22802281
/**
2281-
* A position where a callable can have a declared type.
2282+
* A position where a callable can have a declared type and a call can have
2283+
* an inferred type.
22822284
*/
22832285
class TypePosition {
22842286
/** Holds if this position represents the return type of a callable. */
@@ -2288,7 +2290,14 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22882290
string toString();
22892291
}
22902292

2291-
/** A context needed to resolve calls. */
2293+
/**
2294+
* A context needed to resolve calls.
2295+
*
2296+
* For example, in Rust, we need an additional context to represent the
2297+
* candidate receiver type when resolving method calls.
2298+
*
2299+
* When not used, simply instantiate this class with `Unit`.
2300+
*/
22922301
bindingset[this]
22932302
class CallResolutionContext {
22942303
/** Gets a textual representation of this context. */
@@ -2298,11 +2307,18 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22982307

22992308
/** A callable. */
23002309
class Callable {
2310+
/** Gets the type parameter at position `ppos` of this callable, if any. */
23012311
TypeParameter getTypeParameter(TypeParameterPosition ppos);
23022312

2313+
/**
2314+
* Gets an additional type parameter constraint for the given type parameter,
2315+
* which applies to this callable. For example, in Rust, a function can apply
2316+
* additional constraints on type parameters belonging to the `impl` block
2317+
* that the function is defined in.
2318+
*/
23032319
TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp);
23042320

2305-
/* Gets the declared type of this callable at `path` for position `pos`. */
2321+
/** Gets the declared type of this callable at `path` for position `pos`. */
23062322
Type getDeclaredType(TypePosition pos, TypePath path);
23072323

23082324
/** Gets a textual representation of this callable. */
@@ -2312,19 +2328,31 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
23122328
Location getLocation();
23132329
}
23142330

2331+
/** A call expression. */
23152332
class Call extends Expr {
2333+
/** Gets the explicit type argument at position `apos` and `path` for this call, if any. */
23162334
Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
23172335

2336+
/** Gets the AST node corresponding to the position `pos` of this call. */
23182337
AstNode getNodeAt(TypePosition pos);
23192338

2320-
/** Gets the target of this call. */
2339+
/**
2340+
* Gets the target of this call, to be used when inferring certain types.
2341+
*/
23212342
Callable getTargetCertain();
23222343

2323-
/** Gets the target of this call. */
2344+
/** Gets the target of this call in the given context. */
23242345
Callable getTarget(CallResolutionContext ctx);
23252346
}
23262347

2327-
/** Gets the inferred type `call` at `path` for position `pos` in context `ctx`. */
2348+
/**
2349+
* Gets the inferred type of `call` at `path` and position `pos` in context `ctx`.
2350+
*
2351+
* By default, this is the inferred type of the node at the given position, but
2352+
* in for example Rust, the inferred type of the receiver of a method call needs
2353+
* to take the call context into account, in order to use the correct candidate
2354+
* receiver type.
2355+
*/
23282356
bindingset[ctx]
23292357
default Type inferCallTypeIn(
23302358
Call call, CallResolutionContext ctx, TypePosition pos, TypePath path
@@ -2556,9 +2584,9 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
25562584
path1.isEmpty() and
25572585
path2.isEmpty() and
25582586
(
2559-
exists(Assignment a |
2560-
a.getLeftOperand() = n1 and
2561-
a.getRightOperand() = n2
2587+
exists(AssignExpr ae |
2588+
ae.getLeftOperand() = n1 and
2589+
ae.getRightOperand() = n2
25622590
)
25632591
or
25642592
exists(LetDeclaration let |
@@ -2621,22 +2649,16 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
26212649
}
26222650

26232651
pragma[nomagic]
2624-
private Type inferTypeFromReverseLubStepCand(AstNode n, TypePath path, TypePath prefix) {
2652+
private Type inferTypeFromLubStepTopDown(AstNode n, TypePath path) {
26252653
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
26262654
result = inferType(n2, prefix2.appendInverse(suffix)) and
26272655
path = prefix1.append(suffix) and
2628-
lubStep(n, prefix1, n2, prefix2) and
2629-
prefix = path.getAPrefix()
2656+
lubStep(n, prefix1, n2, prefix2)
26302657
)
26312658
}
26322659

2633-
pragma[nomagic]
2634-
private Type inferTypeFromReverseLub(AstNode n, TypePath path) {
2635-
exists(TypePath prefix |
2636-
result = inferTypeFromReverseLubStepCand(n, path, prefix) and
2637-
hasUnknownTypeAt(n, prefix)
2638-
)
2639-
}
2660+
private predicate inferTypeFromReverseLub =
2661+
TopDownTyping<inferTypeFromLubStepTopDown/2>::inferType/2;
26402662

26412663
/**
26422664
* Gets the inferred type of `n` at `path`.
@@ -2705,22 +2727,46 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
27052727

27062728
private module CallMatching = MatchingWithEnvironment<CallMatchingInput>;
27072729

2708-
pragma[nomagic]
27092730
Type inferCallTypeOut(
2710-
Call call, TypePosition pos, AstNode n, CallResolutionContext ctx, TypePath path
2731+
Call call, CallResolutionContext ctx, TypePosition pos, AstNode n, TypePath path
27112732
) {
27122733
n = call.getNodeAt(pos) and
27132734
result = CallMatching::inferAccessType(call, ctx, pos, path)
27142735
}
27152736

27162737
pragma[nomagic]
2717-
predicate hasUnknownTypeAt(AstNode n, TypePath path) {
2738+
private predicate hasUnknownTypeAt(AstNode n, TypePath path) {
27182739
inferType(n, path) instanceof UnknownType
27192740
}
27202741

27212742
pragma[nomagic]
27222743
private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) }
27232744

2745+
signature Type inferTypeTopDownSig(AstNode n, TypePath path);
2746+
2747+
/**
2748+
* Given a predicate `inferTypeTopDown` for inferring the type of an AST node `n`
2749+
* top-down from a context, this module exposes the predicate `inferType`, which
2750+
* restricts type information to only flow top-down into `n` when `n` has an
2751+
* explicit unknown type.
2752+
*/
2753+
module TopDownTyping<inferTypeTopDownSig/2 inferTypeTopDown> {
2754+
pragma[nomagic]
2755+
private Type inferTypeTopDown(AstNode n, TypePath prefix, TypePath path) {
2756+
result = inferTypeTopDown(n, path) and
2757+
hasUnknownType(n) and
2758+
prefix = path.getAPrefix()
2759+
}
2760+
2761+
pragma[nomagic]
2762+
Type inferType(AstNode n, TypePath path) {
2763+
exists(TypePath prefix |
2764+
result = inferTypeTopDown(n, prefix, path) and
2765+
hasUnknownTypeAt(n, prefix)
2766+
)
2767+
}
2768+
}
2769+
27242770
signature Type inferCallTypeSig(AstNode n, TypePosition pos, TypePath path);
27252771

27262772
/**

0 commit comments

Comments
 (0)