Skip to content

Rust: Type inference uses defaults for type parameters #19756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ abstract class Type extends TType {
/** Gets the `i`th type parameter of this type, if any. */
abstract TypeParameter getTypeParameter(int i);

/** Gets the default type for the `i`th type parameter, if any. */
TypeMention getTypeParameterDefault(int i) { none() }

/** Gets a type parameter of this type. */
final TypeParameter getATypeParameter() { result = this.getTypeParameter(_) }

Expand Down Expand Up @@ -87,6 +90,10 @@ class StructType extends StructOrEnumType, TStruct {
result = TTypeParamTypeParameter(struct.getGenericParamList().getTypeParam(i))
}

override TypeMention getTypeParameterDefault(int i) {
result = struct.getGenericParamList().getTypeParam(i).getDefaultType()
}

override string toString() { result = struct.getName().getText() }

override Location getLocation() { result = struct.getLocation() }
Expand All @@ -108,6 +115,10 @@ class EnumType extends StructOrEnumType, TEnum {
result = TTypeParamTypeParameter(enum.getGenericParamList().getTypeParam(i))
}

override TypeMention getTypeParameterDefault(int i) {
result = enum.getGenericParamList().getTypeParam(i).getDefaultType()
}

override string toString() { result = enum.getName().getText() }

override Location getLocation() { result = enum.getLocation() }
Expand All @@ -133,6 +144,10 @@ class TraitType extends Type, TTrait {
any(AssociatedTypeTypeParameter param | param.getTrait() = trait and param.getIndex() = i)
}

override TypeMention getTypeParameterDefault(int i) {
result = trait.getGenericParamList().getTypeParam(i).getDefaultType()
}

override string toString() { result = trait.toString() }

override Location getLocation() { result = trait.getLocation() }
Expand Down
5 changes: 5 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -1468,6 +1468,11 @@ private module Debug {
result = resolveMethodCallTarget(mce)
}

predicate debugTypeMention(TypeMention tm, TypePath path, Type type) {
tm = getRelevantLocatable() and
tm.resolveTypeAt(path) = type
}

pragma[nomagic]
private int countTypes(AstNode n, TypePath path, Type t) {
t = inferType(n, path) and
Expand Down
5 changes: 5 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeMention.qll
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ class PathTypeReprMention extends TypeMention instanceof PathTypeRepr {
override TypeMention getTypeArgument(int i) {
result = path.getSegment().getGenericArgList().getTypeArg(i)
or
// If a type argument is not given in the path, then we use the default for
// the type parameter if one exists for the type.
not exists(path.getSegment().getGenericArgList().getTypeArg(i)) and
result = this.resolveType().getTypeParameterDefault(i)
or
// `Self` paths inside `impl` blocks have implicit type arguments that are
// the type parameters of the `impl` block. For example, in
//
Expand Down
27 changes: 23 additions & 4 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mod field_access {
}

#[derive(Debug)]
struct GenericThing<A> {
struct GenericThing<A = bool> {
a: A,
}

Expand All @@ -27,6 +27,11 @@ mod field_access {
println!("{:?}", x.a); // $ fieldof=MyThing
}

fn default_field_access(x: GenericThing) {
let a = x.a; // $ fieldof=GenericThing type=a:bool
println!("{:?}", a);
}

fn generic_field_access() {
// Explicit type argument
let x = GenericThing::<S> { a: S }; // $ type=x:A.S
Expand Down Expand Up @@ -472,16 +477,16 @@ mod type_parameter_bounds {
println!("{:?}", s); // $ type=s:S1
}

trait Pair<P1, P2> {
trait Pair<P1 = bool, P2 = i64> {
fn fst(self) -> P1;

fn snd(self) -> P2;
}

fn call_trait_per_bound_with_type_1<T: Pair<S1, S2>>(x: T, y: T) {
// The type in the type parameter bound determines the return type.
let s1 = x.fst(); // $ method=fst
let s2 = y.snd(); // $ method=snd
let s1 = x.fst(); // $ method=fst type=s1:S1
let s2 = y.snd(); // $ method=snd type=s2:S2
println!("{:?}, {:?}", s1, s2);
}

Expand All @@ -491,6 +496,20 @@ mod type_parameter_bounds {
let s2 = y.snd(); // $ method=snd
println!("{:?}, {:?}", s1, s2);
}

fn call_trait_per_bound_with_type_3<T: Pair>(x: T, y: T) {
// The type in the type parameter bound determines the return type.
let s1 = x.fst(); // $ method=fst type=s1:bool
let s2 = y.snd(); // $ method=snd type=s2:i64
println!("{:?}, {:?}", s1, s2);
}

fn call_trait_per_bound_with_type_4<T: Pair<u8>>(x: T, y: T) {
// The type in the type parameter bound determines the return type.
let s1 = x.fst(); // $ method=fst type=s1:u8
let s2 = y.snd(); // $ method=snd type=s2:i64
println!("{:?}, {:?}", s1, s2);
}
}

mod function_trait_bounds {
Expand Down
Loading