Skip to content

Commit

Permalink
Add a test for field default value body as defining usage of TAIT
Browse files Browse the repository at this point in the history
  • Loading branch information
ShoyuVanilla committed Jan 27, 2025
1 parent c7463fe commit ec89b7d
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 27 deletions.
6 changes: 1 addition & 5 deletions src/tools/rust-analyzer/crates/hir-def/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1399,11 +1399,7 @@ impl HasModule for DefWithBodyId {
DefWithBodyId::ConstId(it) => it.module(db),
DefWithBodyId::VariantId(it) => it.module(db),
DefWithBodyId::InTypeConstId(it) => it.lookup(db).owner.module(db),
DefWithBodyId::FieldId(it) => match it.parent {
VariantId::EnumVariantId(it) => it.module(db),
VariantId::StructId(it) => it.module(db),
VariantId::UnionId(it) => it.module(db),
},
DefWithBodyId::FieldId(it) => it.module(db),
}
}
}
Expand Down
112 changes: 91 additions & 21 deletions src/tools/rust-analyzer/crates/hir-ty/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,22 @@ use std::env;
use std::sync::LazyLock;

use base_db::SourceDatabaseFileInputExt as _;
use either::Either;
use expect_test::Expect;
use hir_def::{
db::DefDatabase,
expr_store::{Body, BodySourceMap},
hir::{ExprId, Pat, PatId},
item_scope::ItemScope,
nameres::DefMap,
src::HasSource,
AssocItemId, DefWithBodyId, HasModule, LocalModuleId, Lookup, ModuleDefId, SyntheticSyntax,
src::{HasChildSource, HasSource},
AdtId, AssocItemId, DefWithBodyId, FieldId, HasModule, LocalModuleId, Lookup, ModuleDefId,
SyntheticSyntax,
};
use hir_expand::{db::ExpandDatabase, FileRange, InFile};
use itertools::Itertools;
use rustc_hash::FxHashMap;
use span::TextSize;
use stdx::format_to;
use syntax::{
ast::{self, AstNode, HasName},
Expand Down Expand Up @@ -132,14 +135,40 @@ fn check_impl(
None => continue,
};
let def_map = module.def_map(&db);
visit_module(&db, &def_map, module.local_id, &mut |it| {
defs.push(match it {
ModuleDefId::FunctionId(it) => it.into(),
ModuleDefId::EnumVariantId(it) => it.into(),
ModuleDefId::ConstId(it) => it.into(),
ModuleDefId::StaticId(it) => it.into(),
_ => return,
})
visit_module(&db, &def_map, module.local_id, &mut |it| match it {
ModuleDefId::FunctionId(it) => defs.push(it.into()),
ModuleDefId::EnumVariantId(it) => {
defs.push(it.into());
let variant_id = it.into();
let vd = db.variant_data(variant_id);
defs.extend(vd.fields().iter().filter_map(|(local_id, fd)| {
if fd.has_default {
let field = FieldId { parent: variant_id, local_id, has_default: true };
Some(DefWithBodyId::FieldId(field))
} else {
None
}
}));
}
ModuleDefId::ConstId(it) => defs.push(it.into()),
ModuleDefId::StaticId(it) => defs.push(it.into()),
ModuleDefId::AdtId(it) => {
let variant_id = match it {
AdtId::StructId(it) => it.into(),
AdtId::UnionId(it) => it.into(),
AdtId::EnumId(_) => return,
};
let vd = db.variant_data(variant_id);
defs.extend(vd.fields().iter().filter_map(|(local_id, fd)| {
if fd.has_default {
let field = FieldId { parent: variant_id, local_id, has_default: true };
Some(DefWithBodyId::FieldId(field))
} else {
None
}
}));
}
_ => {}
});
}
defs.sort_by_key(|def| match def {
Expand All @@ -160,12 +189,20 @@ fn check_impl(
loc.source(&db).value.syntax().text_range().start()
}
DefWithBodyId::InTypeConstId(it) => it.source(&db).syntax().text_range().start(),
DefWithBodyId::FieldId(_) => unreachable!(),
DefWithBodyId::FieldId(it) => {
let cs = it.parent.child_source(&db);
match cs.value.get(it.local_id) {
Some(Either::Left(it)) => it.syntax().text_range().start(),
Some(Either::Right(it)) => it.syntax().text_range().end(),
None => TextSize::new(u32::MAX),
}
}
});
let mut unexpected_type_mismatches = String::new();
for def in defs {
let (body, body_source_map) = db.body_with_source_map(def);
let inference_result = db.infer(def);
dbg!(&inference_result);

for (pat, mut ty) in inference_result.type_of_pat.iter() {
if let Pat::Bind { id, .. } = body.pats[pat] {
Expand Down Expand Up @@ -389,14 +426,40 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
let def_map = module.def_map(&db);

let mut defs: Vec<DefWithBodyId> = Vec::new();
visit_module(&db, &def_map, module.local_id, &mut |it| {
defs.push(match it {
ModuleDefId::FunctionId(it) => it.into(),
ModuleDefId::EnumVariantId(it) => it.into(),
ModuleDefId::ConstId(it) => it.into(),
ModuleDefId::StaticId(it) => it.into(),
_ => return,
})
visit_module(&db, &def_map, module.local_id, &mut |it| match it {
ModuleDefId::FunctionId(it) => defs.push(it.into()),
ModuleDefId::EnumVariantId(it) => {
defs.push(it.into());
let variant_id = it.into();
let vd = db.variant_data(variant_id);
defs.extend(vd.fields().iter().filter_map(|(local_id, fd)| {
if fd.has_default {
let field = FieldId { parent: variant_id, local_id, has_default: true };
Some(DefWithBodyId::FieldId(field))
} else {
None
}
}));
}
ModuleDefId::ConstId(it) => defs.push(it.into()),
ModuleDefId::StaticId(it) => defs.push(it.into()),
ModuleDefId::AdtId(it) => {
let variant_id = match it {
AdtId::StructId(it) => it.into(),
AdtId::UnionId(it) => it.into(),
AdtId::EnumId(_) => return,
};
let vd = db.variant_data(variant_id);
defs.extend(vd.fields().iter().filter_map(|(local_id, fd)| {
if fd.has_default {
let field = FieldId { parent: variant_id, local_id, has_default: true };
Some(DefWithBodyId::FieldId(field))
} else {
None
}
}));
}
_ => {}
});
defs.sort_by_key(|def| match def {
DefWithBodyId::FunctionId(it) => {
Expand All @@ -416,7 +479,14 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
loc.source(&db).value.syntax().text_range().start()
}
DefWithBodyId::InTypeConstId(it) => it.source(&db).syntax().text_range().start(),
DefWithBodyId::FieldId(_) => unreachable!(),
DefWithBodyId::FieldId(it) => {
let cs = it.parent.child_source(&db);
match cs.value.get(it.local_id) {
Some(Either::Left(it)) => it.syntax().text_range().start(),
Some(Either::Right(it)) => it.syntax().text_range().end(),
None => TextSize::new(u32::MAX),
}
}
});
for def in defs {
let (body, source_map) = db.body_with_source_map(def);
Expand Down Expand Up @@ -477,7 +547,7 @@ pub(crate) fn visit_module(
let body = db.body(it.into());
visit_body(db, &body, cb);
}
ModuleDefId::AdtId(hir_def::AdtId::EnumId(it)) => {
ModuleDefId::AdtId(AdtId::EnumId(it)) => {
db.enum_data(it).variants.iter().for_each(|&(it, _)| {
let body = db.body(it.into());
cb(it.into());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,53 @@ static ALIAS: i32 = {
217..218 '5': i32
205..211: expected impl Trait + ?Sized, got Struct
"#]],
)
);
}

#[test]
fn defining_type_alias_impl_trait_from_default_fields() {
check_no_mismatches(
r#"
trait Trait {}
struct Struct;
impl Trait for Struct {}
type AliasTy = impl Trait;
struct Foo {
foo: AliasTy = {
let x: AliasTy = Struct;
x
},
}
"#,
);

check_infer_with_mismatches(
r#"
trait Trait {}
struct Struct;
impl Trait for Struct {}
type AliasTy = impl Trait;
struct Foo {
foo: i32 = {
let x: AliasTy = Struct;
5
},
}
"#,
expect![[r#"
114..164 '{ ... }': i32
128..129 'x': impl Trait + ?Sized
141..147 'Struct': Struct
157..158 '5': i32
141..147: expected impl Trait + ?Sized, got Struct
"#]],
);
}

0 comments on commit ec89b7d

Please sign in to comment.