Skip to content

Commit 58b44ff

Browse files
Merge pull request #20125 from nicolas-guichard/push-pypzwzspzznu
Use inferred type in “extract type as type alias” assist and display inferred type placeholder `_` inlay hints
2 parents cf4b1fa + 9500624 commit 58b44ff

File tree

9 files changed

+299
-10
lines changed

9 files changed

+299
-10
lines changed

crates/hir-def/src/hir/type_ref.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,16 @@ impl TypeRef {
195195
TypeRef::Tuple(ThinVec::new())
196196
}
197197

198-
pub fn walk(this: TypeRefId, map: &ExpressionStore, f: &mut impl FnMut(&TypeRef)) {
198+
pub fn walk(this: TypeRefId, map: &ExpressionStore, f: &mut impl FnMut(TypeRefId, &TypeRef)) {
199199
go(this, f, map);
200200

201-
fn go(type_ref: TypeRefId, f: &mut impl FnMut(&TypeRef), map: &ExpressionStore) {
202-
let type_ref = &map[type_ref];
203-
f(type_ref);
201+
fn go(
202+
type_ref_id: TypeRefId,
203+
f: &mut impl FnMut(TypeRefId, &TypeRef),
204+
map: &ExpressionStore,
205+
) {
206+
let type_ref = &map[type_ref_id];
207+
f(type_ref_id, type_ref);
204208
match type_ref {
205209
TypeRef::Fn(fn_) => {
206210
fn_.params.iter().for_each(|&(_, param_type)| go(param_type, f, map))
@@ -224,7 +228,7 @@ impl TypeRef {
224228
};
225229
}
226230

227-
fn go_path(path: &Path, f: &mut impl FnMut(&TypeRef), map: &ExpressionStore) {
231+
fn go_path(path: &Path, f: &mut impl FnMut(TypeRefId, &TypeRef), map: &ExpressionStore) {
228232
if let Some(type_ref) = path.type_anchor() {
229233
go(type_ref, f, map);
230234
}

crates/hir-ty/src/infer.rs

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use hir_def::{
4141
layout::Integer,
4242
resolver::{HasResolver, ResolveValueResult, Resolver, TypeNs, ValueNs},
4343
signatures::{ConstSignature, StaticSignature},
44-
type_ref::{ConstRef, LifetimeRefId, TypeRefId},
44+
type_ref::{ConstRef, LifetimeRefId, TypeRef, TypeRefId},
4545
};
4646
use hir_expand::{mod_path::ModPath, name::Name};
4747
use indexmap::IndexSet;
@@ -60,6 +60,7 @@ use triomphe::Arc;
6060

6161
use crate::{
6262
ImplTraitId, IncorrectGenericsLenKind, PathLoweringDiagnostic, TargetFeatures,
63+
collect_type_inference_vars,
6364
db::{HirDatabase, InternedClosureId, InternedOpaqueTyId},
6465
infer::{
6566
coerce::{CoerceMany, DynamicCoerceMany},
@@ -497,6 +498,7 @@ pub struct InferenceResult<'db> {
497498
/// unresolved or missing subpatterns or subpatterns of mismatched types.
498499
pub(crate) type_of_pat: ArenaMap<PatId, Ty<'db>>,
499500
pub(crate) type_of_binding: ArenaMap<BindingId, Ty<'db>>,
501+
pub(crate) type_of_type_placeholder: ArenaMap<TypeRefId, Ty<'db>>,
500502
pub(crate) type_of_opaque: FxHashMap<InternedOpaqueTyId, Ty<'db>>,
501503
pub(crate) type_mismatches: FxHashMap<ExprOrPatId, TypeMismatch<'db>>,
502504
/// Whether there are any type-mismatching errors in the result.
@@ -542,6 +544,7 @@ impl<'db> InferenceResult<'db> {
542544
type_of_expr: Default::default(),
543545
type_of_pat: Default::default(),
544546
type_of_binding: Default::default(),
547+
type_of_type_placeholder: Default::default(),
545548
type_of_opaque: Default::default(),
546549
type_mismatches: Default::default(),
547550
has_errors: Default::default(),
@@ -606,6 +609,12 @@ impl<'db> InferenceResult<'db> {
606609
_ => None,
607610
})
608611
}
612+
pub fn placeholder_types(&self) -> impl Iterator<Item = (TypeRefId, &Ty<'db>)> {
613+
self.type_of_type_placeholder.iter()
614+
}
615+
pub fn type_of_type_placeholder(&self, type_ref: TypeRefId) -> Option<Ty<'db>> {
616+
self.type_of_type_placeholder.get(type_ref).copied()
617+
}
609618
pub fn closure_info(&self, closure: InternedClosureId) -> &(Vec<CapturedItem<'db>>, FnTrait) {
610619
self.closure_info.get(&closure).unwrap()
611620
}
@@ -1014,6 +1023,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
10141023
type_of_expr,
10151024
type_of_pat,
10161025
type_of_binding,
1026+
type_of_type_placeholder,
10171027
type_of_opaque,
10181028
type_mismatches,
10191029
has_errors,
@@ -1046,6 +1056,11 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
10461056
*has_errors = *has_errors || ty.references_non_lt_error();
10471057
}
10481058
type_of_binding.shrink_to_fit();
1059+
for ty in type_of_type_placeholder.values_mut() {
1060+
*ty = table.resolve_completely(*ty);
1061+
*has_errors = *has_errors || ty.references_non_lt_error();
1062+
}
1063+
type_of_type_placeholder.shrink_to_fit();
10491064
type_of_opaque.shrink_to_fit();
10501065

10511066
*has_errors |= !type_mismatches.is_empty();
@@ -1285,6 +1300,10 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
12851300
self.result.type_of_pat.insert(pat, ty);
12861301
}
12871302

1303+
fn write_type_placeholder_ty(&mut self, type_ref: TypeRefId, ty: Ty<'db>) {
1304+
self.result.type_of_type_placeholder.insert(type_ref, ty);
1305+
}
1306+
12881307
fn write_binding_ty(&mut self, id: BindingId, ty: Ty<'db>) {
12891308
self.result.type_of_binding.insert(id, ty);
12901309
}
@@ -1333,7 +1352,27 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
13331352
) -> Ty<'db> {
13341353
let ty = self
13351354
.with_ty_lowering(store, type_source, lifetime_elision, |ctx| ctx.lower_ty(type_ref));
1336-
self.process_user_written_ty(ty)
1355+
let ty = self.process_user_written_ty(ty);
1356+
1357+
// Record the association from placeholders' TypeRefId to type variables.
1358+
// We only record them if their number matches. This assumes TypeRef::walk and TypeVisitable process the items in the same order.
1359+
let type_variables = collect_type_inference_vars(&ty);
1360+
let mut placeholder_ids = vec![];
1361+
TypeRef::walk(type_ref, store, &mut |type_ref_id, type_ref| {
1362+
if matches!(type_ref, TypeRef::Placeholder) {
1363+
placeholder_ids.push(type_ref_id);
1364+
}
1365+
});
1366+
1367+
if placeholder_ids.len() == type_variables.len() {
1368+
for (placeholder_id, type_variable) in
1369+
placeholder_ids.into_iter().zip(type_variables.into_iter())
1370+
{
1371+
self.write_type_placeholder_ty(placeholder_id, type_variable);
1372+
}
1373+
}
1374+
1375+
ty
13371376
}
13381377

13391378
pub(crate) fn make_body_ty(&mut self, type_ref: TypeRefId) -> Ty<'db> {

crates/hir-ty/src/lib.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,35 @@ where
569569
Vec::from_iter(collector.params)
570570
}
571571

572+
struct TypeInferenceVarCollector<'db> {
573+
type_inference_vars: Vec<Ty<'db>>,
574+
}
575+
576+
impl<'db> rustc_type_ir::TypeVisitor<DbInterner<'db>> for TypeInferenceVarCollector<'db> {
577+
type Result = ();
578+
579+
fn visit_ty(&mut self, ty: Ty<'db>) -> Self::Result {
580+
use crate::rustc_type_ir::Flags;
581+
if ty.is_ty_var() {
582+
self.type_inference_vars.push(ty);
583+
} else if ty.flags().intersects(rustc_type_ir::TypeFlags::HAS_TY_INFER) {
584+
ty.super_visit_with(self);
585+
} else {
586+
// Fast path: don't visit inner types (e.g. generic arguments) when `flags` indicate
587+
// that there are no placeholders.
588+
}
589+
}
590+
}
591+
592+
pub fn collect_type_inference_vars<'db, T>(value: &T) -> Vec<Ty<'db>>
593+
where
594+
T: ?Sized + rustc_type_ir::TypeVisitable<DbInterner<'db>>,
595+
{
596+
let mut collector = TypeInferenceVarCollector { type_inference_vars: vec![] };
597+
value.visit_with(&mut collector);
598+
collector.type_inference_vars
599+
}
600+
572601
pub fn known_const_to_ast<'db>(
573602
konst: Const<'db>,
574603
db: &'db dyn HirDatabase,

crates/hir-ty/src/tests.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use hir_def::{
2323
item_scope::ItemScope,
2424
nameres::DefMap,
2525
src::HasSource,
26+
type_ref::TypeRefId,
2627
};
2728
use hir_expand::{FileRange, InFile, db::ExpandDatabase};
2829
use itertools::Itertools;
@@ -219,6 +220,24 @@ fn check_impl(
219220
}
220221
}
221222
}
223+
224+
for (type_ref, ty) in inference_result.placeholder_types() {
225+
let node = match type_node(&body_source_map, type_ref, &db) {
226+
Some(value) => value,
227+
None => continue,
228+
};
229+
let range = node.as_ref().original_file_range_rooted(&db);
230+
if let Some(expected) = types.remove(&range) {
231+
let actual = salsa::attach(&db, || {
232+
if display_source {
233+
ty.display_source_code(&db, def.module(&db), true).unwrap()
234+
} else {
235+
ty.display_test(&db, display_target).to_string()
236+
}
237+
});
238+
assert_eq!(actual, expected, "type annotation differs at {:#?}", range.range);
239+
}
240+
}
222241
}
223242

224243
let mut buf = String::new();
@@ -275,6 +294,20 @@ fn pat_node(
275294
})
276295
}
277296

297+
fn type_node(
298+
body_source_map: &BodySourceMap,
299+
type_ref: TypeRefId,
300+
db: &TestDB,
301+
) -> Option<InFile<SyntaxNode>> {
302+
Some(match body_source_map.type_syntax(type_ref) {
303+
Ok(sp) => {
304+
let root = db.parse_or_expand(sp.file_id);
305+
sp.map(|ptr| ptr.to_node(&root).syntax().clone())
306+
}
307+
Err(SyntheticSyntax) => return None,
308+
})
309+
}
310+
278311
fn infer(#[rust_analyzer::rust_fixture] ra_fixture: &str) -> String {
279312
infer_with_mismatches(ra_fixture, false)
280313
}

crates/hir-ty/src/tests/display_source_code.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,22 @@ fn test() {
246246
"#,
247247
);
248248
}
249+
250+
#[test]
251+
fn type_placeholder_type() {
252+
check_types_source_code(
253+
r#"
254+
struct S<T>(T);
255+
fn test() {
256+
let f: S<_> = S(3);
257+
//^ i32
258+
let f: [_; _] = [4_u32, 5, 6];
259+
//^ u32
260+
let f: (_, _, _) = (1_u32, 1_i32, false);
261+
//^ u32
262+
//^ i32
263+
//^ bool
264+
}
265+
"#,
266+
);
267+
}

crates/hir/src/source_analyzer.rs

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use hir_def::{
2121
lang_item::LangItem,
2222
nameres::MacroSubNs,
2323
resolver::{HasResolver, Resolver, TypeNs, ValueNs, resolver_for_scope},
24-
type_ref::{Mutability, TypeRefId},
24+
type_ref::{Mutability, TypeRef, TypeRefId},
2525
};
2626
use hir_expand::{
2727
HirFileId, InFile,
@@ -267,8 +267,11 @@ impl<'db> SourceAnalyzer<'db> {
267267
db: &'db dyn HirDatabase,
268268
ty: &ast::Type,
269269
) -> Option<Type<'db>> {
270+
let interner = DbInterner::new_with(db, None, None);
271+
270272
let type_ref = self.type_id(ty)?;
271-
let ty = TyLoweringContext::new(
273+
274+
let mut ty = TyLoweringContext::new(
272275
db,
273276
&self.resolver,
274277
self.store()?,
@@ -279,6 +282,31 @@ impl<'db> SourceAnalyzer<'db> {
279282
LifetimeElisionKind::Infer,
280283
)
281284
.lower_ty(type_ref);
285+
286+
// Try and substitute unknown types using InferenceResult
287+
if let Some(infer) = self.infer()
288+
&& let Some(store) = self.store()
289+
{
290+
let mut inferred_types = vec![];
291+
TypeRef::walk(type_ref, store, &mut |type_ref_id, type_ref| {
292+
if matches!(type_ref, TypeRef::Placeholder) {
293+
inferred_types.push(infer.type_of_type_placeholder(type_ref_id));
294+
}
295+
});
296+
let mut inferred_types = inferred_types.into_iter();
297+
298+
let substituted_ty = hir_ty::next_solver::fold::fold_tys(interner, ty, |ty| {
299+
if ty.is_ty_error() { inferred_types.next().flatten().unwrap_or(ty) } else { ty }
300+
});
301+
302+
// Only used the result if the placeholder and unknown type counts matched
303+
let success =
304+
inferred_types.next().is_none() && !substituted_ty.references_non_lt_error();
305+
if success {
306+
ty = substituted_ty;
307+
}
308+
}
309+
282310
Some(Type::new_with_resolver(db, &self.resolver, ty))
283311
}
284312

crates/ide-assists/src/handlers/extract_type_alias.rs

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use either::Either;
2+
use hir::HirDisplay;
23
use ide_db::syntax_helpers::node_ext::walk_ty;
34
use syntax::{
45
ast::{self, AstNode, HasGenericArgs, HasGenericParams, HasName, edit::IndentLevel, make},
@@ -39,6 +40,15 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) ->
3940
);
4041
let target = ty.syntax().text_range();
4142

43+
let resolved_ty = ctx.sema.resolve_type(&ty)?;
44+
let resolved_ty = if !resolved_ty.contains_unknown() {
45+
let module = ctx.sema.scope(ty.syntax())?.module();
46+
let resolved_ty = resolved_ty.display_source_code(ctx.db(), module.into(), false).ok()?;
47+
make::ty(&resolved_ty)
48+
} else {
49+
ty.clone()
50+
};
51+
4252
acc.add(
4353
AssistId::refactor_extract("extract_type_alias"),
4454
"Extract type as type alias",
@@ -72,7 +82,7 @@ pub(crate) fn extract_type_alias(acc: &mut Assists, ctx: &AssistContext<'_>) ->
7282

7383
// Insert new alias
7484
let ty_alias =
75-
make::ty_alias(None, "Type", generic_params, None, None, Some((ty, None)))
85+
make::ty_alias(None, "Type", generic_params, None, None, Some((resolved_ty, None)))
7686
.clone_for_update();
7787

7888
if let Some(cap) = ctx.config.snippet_cap
@@ -391,4 +401,50 @@ where
391401
"#,
392402
);
393403
}
404+
405+
#[test]
406+
fn inferred_generic_type_parameter() {
407+
check_assist(
408+
extract_type_alias,
409+
r#"
410+
struct Wrap<T>(T);
411+
412+
fn main() {
413+
let wrap: $0Wrap<_>$0 = Wrap::<_>(3i32);
414+
}
415+
"#,
416+
r#"
417+
struct Wrap<T>(T);
418+
419+
type $0Type = Wrap<i32>;
420+
421+
fn main() {
422+
let wrap: Type = Wrap::<_>(3i32);
423+
}
424+
"#,
425+
)
426+
}
427+
428+
#[test]
429+
fn inferred_type() {
430+
check_assist(
431+
extract_type_alias,
432+
r#"
433+
struct Wrap<T>(T);
434+
435+
fn main() {
436+
let wrap: Wrap<$0_$0> = Wrap::<_>(3i32);
437+
}
438+
"#,
439+
r#"
440+
struct Wrap<T>(T);
441+
442+
type $0Type = i32;
443+
444+
fn main() {
445+
let wrap: Wrap<Type> = Wrap::<_>(3i32);
446+
}
447+
"#,
448+
)
449+
}
394450
}

0 commit comments

Comments
 (0)