Skip to content

Commit 2041b4a

Browse files
committed
wip10
1 parent 249f7c3 commit 2041b4a

2 files changed

Lines changed: 80 additions & 43 deletions

File tree

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,7 @@ private predicate bodyReturns(Expr body, Expr e) {
916916
}
917917

918918
pragma[nomagic]
919-
private Type inferUnknownTypeFromAnnotationCand(AstNode n, TypePath path, TypePath prefix) {
919+
private Type inferTypeFromAnnotationTopDown(AstNode n, TypePath path) {
920920
// Normally, these are coercion sites, but in case a type is unknown we
921921
// allow for type information to flow from the type annotation.
922922
exists(TypeMention tm | result = tm.getTypeAt(path) |
@@ -925,17 +925,12 @@ private Type inferUnknownTypeFromAnnotationCand(AstNode n, TypePath path, TypePa
925925
tm = any(ClosureExpr ce | n = ce.getBody()).getRetType().getTypeRepr()
926926
or
927927
tm = getReturnTypeMention(any(Function f | n = f.getBody()))
928-
) and
929-
prefix = path.getAPrefix()
930-
}
931-
932-
private Type inferUnknownTypeFromAnnotation(AstNode n, TypePath path) {
933-
exists(TypePath prefix |
934-
result = inferUnknownTypeFromAnnotationCand(n, path, prefix) and
935-
hasUnknownTypeAt(n, prefix)
936928
)
937929
}
938930

931+
private predicate inferUnknownTypeFromAnnotation =
932+
TopDownTyping<inferTypeFromAnnotationTopDown/2>::inferType/2;
933+
939934
pragma[nomagic]
940935
private TupleType inferTupleRootType(AstNode n) {
941936
// `typeEquality` handles the non-root cases
@@ -2819,7 +2814,7 @@ private Type inferFunctionCallType0(
28192814
call.hasUnknownTypeAt(derefChainBorrow, pos, path0) and
28202815
result = TUnknownType()
28212816
or
2822-
result = inferCallTypeOut(call, pos, n, derefChainBorrow, path0)
2817+
result = inferCallTypeOut(call, derefChainBorrow, pos, n, path0)
28232818
)
28242819
|
28252820
if
@@ -3630,11 +3625,10 @@ private Type inferForLoopExprType(AstNode n, TypePath path) {
36303625
}
36313626

36323627
pragma[nomagic]
3633-
private Type inferClosureExprBodyTypeCand(AstNode n, TypePath path, TypePath prefix) {
3628+
private Type inferClosureExprBodyTypeTopDown(AstNode n, TypePath path) {
36343629
exists(ClosureExpr ce |
36353630
n = ce.getClosureBody() and
3636-
result = inferType(ce, closureReturnPath().appendInverse(path)) and
3637-
prefix = path.getAPrefix()
3631+
result = inferType(ce, closureReturnPath().appendInverse(path))
36383632
)
36393633
}
36403634

@@ -3661,10 +3655,7 @@ private Type inferClosureExprType(AstNode n, TypePath path) {
36613655
)
36623656
)
36633657
or
3664-
exists(TypePath prefix |
3665-
result = inferClosureExprBodyTypeCand(n, path, prefix) and
3666-
hasUnknownTypeAt(n, prefix)
3667-
)
3658+
result = TopDownTyping<inferClosureExprBodyTypeTopDown/2>::inferType(n, path)
36683659
}
36693660

36703661
pragma[nomagic]

shared/typeinference/codeql/typeinference/internal/TypeInference.qll

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ signature module InputSig1<LocationSig Location> {
147147

148148
/**
149149
* A special pseudo type used to represent cases where the actual type needs
150-
* to be inferred from the context. For example, in
150+
* to be inferred from the context in a top-down manner. For example, in
151151
*
152152
* ```rust
153153
* let x = Vec::new();
@@ -2119,13 +2119,14 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
21192119
*/
21202120
signature module InputSig3 {
21212121
/**
2122-
* References to cached predicates that should be included to the cached
2123-
* stage of type inference. Such predicates should reference `CachedStage::ref`.
2122+
* A predicate used to reference cached predicates that should be included to the
2123+
* cached stage of type inference. Such predicates should themselves reference
2124+
* `CachedStage::ref`.
21242125
*/
21252126
default predicate cachedStageRevRef() { none() }
21262127

21272128
/**
2128-
* Point this predicate to the `inferType` predicate in the output of this module.
2129+
* Point this predicate to the `inferType` predicate from the output of this module.
21292130
*
21302131
* Needed to be able to refer to `inferType` in default signature implementations.
21312132
*/
@@ -2251,7 +2252,8 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22512252
}
22522253

22532254
/**
2254-
* A position where a callable can have a declared type.
2255+
* A position where a callable can have a declared type and a call can have
2256+
* an inferred type.
22552257
*/
22562258
class TypePosition {
22572259
/** Holds if this position represents the return type of a callable. */
@@ -2261,7 +2263,14 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22612263
string toString();
22622264
}
22632265

2264-
/** A context needed to resolve calls. */
2266+
/**
2267+
* A context needed to resolve calls.
2268+
*
2269+
* For example, in Rust, we need an additional context to represent the
2270+
* candidate receiver type when resolving method calls.
2271+
*
2272+
* When not used, simply instantiate this class with `Unit`.
2273+
*/
22652274
bindingset[this]
22662275
class CallResolutionContext {
22672276
/** Gets a textual representation of this context. */
@@ -2271,11 +2280,18 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22712280

22722281
/** A callable. */
22732282
class Callable {
2283+
/** Gets the type parameter at position `ppos` of this callable, if any. */
22742284
TypeParameter getTypeParameter(TypeParameterPosition ppos);
22752285

2286+
/**
2287+
* Gets an additional type parameter constraint for the given type parameter,
2288+
* which applies to this callable. For example, in Rust, a function can apply
2289+
* additional constraints on type parameters belonging to the `impl` block
2290+
* that the function is defined in.
2291+
*/
22762292
TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp);
22772293

2278-
/* Gets the declared type of this callable at `path` for position `pos`. */
2294+
/** Gets the declared type of this callable at `path` for position `pos`. */
22792295
Type getDeclaredType(TypePosition pos, TypePath path);
22802296

22812297
/** Gets a textual representation of this callable. */
@@ -2285,19 +2301,31 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
22852301
Location getLocation();
22862302
}
22872303

2304+
/** A call expression. */
22882305
class Call extends Expr {
2306+
/** Gets the explicit type argument at position `apos` and `path` for this call, if any. */
22892307
Type getTypeArgument(TypeArgumentPosition apos, TypePath path);
22902308

2309+
/** Gets the AST node corresponding to the position `pos` of this call. */
22912310
AstNode getNodeAt(TypePosition pos);
22922311

2293-
/** Gets the target of this call. */
2312+
/**
2313+
* Gets the target of this call, to be used when inferring certain types.
2314+
*/
22942315
Callable getTargetCertain();
22952316

2296-
/** Gets the target of this call. */
2317+
/** Gets the target of this call in the given context. */
22972318
Callable getTarget(CallResolutionContext ctx);
22982319
}
22992320

2300-
/** Gets the inferred type `call` at `path` for position `pos` in context `ctx`. */
2321+
/**
2322+
* Gets the inferred type of `call` at `path` and position `pos` in context `ctx`.
2323+
*
2324+
* By default, this is the inferred type of the node at the given position, but
2325+
* in for example Rust, the inferred type of the receiver of a method call needs
2326+
* to take the call context into account, in order to use the correct candidate
2327+
* receiver type.
2328+
*/
23012329
bindingset[ctx]
23022330
default Type inferCallTypeIn(
23032331
Call call, CallResolutionContext ctx, TypePosition pos, TypePath path
@@ -2529,9 +2557,9 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
25292557
path1.isEmpty() and
25302558
path2.isEmpty() and
25312559
(
2532-
exists(Assignment a |
2533-
a.getLeftOperand() = n1 and
2534-
a.getRightOperand() = n2
2560+
exists(AssignExpr ae |
2561+
ae.getLeftOperand() = n1 and
2562+
ae.getRightOperand() = n2
25352563
)
25362564
or
25372565
exists(LetDeclaration let |
@@ -2594,22 +2622,16 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
25942622
}
25952623

25962624
pragma[nomagic]
2597-
private Type inferTypeFromReverseLubStepCand(AstNode n, TypePath path, TypePath prefix) {
2625+
private Type inferTypeFromLubStepTopDown(AstNode n, TypePath path) {
25982626
exists(TypePath prefix1, AstNode n2, TypePath prefix2, TypePath suffix |
25992627
result = inferType(n2, prefix2.appendInverse(suffix)) and
26002628
path = prefix1.append(suffix) and
2601-
lubStep(n, prefix1, n2, prefix2) and
2602-
prefix = path.getAPrefix()
2629+
lubStep(n, prefix1, n2, prefix2)
26032630
)
26042631
}
26052632

2606-
pragma[nomagic]
2607-
private Type inferTypeFromReverseLub(AstNode n, TypePath path) {
2608-
exists(TypePath prefix |
2609-
result = inferTypeFromReverseLubStepCand(n, path, prefix) and
2610-
hasUnknownTypeAt(n, prefix)
2611-
)
2612-
}
2633+
private predicate inferTypeFromReverseLub =
2634+
TopDownTyping<inferTypeFromLubStepTopDown/2>::inferType/2;
26132635

26142636
/**
26152637
* Gets the inferred type of `n` at `path`.
@@ -2678,22 +2700,46 @@ module Make1<LocationSig Location, InputSig1<Location> Input1> {
26782700

26792701
private module CallMatching = MatchingWithEnvironment<CallMatchingInput>;
26802702

2681-
pragma[nomagic]
26822703
Type inferCallTypeOut(
2683-
Call call, TypePosition pos, AstNode n, CallResolutionContext ctx, TypePath path
2704+
Call call, CallResolutionContext ctx, TypePosition pos, AstNode n, TypePath path
26842705
) {
26852706
n = call.getNodeAt(pos) and
26862707
result = CallMatching::inferAccessType(call, ctx, pos, path)
26872708
}
26882709

26892710
pragma[nomagic]
2690-
predicate hasUnknownTypeAt(AstNode n, TypePath path) {
2711+
private predicate hasUnknownTypeAt(AstNode n, TypePath path) {
26912712
inferType(n, path) instanceof UnknownType
26922713
}
26932714

26942715
pragma[nomagic]
26952716
private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) }
26962717

2718+
signature Type inferTypeTopDownSig(AstNode n, TypePath path);
2719+
2720+
/**
2721+
* Given a predicate `inferTypeTopDown` for inferring the type of an AST node `n`
2722+
* top-down from a context, this module exposes the predicate `inferType`, which
2723+
* restricts type information to only flow top-down into `n` when `n` has an
2724+
* explicit unknown type.
2725+
*/
2726+
module TopDownTyping<inferTypeTopDownSig/2 inferTypeTopDown> {
2727+
pragma[nomagic]
2728+
private Type inferTypeTopDown(AstNode n, TypePath prefix, TypePath path) {
2729+
result = inferTypeTopDown(n, path) and
2730+
hasUnknownType(n) and
2731+
prefix = path.getAPrefix()
2732+
}
2733+
2734+
pragma[nomagic]
2735+
Type inferType(AstNode n, TypePath path) {
2736+
exists(TypePath prefix |
2737+
result = inferTypeTopDown(n, prefix, path) and
2738+
hasUnknownTypeAt(n, prefix)
2739+
)
2740+
}
2741+
}
2742+
26972743
signature Type inferCallTypeSig(AstNode n, TypePosition pos, TypePath path);
26982744

26992745
/**

0 commit comments

Comments
 (0)