From 3fa4ea4da3bb616767f9e2c049253d6b505e8541 Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Mon, 19 May 2025 21:31:18 +0200 Subject: [PATCH 1/2] Rust: Improve performance of type inference --- .../codeql/rust/internal/TypeInference.qll | 41 +++++++--- .../typeinference/internal/TypeInference.qll | 82 +++++++++++++------ 2 files changed, 88 insertions(+), 35 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/TypeInference.qll index 278d9ebc3176..5ce64b52d681 100644 --- a/rust/ql/lib/codeql/rust/internal/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/TypeInference.qll @@ -213,13 +213,6 @@ private predicate typeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath path1 = path2 ) or - n2 = - any(PrefixExpr pe | - pe.getOperatorName() = "*" and - pe.getExpr() = n1 and - path1 = TypePath::cons(TRefTypeParameter(), path2) - ) - or n1 = n2.(ParenExpr).getExpr() and path1 = path2 or @@ -239,12 +232,36 @@ private predicate typeEquality(AstNode n1, TypePath path1, AstNode n2, TypePath ) } +bindingset[path1] +private predicate typeEqualityLeft(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + typeEquality(n1, path1, n2, path2) + or + n2 = + any(PrefixExpr pe | + pe.getOperatorName() = "*" and + pe.getExpr() = n1 and + path1 = TypePath::consInverse(TRefTypeParameter(), path2) + ) +} + +bindingset[path2] +private predicate typeEqualityRight(AstNode n1, TypePath path1, AstNode n2, TypePath path2) { + typeEquality(n1, path1, n2, path2) + or + n2 = + any(PrefixExpr pe | + pe.getOperatorName() = "*" and + pe.getExpr() = n1 and + path1 = TypePath::cons(TRefTypeParameter(), path2) + ) +} + pragma[nomagic] private Type inferTypeEquality(AstNode n, TypePath path) { exists(AstNode n2, TypePath path2 | result = inferType(n2, path2) | - typeEquality(n, path, n2, path2) + typeEqualityRight(n, path, n2, path2) or - typeEquality(n2, path2, n, path) + typeEqualityLeft(n2, path2, n, path) ) } @@ -909,7 +926,7 @@ private Type inferRefExprType(Expr e, TypePath path) { e = re.getExpr() and exists(TypePath exprPath, TypePath refPath, Type exprType | result = inferType(re, exprPath) and - exprPath = TypePath::cons(TRefTypeParameter(), refPath) and + exprPath = TypePath::consInverse(TRefTypeParameter(), refPath) and exprType = inferType(e) | if exprType = TRefType() @@ -924,7 +941,7 @@ private Type inferRefExprType(Expr e, TypePath path) { pragma[nomagic] private Type inferTryExprType(TryExpr te, TypePath path) { exists(TypeParam tp | - result = inferType(te.getExpr(), TypePath::cons(TTypeParamTypeParameter(tp), path)) + result = inferType(te.getExpr(), TypePath::consInverse(TTypeParamTypeParameter(tp), path)) | tp = any(ResultEnum r).getGenericParamList().getGenericParam(0) or @@ -1000,7 +1017,7 @@ private module Cached { pragma[nomagic] Type getTypeAt(TypePath path) { exists(TypePath path0 | result = inferType(this, path0) | - path0 = TypePath::cons(TRefTypeParameter(), path) + path0 = TypePath::consInverse(TRefTypeParameter(), path) or not path0.isCons(TRefTypeParameter(), _) and not (path0.isEmpty() and result = TRefType()) and diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index 1bce43c436bd..fa475be575f7 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -181,18 +181,29 @@ module Make1 Input1> { /** Holds if this type path is empty. */ predicate isEmpty() { this = "" } + /** Gets the length of this path, assuming the length is at least 2. */ + bindingset[this] + pragma[inline_late] + private int length2() { + // Same as + // `result = strictcount(this.indexOf(".")) + 1` + // but performs better because it doesn't use an aggregate + result = this.regexpReplaceAll("[0-9]+", "").length() + 1 + } + /** Gets the length of this path. */ bindingset[this] pragma[inline_late] int length() { - this.isEmpty() and result = 0 - or - result = strictcount(this.indexOf(".")) + 1 + if this.isEmpty() + then result = 0 + else + if exists(TypeParameter::decode(this)) + then result = 1 + else result = this.length2() } /** Gets the path obtained by appending `suffix` onto this path. */ - bindingset[suffix, result] - bindingset[this, result] bindingset[this, suffix] TypePath append(TypePath suffix) { if this.isEmpty() @@ -202,22 +213,37 @@ module Make1 Input1> { then result = this else ( result = this + "." + suffix and - not result.length() > getTypePathLimit() + ( + not exists(getTypePathLimit()) + or + result.length2() <= getTypePathLimit() + ) + ) + } + + /** + * Gets the path obtained by appending `suffix` onto this path. + * + * Unlike `append`, this predicate has `result` in the binding set, + * so there is no need to check the length of `result`. + */ + bindingset[this, result] + TypePath appendInverse(TypePath suffix) { + if result.isEmpty() + then this.isEmpty() and suffix.isEmpty() + else + if this.isEmpty() + then suffix = result + else ( + result = this and suffix.isEmpty() + or + result = this + "." + suffix ) } /** Holds if this path starts with `tp`, followed by `suffix`. */ bindingset[this] - predicate isCons(TypeParameter tp, TypePath suffix) { - tp = TypeParameter::decode(this) and - suffix.isEmpty() - or - exists(int first | - first = min(this.indexOf(".")) and - suffix = this.suffix(first + 1) and - tp = TypeParameter::decode(this.prefix(first)) - ) - } + predicate isCons(TypeParameter tp, TypePath suffix) { this = TypePath::consInverse(tp, suffix) } } /** Provides predicates for constructing `TypePath`s. */ @@ -232,9 +258,17 @@ module Make1 Input1> { * Gets the type path obtained by appending the singleton type path `tp` * onto `suffix`. */ - bindingset[result] bindingset[suffix] TypePath cons(TypeParameter tp, TypePath suffix) { result = singleton(tp).append(suffix) } + + /** + * Gets the type path obtained by appending the singleton type path `tp` + * onto `suffix`. + */ + bindingset[result] + TypePath consInverse(TypeParameter tp, TypePath suffix) { + result = singleton(tp).appendInverse(suffix) + } } /** @@ -556,7 +590,7 @@ module Make1 Input1> { TypeMention tm1, TypeMention tm2, TypeParameter tp, TypePath path, Type t ) { exists(TypePath prefix | - tm2.resolveTypeAt(prefix) = tp and t = tm1.resolveTypeAt(prefix.append(path)) + tm2.resolveTypeAt(prefix) = tp and t = tm1.resolveTypeAt(prefix.appendInverse(path)) ) } @@ -899,7 +933,7 @@ module Make1 Input1> { exists(AccessPosition apos, DeclarationPosition dpos, TypePath pathToTypeParam | tp = target.getDeclaredType(dpos, pathToTypeParam) and accessDeclarationPositionMatch(apos, dpos) and - adjustedAccessType(a, apos, target, pathToTypeParam.append(path), t) + adjustedAccessType(a, apos, target, pathToTypeParam.appendInverse(path), t) ) } @@ -998,7 +1032,9 @@ module Make1 Input1> { RelevantAccess() { this = MkRelevantAccess(a, apos, path) } - Type getTypeAt(TypePath suffix) { a.getInferredType(apos, path.append(suffix)) = result } + Type getTypeAt(TypePath suffix) { + a.getInferredType(apos, path.appendInverse(suffix)) = result + } /** Holds if this relevant access has the type `type` and should satisfy `constraint`. */ predicate hasTypeConstraint(Type type, Type constraint) { @@ -1077,7 +1113,7 @@ module Make1 Input1> { t0 = abs.getATypeParameter() and exists(TypePath path3, TypePath suffix | sub.resolveTypeAt(path3) = t0 and - at.getTypeAt(path3.append(suffix)) = t and + at.getTypeAt(path3.appendInverse(suffix)) = t and path = prefix0.append(suffix) ) ) @@ -1149,7 +1185,7 @@ module Make1 Input1> { not exists(getTypeArgument(a, target, tp, _)) and target = a.getTarget() and exists(AccessPosition apos, DeclarationPosition dpos, Type base, TypePath pathToTypeParam | - accessBaseType(a, apos, base, pathToTypeParam.append(path), t) and + accessBaseType(a, apos, base, pathToTypeParam.appendInverse(path), t) and declarationBaseType(target, dpos, base, pathToTypeParam, tp) and accessDeclarationPositionMatch(apos, dpos) ) @@ -1217,7 +1253,7 @@ module Make1 Input1> { typeParameterConstraintHasTypeParameter(target, dpos, pathToTp2, _, constraint, pathToTp, tp) and AccessConstraint::satisfiesConstraintTypeMention(a, apos, pathToTp2, constraint, - pathToTp.append(path), t) + pathToTp.appendInverse(path), t) ) } From 13861b81a850a4c32f3377d148ed3f133df0e799 Mon Sep 17 00:00:00 2001 From: Tom Hvitved Date: Wed, 21 May 2025 14:01:48 +0200 Subject: [PATCH 2/2] Address review comments --- .../codeql/rust/internal/TypeInference.qll | 11 ++--- .../typeinference/internal/TypeInference.qll | 43 ++++++++----------- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/rust/ql/lib/codeql/rust/internal/TypeInference.qll b/rust/ql/lib/codeql/rust/internal/TypeInference.qll index 5ce64b52d681..9bbd540f9a9d 100644 --- a/rust/ql/lib/codeql/rust/internal/TypeInference.qll +++ b/rust/ql/lib/codeql/rust/internal/TypeInference.qll @@ -240,7 +240,7 @@ private predicate typeEqualityLeft(AstNode n1, TypePath path1, AstNode n2, TypeP any(PrefixExpr pe | pe.getOperatorName() = "*" and pe.getExpr() = n1 and - path1 = TypePath::consInverse(TRefTypeParameter(), path2) + path1.isCons(TRefTypeParameter(), path2) ) } @@ -926,7 +926,7 @@ private Type inferRefExprType(Expr e, TypePath path) { e = re.getExpr() and exists(TypePath exprPath, TypePath refPath, Type exprType | result = inferType(re, exprPath) and - exprPath = TypePath::consInverse(TRefTypeParameter(), refPath) and + exprPath.isCons(TRefTypeParameter(), refPath) and exprType = inferType(e) | if exprType = TRefType() @@ -940,8 +940,9 @@ private Type inferRefExprType(Expr e, TypePath path) { pragma[nomagic] private Type inferTryExprType(TryExpr te, TypePath path) { - exists(TypeParam tp | - result = inferType(te.getExpr(), TypePath::consInverse(TTypeParamTypeParameter(tp), path)) + exists(TypeParam tp, TypePath path0 | + result = inferType(te.getExpr(), path0) and + path0.isCons(TTypeParamTypeParameter(tp), path) | tp = any(ResultEnum r).getGenericParamList().getGenericParam(0) or @@ -1017,7 +1018,7 @@ private module Cached { pragma[nomagic] Type getTypeAt(TypePath path) { exists(TypePath path0 | result = inferType(this, path0) | - path0 = TypePath::consInverse(TRefTypeParameter(), path) + path0.isCons(TRefTypeParameter(), path) or not path0.isCons(TRefTypeParameter(), _) and not (path0.isEmpty() and result = TRefType()) and diff --git a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll index fa475be575f7..4414bc74c0bd 100644 --- a/shared/typeinference/codeql/typeinference/internal/TypeInference.qll +++ b/shared/typeinference/codeql/typeinference/internal/TypeInference.qll @@ -184,7 +184,7 @@ module Make1 Input1> { /** Gets the length of this path, assuming the length is at least 2. */ bindingset[this] pragma[inline_late] - private int length2() { + private int lengthAtLeast2() { // Same as // `result = strictcount(this.indexOf(".")) + 1` // but performs better because it doesn't use an aggregate @@ -200,7 +200,7 @@ module Make1 Input1> { else if exists(TypeParameter::decode(this)) then result = 1 - else result = this.length2() + else result = this.lengthAtLeast2() } /** Gets the path obtained by appending `suffix` onto this path. */ @@ -216,7 +216,7 @@ module Make1 Input1> { ( not exists(getTypePathLimit()) or - result.length2() <= getTypePathLimit() + result.lengthAtLeast2() <= getTypePathLimit() ) ) } @@ -228,22 +228,26 @@ module Make1 Input1> { * so there is no need to check the length of `result`. */ bindingset[this, result] - TypePath appendInverse(TypePath suffix) { - if result.isEmpty() - then this.isEmpty() and suffix.isEmpty() - else - if this.isEmpty() - then suffix = result - else ( - result = this and suffix.isEmpty() - or - result = this + "." + suffix - ) + TypePath appendInverse(TypePath suffix) { suffix = result.stripPrefix(this) } + + /** Gets the path obtained by removing `prefix` from this path. */ + bindingset[this, prefix] + TypePath stripPrefix(TypePath prefix) { + if prefix.isEmpty() + then result = this + else ( + this = prefix and + result.isEmpty() + or + this = prefix + "." + result + ) } /** Holds if this path starts with `tp`, followed by `suffix`. */ bindingset[this] - predicate isCons(TypeParameter tp, TypePath suffix) { this = TypePath::consInverse(tp, suffix) } + predicate isCons(TypeParameter tp, TypePath suffix) { + suffix = this.stripPrefix(TypePath::singleton(tp)) + } } /** Provides predicates for constructing `TypePath`s. */ @@ -260,15 +264,6 @@ module Make1 Input1> { */ bindingset[suffix] TypePath cons(TypeParameter tp, TypePath suffix) { result = singleton(tp).append(suffix) } - - /** - * Gets the type path obtained by appending the singleton type path `tp` - * onto `suffix`. - */ - bindingset[result] - TypePath consInverse(TypeParameter tp, TypePath suffix) { - result = singleton(tp).appendInverse(suffix) - } } /**