Skip to content

Commit 42e5fd3

Browse files
committed
added typetree support for memcpy
1 parent bd70be1 commit 42e5fd3

File tree

17 files changed

+125
-21
lines changed

17 files changed

+125
-21
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
244244
scratch_align,
245245
bx.const_usize(copy_bytes),
246246
MemFlags::empty(),
247+
None,
247248
);
248249
bx.lifetime_end(llscratch, scratch_size);
249250
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
22
use std::ops::Deref;
33
use std::{iter, ptr};
44

5+
use rustc_ast::expand::typetree::FncTree;
56
pub(crate) mod autodiff;
67
pub(crate) mod gpu_offload;
78

@@ -1118,11 +1119,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11181119
src_align: Align,
11191120
size: &'ll Value,
11201121
flags: MemFlags,
1122+
tt: Option<FncTree>,
11211123
) {
11221124
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
11231125
let size = self.intcast(size, self.type_isize(), false);
11241126
let is_volatile = flags.contains(MemFlags::VOLATILE);
1125-
unsafe {
1127+
let memcpy = unsafe {
11261128
llvm::LLVMRustBuildMemCpy(
11271129
self.llbuilder,
11281130
dst,
@@ -1131,7 +1133,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11311133
src_align.bytes() as c_uint,
11321134
size,
11331135
is_volatile,
1134-
);
1136+
)
1137+
};
1138+
1139+
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
1140+
// a memcpy during autodiff, it needs to know the structure of the data being
1141+
// copied to properly track derivatives. For example, copying an array of floats
1142+
// vs. copying a struct with mixed types requires different derivative handling.
1143+
// The TypeTree tells Enzyme exactly what memory layout to expect.
1144+
if let Some(tt) = tt {
1145+
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
11351146
}
11361147
}
11371148

compiler/rustc_codegen_llvm/src/va_arg.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
735735
src_align,
736736
bx.const_u32(layout.layout.size().bytes() as u32),
737737
MemFlags::empty(),
738+
None,
738739
);
739740
tmp
740741
} else {

compiler/rustc_codegen_ssa/src/mir/block.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,6 +1623,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
16231623
align,
16241624
bx.const_usize(copy_bytes),
16251625
MemFlags::empty(),
1626+
None,
16261627
);
16271628
// ...and then load it with the ABI type.
16281629
llval = load_cast(bx, cast, llscratch, scratch_align);

compiler/rustc_codegen_ssa/src/mir/intrinsic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
3030
if allow_overlap {
3131
bx.memmove(dst, align, src, align, size, flags);
3232
} else {
33-
bx.memcpy(dst, align, src, align, size, flags);
33+
bx.memcpy(dst, align, src, align, size, flags, None);
3434
}
3535
}
3636

compiler/rustc_codegen_ssa/src/mir/statement.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
9090
let align = pointee_layout.align;
9191
let dst = dst_val.immediate();
9292
let src = src_val.immediate();
93-
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
93+
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
9494
}
9595
mir::StatementKind::FakeRead(..)
9696
| mir::StatementKind::Retag { .. }

compiler/rustc_codegen_ssa/src/traits/builder.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ pub trait BuilderMethods<'a, 'tcx>:
424424
src_align: Align,
425425
size: Self::Value,
426426
flags: MemFlags,
427+
tt: Option<rustc_ast::expand::typetree::FncTree>,
427428
);
428429
fn memmove(
429430
&mut self,
@@ -480,7 +481,7 @@ pub trait BuilderMethods<'a, 'tcx>:
480481
temp.val.store_with_flags(self, dst.with_type(layout), flags);
481482
} else if !layout.is_zst() {
482483
let bytes = self.const_usize(layout.size.bytes());
483-
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags);
484+
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None);
484485
}
485486
}
486487

compiler/rustc_interface/src/tests.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,6 @@ fn test_unstable_options_tracking_hash() {
764764
tracked!(allow_features, Some(vec![String::from("lang_items")]));
765765
tracked!(always_encode_mir, true);
766766
tracked!(assume_incomplete_release, true);
767-
tracked!(autodiff, vec![AutoDiff::Enable]);
768767
tracked!(autodiff, vec![AutoDiff::Enable, AutoDiff::NoTT]);
769768
tracked!(binary_dep_depinfo, true);
770769
tracked!(box_noalias, false);

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2286,12 +2286,12 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22862286
let child = typetree_from_ty(tcx, inner_ty);
22872287
return TypeTree(vec![Type {
22882288
offset: -1,
2289-
size: 8, // TODO(KMJ-007): Get actual pointer size from target
2289+
size: tcx.data_layout.pointer_size().bytes_usize(),
22902290
kind: Kind::Pointer,
22912291
child,
22922292
}]);
22932293
}
22942294

2295-
// TODO(KMJ-007): Handle arrays, slices, structs, and other complex types
2295+
// FIXME(KMJ-007): Handle arrays, slices, structs, and other complex types
22962296
TypeTree::new()
22972297
}

tests/codegen-llvm/autodiff/typetree.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ fn main() {
3030
let output_ = d_simple(&x, &mut df_dx, 1.0);
3131
assert_eq!(output, output_);
3232
assert_eq!(2.0, df_dx);
33-
}
33+
}

0 commit comments

Comments
 (0)