From ed36445878345d7c08bd8b59232f07cb7e1f88ec Mon Sep 17 00:00:00 2001 From: coord_e Date: Mon, 24 Nov 2025 13:03:30 +0900 Subject: [PATCH] Implement #[thrust::extern_spec_fn] --- src/analyze/annot.rs | 4 ++++ src/analyze/crate_.rs | 12 +++++++++- src/analyze/local_def.rs | 52 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 66 insertions(+), 2 deletions(-) diff --git a/src/analyze/annot.rs b/src/analyze/annot.rs index 91dd209..2dbb9ea 100644 --- a/src/analyze/annot.rs +++ b/src/analyze/annot.rs @@ -33,6 +33,10 @@ pub fn callable_path() -> [Symbol; 2] { [Symbol::intern("thrust"), Symbol::intern("callable")] } +pub fn extern_spec_fn_path() -> [Symbol; 2] { + [Symbol::intern("thrust"), Symbol::intern("extern_spec_fn")] +} + /// A [`annot::Resolver`] implementation for resolving function parameters. /// /// The parameter names and their sorts needs to be configured via diff --git a/src/analyze/crate_.rs b/src/analyze/crate_.rs index d6e343d..b1cd0af 100644 --- a/src/analyze/crate_.rs +++ b/src/analyze/crate_.rs @@ -46,6 +46,11 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.trusted.insert(local_def_id.to_def_id()); } + if analyzer.is_annotated_as_extern_spec_fn() { + assert!(analyzer.is_fully_annotated()); + self.trusted.insert(local_def_id.to_def_id()); + } + let sig = self .tcx .fn_sig(local_def_id) @@ -56,7 +61,12 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { self.ctx.register_deferred_def(local_def_id.to_def_id()); } else { let expected = analyzer.expected_ty(); - self.ctx.register_def(local_def_id.to_def_id(), expected); + let target_def_id = if analyzer.is_annotated_as_extern_spec_fn() { + analyzer.extern_spec_fn_target_def_id() + } else { + local_def_id.to_def_id() + }; + self.ctx.register_def(target_def_id, expected); } } diff --git a/src/analyze/local_def.rs b/src/analyze/local_def.rs index 50a4397..28dc38b 100644 --- a/src/analyze/local_def.rs +++ b/src/analyze/local_def.rs @@ -6,7 +6,7 @@ use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; use rustc_middle::mir::{self, BasicBlock, Body, Local}; use rustc_middle::ty::{self as mir_ty, TyCtxt, TypeAndMut}; -use rustc_span::def_id::LocalDefId; +use rustc_span::def_id::{DefId, LocalDefId}; use rustc_span::symbol::Ident; use crate::analyze; @@ -126,6 +126,16 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { .is_some() } + pub fn is_annotated_as_extern_spec_fn(&self) -> bool { + self.tcx + .get_attrs_by_path( + self.local_def_id.to_def_id(), + &analyze::annot::extern_spec_fn_path(), + ) + .next() + .is_some() + } + // TODO: unify this logic with extraction functions above pub fn is_fully_annotated(&self) -> bool { let has_require = self @@ -240,6 +250,46 @@ impl<'tcx, 'ctx> Analyzer<'tcx, 'ctx> { rty::RefinedType::unrefined(builder.build().into()) } + pub fn extern_spec_fn_target_def_id(&self) -> DefId { + struct ExtractDefId<'tcx> { + tcx: TyCtxt<'tcx>, + outer_def_id: LocalDefId, + inner_def_id: Option, + } + + impl<'tcx> rustc_hir::intravisit::Visitor<'tcx> for ExtractDefId<'tcx> { + type NestedFilter = rustc_middle::hir::nested_filter::OnlyBodies; + + fn nested_visit_map(&mut self) -> Self::Map { + self.tcx.hir() + } + + fn visit_qpath( + &mut self, + qpath: &rustc_hir::QPath<'tcx>, + hir_id: rustc_hir::HirId, + _span: rustc_span::Span, + ) { + let typeck_result = self.tcx.typeck(self.outer_def_id); + if let rustc_hir::def::Res::Def(_, def_id) = typeck_result.qpath_res(qpath, hir_id) + { + self.inner_def_id = Some(def_id); + } + } + } + + use rustc_hir::intravisit::Visitor as _; + let mut visitor = ExtractDefId { + tcx: self.tcx, + outer_def_id: self.local_def_id, + inner_def_id: None, + }; + if let rustc_hir::Node::Item(item) = self.tcx.hir_node_by_def_id(self.local_def_id) { + visitor.visit_item(item); + } + visitor.inner_def_id.expect("invalid extern_spec_fn") + } + fn is_mut_param(&self, param_idx: rty::FunctionParamIdx) -> bool { let param_local = analyze::local_of_function_param(param_idx); self.body.local_decls[param_local].mutability.is_mut()