Skip to content

Commit 15c4c30

Browse files
committed
wip7
1 parent 3696436 commit 15c4c30

3 files changed

Lines changed: 191 additions & 129 deletions

File tree

rust/ql/lib/codeql/rust/internal/typeinference/TypeInference.qll

Lines changed: 42 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,8 @@ private module Input3 implements InputSig3 {
286286
(exists(resolveTupleFieldExpr(_, _)) implies any())
287287
}
288288

289+
predicate inferType = M3::inferType/2;
290+
289291
class BoolType extends DataType {
290292
BoolType() { this.getTypeItem() instanceof Builtins::Bool }
291293
}
@@ -366,41 +368,54 @@ private module Input3 implements InputSig3 {
366368
override AstNode getRightOperand() { result = this.getInitializer() }
367369
}
368370

369-
class CallTarget extends FunctionCallMatchingInput::Declaration {
371+
class CallResolutionContext = FunctionCallMatchingInput::AccessEnvironment;
372+
373+
class TypePosition = FunctionPosition;
374+
375+
class Callable extends FunctionCallMatchingInput::Declaration {
370376
TypeMention getAdditionalTypeParameterConstraint(TypeParameter tp) {
371377
result =
372378
tp.(TypeParamTypeParameter)
373379
.getTypeParam()
374380
.getAdditionalTypeBound(this.getFunction(), _)
375381
.getTypeRepr()
376382
}
377-
378-
Type getReturnType(TypePath path) {
379-
exists(FunctionPosition pos |
380-
pos.isReturn() and
381-
result = super.getDeclaredType(pos, path)
382-
)
383-
}
384-
385-
Type getParameterType(int index, TypePath path) {
386-
none() // todo
387-
}
388383
}
389384

390385
class Call extends Expr instanceof FunctionCallMatchingInput::Access {
391386
Type getTypeArgument(TypeArgumentPosition apos, TypePath path) {
392387
result = super.getTypeArgument(apos, path)
393388
}
394389

390+
AstNode getNodeAt(TypePosition pos) { result = super.getNodeAt(pos) }
391+
395392
/** Gets the target of this call. */
396-
CallTarget getTargetCertain() {
393+
Callable getTargetCertain() {
397394
exists(ImplOrTraitItemNodeOption i, FunctionDeclaration f, Path p |
398395
result.isFunction(i, f) and
399396
p = CallExprImpl::getFunctionPath(this) and
400397
f = resolvePath(p) and
401398
f.isDirectlyFor(i)
402399
)
403400
}
401+
402+
Callable getTarget(string derefChainBorrow) { result = super.getTarget(derefChainBorrow) }
403+
}
404+
405+
bindingset[derefChainBorrow]
406+
Type inferCallTypeIn(Call call, string derefChainBorrow, FunctionPosition pos, TypePath path) {
407+
result = call.(FunctionCallMatchingInput::Access).getInferredType(derefChainBorrow, pos, path)
408+
}
409+
410+
Type inferCallTypeOut(AstNode n, TypePosition pos, TypePath path) {
411+
result = inferFunctionCallTypeNonSelf(n, pos, path)
412+
or
413+
exists(FunctionCallMatchingInput::Access a |
414+
result = inferFunctionCallTypeSelf(a, n, DerefChain::nil(), path) and
415+
if a.(AssocFunctionResolution::AssocFunctionCall).hasReceiver()
416+
then not path.isEmpty()
417+
else any()
418+
)
404419
}
405420

406421
predicate certainTypeEqualityInput(AstNode n1, TypePath prefix1, AstNode n2, TypePath prefix2) {
@@ -553,8 +568,6 @@ private module Input3 implements InputSig3 {
553568
Type inferTypeInput(AstNode n, TypePath path) {
554569
result = inferAssignmentOperationType(n, path)
555570
or
556-
result = inferFunctionCallType(n, path)
557-
or
558571
result = inferConstructionType(n, path)
559572
or
560573
result = inferOperationType(n, path)
@@ -1094,53 +1107,6 @@ private module ContextTyping {
10941107
)
10951108
}
10961109
}
1097-
1098-
pragma[nomagic]
1099-
private predicate hasUnknownTypeAt(AstNode n, TypePath path) {
1100-
inferType(n, path) = TUnknownType()
1101-
}
1102-
1103-
pragma[nomagic]
1104-
private predicate hasUnknownType(AstNode n) { hasUnknownTypeAt(n, _) }
1105-
1106-
newtype FunctionPositionKind =
1107-
SelfKind() or
1108-
ReturnKind() or
1109-
PositionalKind()
1110-
1111-
signature Type inferCallTypeSig(AstNode n, FunctionPositionKind kind, TypePath path);
1112-
1113-
/**
1114-
* Given a predicate `inferCallType` for inferring the type of a call at a given
1115-
* position, this module exposes the predicate `check`, which wraps the input
1116-
* predicate and checks that types are only propagated into arguments when they
1117-
* are context-typed.
1118-
*/
1119-
module CheckContextTyping<inferCallTypeSig/3 inferCallType> {
1120-
pragma[nomagic]
1121-
private Type inferCallNonReturnType(
1122-
AstNode n, FunctionPositionKind kind, TypePath prefix, TypePath path
1123-
) {
1124-
result = inferCallType(n, kind, path) and
1125-
hasUnknownType(n) and
1126-
kind != ReturnKind() and
1127-
prefix = path.getAPrefix()
1128-
}
1129-
1130-
pragma[nomagic]
1131-
Type check(AstNode n, TypePath path) {
1132-
result = inferCallType(n, ReturnKind(), path)
1133-
or
1134-
exists(FunctionPositionKind kind, TypePath prefix |
1135-
result = inferCallNonReturnType(n, kind, prefix, path) and
1136-
hasUnknownTypeAt(n, prefix)
1137-
|
1138-
// Never propagate type information directly into the receiver, since its type
1139-
// must already have been known in order to resolve the call
1140-
if kind = SelfKind() then not prefix.isEmpty() else any()
1141-
)
1142-
}
1143-
}
11441110
}
11451111

11461112
/**
@@ -2836,22 +2802,20 @@ private module FunctionCallMatchingInput implements MatchingWithEnvironmentInput
28362802
}
28372803
}
28382804

2839-
private module FunctionCallMatching = MatchingWithEnvironment<FunctionCallMatchingInput>;
2840-
28412805
pragma[nomagic]
28422806
private Type inferFunctionCallType0(
28432807
FunctionCallMatchingInput::Access call, FunctionPosition pos, AstNode n, DerefChain derefChain,
28442808
BorrowKind borrow, TypePath path
28452809
) {
28462810
exists(TypePath path0 |
2847-
n = call.getNodeAt(pos) and
28482811
exists(string derefChainBorrow |
28492812
FunctionCallMatchingInput::decodeDerefChainBorrow(derefChainBorrow, derefChain, borrow)
28502813
|
2851-
result = FunctionCallMatching::inferAccessType(call, derefChainBorrow, pos, path0)
2852-
or
2814+
n = call.getNodeAt(pos) and
28532815
call.hasUnknownTypeAt(derefChainBorrow, pos, path0) and
28542816
result = TUnknownType()
2817+
or
2818+
result = inferCallTypeOut(call, pos, n, derefChainBorrow, path0)
28552819
)
28562820
|
28572821
if
@@ -2919,31 +2883,6 @@ private Type inferFunctionCallTypeSelf(
29192883
)
29202884
}
29212885

2922-
private Type inferFunctionCallTypePreCheck(
2923-
AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path
2924-
) {
2925-
exists(FunctionPosition pos |
2926-
result = inferFunctionCallTypeNonSelf(n, pos, path) and
2927-
if pos.isPosition()
2928-
then kind = ContextTyping::PositionalKind()
2929-
else kind = ContextTyping::ReturnKind()
2930-
)
2931-
or
2932-
exists(FunctionCallMatchingInput::Access a |
2933-
result = inferFunctionCallTypeSelf(a, n, DerefChain::nil(), path) and
2934-
if a.(AssocFunctionResolution::AssocFunctionCall).hasReceiver()
2935-
then kind = ContextTyping::SelfKind()
2936-
else kind = ContextTyping::PositionalKind()
2937-
)
2938-
}
2939-
2940-
/**
2941-
* Gets the type of `n` at `path`, where `n` is either a function call or an
2942-
* argument/receiver of a function call.
2943-
*/
2944-
private predicate inferFunctionCallType =
2945-
ContextTyping::CheckContextTyping<inferFunctionCallTypePreCheck/3>::check/2;
2946-
29472886
abstract private class Constructor extends Addressable {
29482887
final TypeParameter getTypeParameter(TypeParameterPosition ppos) {
29492888
typeParamMatchPosition(this.getTypeItem().getGenericParamList().getATypeParam(), result, ppos)
@@ -3102,15 +3041,8 @@ private module ConstructionMatchingInput implements MatchingInputSig {
31023041
private module ConstructionMatching = Matching<ConstructionMatchingInput>;
31033042

31043043
pragma[nomagic]
3105-
private Type inferConstructionTypePreCheck(
3106-
AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path
3107-
) {
3108-
exists(ConstructionMatchingInput::Access a, FunctionPosition pos |
3109-
n = a.getNodeAt(pos) and
3110-
if pos.isPosition()
3111-
then kind = ContextTyping::PositionalKind()
3112-
else kind = ContextTyping::ReturnKind()
3113-
|
3044+
private Type inferConstructionTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) {
3045+
exists(ConstructionMatchingInput::Access a | n = a.getNodeAt(pos) |
31143046
result = ConstructionMatching::inferAccessType(a, pos, path)
31153047
or
31163048
a.hasUnknownTypeAt(pos, path) and
@@ -3119,7 +3051,7 @@ private Type inferConstructionTypePreCheck(
31193051
}
31203052

31213053
private predicate inferConstructionType =
3122-
ContextTyping::CheckContextTyping<inferConstructionTypePreCheck/3>::check/2;
3054+
CheckContextTyping<inferConstructionTypePreCheck/3>::check/2;
31233055

31243056
/**
31253057
* A matching configuration for resolving types of operations like `a + b`.
@@ -3184,23 +3116,15 @@ private module OperationMatchingInput implements MatchingInputSig {
31843116
private module OperationMatching = Matching<OperationMatchingInput>;
31853117

31863118
pragma[nomagic]
3187-
private Type inferOperationTypePreCheck(
3188-
AstNode n, ContextTyping::FunctionPositionKind kind, TypePath path
3189-
) {
3190-
exists(OperationMatchingInput::Access a, FunctionPosition pos |
3119+
private Type inferOperationTypePreCheck(AstNode n, FunctionPosition pos, TypePath path) {
3120+
exists(OperationMatchingInput::Access a |
31913121
n = a.getNodeAt(pos) and
31923122
result = OperationMatching::inferAccessType(a, pos, path) and
3193-
if pos.asPosition() = 0
3194-
then kind = ContextTyping::SelfKind()
3195-
else
3196-
if pos.isPosition()
3197-
then kind = ContextTyping::PositionalKind()
3198-
else kind = ContextTyping::ReturnKind()
3123+
if pos.asPosition() = 0 then not path.isEmpty() else any()
31993124
)
32003125
}
32013126

3202-
private predicate inferOperationType =
3203-
ContextTyping::CheckContextTyping<inferOperationTypePreCheck/3>::check/2;
3127+
private predicate inferOperationType = CheckContextTyping<inferOperationTypePreCheck/3>::check/2;
32043128

32053129
pragma[nomagic]
32063130
private Type getFieldExprLookupType(FieldExpr fe, string name, DerefChain derefChain) {
@@ -3815,11 +3739,10 @@ private module Debug {
38153739
t = self.getTypeAt(path)
38163740
}
38173741

3818-
predicate debugInferFunctionCallType(AstNode n, TypePath path, Type t) {
3819-
n = getRelevantLocatable() and
3820-
t = inferFunctionCallType(n, path)
3821-
}
3822-
3742+
// predicate debugInferFunctionCallType(AstNode n, TypePath path, Type t) {
3743+
// n = getRelevantLocatable() and
3744+
// t = inferFunctionCallType(n, path)
3745+
// }
38233746
predicate debugInferConstructionType(AstNode n, TypePath path, Type t) {
38243747
n = getRelevantLocatable() and
38253748
t = inferConstructionType(n, path)

rust/ql/test/library-tests/type-inference/type-inference.expected

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10252,6 +10252,7 @@ inferType
1025210252
| main.rs:1412:17:1412:20 | self | TRef.TSlice | main.rs:1410:14:1410:23 | T |
1025310253
| main.rs:1412:17:1412:27 | self.get(...) | | {EXTERNAL LOCATION} | Option |
1025410254
| main.rs:1412:17:1412:27 | self.get(...) | T | {EXTERNAL LOCATION} | & |
10255+
| main.rs:1412:17:1412:27 | self.get(...) | T.TRef | main.rs:1410:14:1410:23 | T |
1025510256
| main.rs:1412:17:1412:36 | ... .unwrap() | | {EXTERNAL LOCATION} | & |
1025610257
| main.rs:1412:17:1412:36 | ... .unwrap() | TRef | main.rs:1410:14:1410:23 | T |
1025710258
| main.rs:1412:26:1412:26 | 0 | | {EXTERNAL LOCATION} | i32 |
@@ -11600,6 +11601,8 @@ inferType
1160011601
| main.rs:2221:18:2221:21 | true | | {EXTERNAL LOCATION} | bool |
1160111602
| main.rs:2223:9:2223:15 | S(...) | | main.rs:2107:5:2107:19 | S |
1160211603
| main.rs:2223:9:2223:15 | S(...) | T | {EXTERNAL LOCATION} | i64 |
11604+
| main.rs:2223:9:2223:15 | S(...) | T | main.rs:2107:5:2107:19 | S |
11605+
| main.rs:2223:9:2223:15 | S(...) | T.T | {EXTERNAL LOCATION} | i64 |
1160311606
| main.rs:2223:9:2223:31 | ... .my_add(...) | | main.rs:2107:5:2107:19 | S |
1160411607
| main.rs:2223:9:2223:31 | ... .my_add(...) | T | {EXTERNAL LOCATION} | i64 |
1160511608
| main.rs:2223:9:2223:31 | ... .my_add(...) | T | main.rs:2107:5:2107:19 | S |
@@ -11618,6 +11621,8 @@ inferType
1161811621
| main.rs:2224:24:2224:27 | 3i64 | | {EXTERNAL LOCATION} | i64 |
1161911622
| main.rs:2225:9:2225:15 | S(...) | | main.rs:2107:5:2107:19 | S |
1162011623
| main.rs:2225:9:2225:15 | S(...) | T | {EXTERNAL LOCATION} | i64 |
11624+
| main.rs:2225:9:2225:15 | S(...) | T | {EXTERNAL LOCATION} | & |
11625+
| main.rs:2225:9:2225:15 | S(...) | T.TRef | {EXTERNAL LOCATION} | i64 |
1162111626
| main.rs:2225:9:2225:29 | ... .my_add(...) | | main.rs:2107:5:2107:19 | S |
1162211627
| main.rs:2225:9:2225:29 | ... .my_add(...) | T | {EXTERNAL LOCATION} | i64 |
1162311628
| main.rs:2225:11:2225:14 | 1i64 | | {EXTERNAL LOCATION} | i64 |

0 commit comments

Comments
 (0)