Skip to content

Commit 363ee2c

Browse files
committed
fixup! added typetree support for memcpy
1 parent 2c80fe5 commit 363ee2c

File tree

14 files changed

+115
-17
lines changed

14 files changed

+115
-17
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
244244
scratch_align,
245245
bx.const_usize(copy_bytes),
246246
MemFlags::empty(),
247+
None,
247248
);
248249
bx.lifetime_end(llscratch, scratch_size);
249250
}

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
22
use std::ops::Deref;
33
use std::{iter, ptr};
44

5+
use rustc_ast::expand::typetree::FncTree;
56
pub(crate) mod autodiff;
67
pub(crate) mod gpu_offload;
78

@@ -1118,11 +1119,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11181119
src_align: Align,
11191120
size: &'ll Value,
11201121
flags: MemFlags,
1122+
tt: Option<FncTree>,
11211123
) {
11221124
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
11231125
let size = self.intcast(size, self.type_isize(), false);
11241126
let is_volatile = flags.contains(MemFlags::VOLATILE);
1125-
unsafe {
1127+
let memcpy = unsafe {
11261128
llvm::LLVMRustBuildMemCpy(
11271129
self.llbuilder,
11281130
dst,
@@ -1131,7 +1133,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
11311133
src_align.bytes() as c_uint,
11321134
size,
11331135
is_volatile,
1134-
);
1136+
)
1137+
};
1138+
1139+
// TypeTree metadata for memcpy is especially important: when Enzyme encounters
1140+
// a memcpy during autodiff, it needs to know the structure of the data being
1141+
// copied to properly track derivatives. For example, copying an array of floats
1142+
// vs. copying a struct with mixed types requires different derivative handling.
1143+
// The TypeTree tells Enzyme exactly what memory layout to expect.
1144+
if let Some(tt) = tt {
1145+
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
11351146
}
11361147
}
11371148

compiler/rustc_codegen_llvm/src/va_arg.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,7 @@ fn copy_to_temporary_if_more_aligned<'ll, 'tcx>(
735735
src_align,
736736
bx.const_u32(layout.layout.size().bytes() as u32),
737737
MemFlags::empty(),
738+
None,
738739
);
739740
tmp
740741
} else {

compiler/rustc_codegen_ssa/src/mir/block.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,6 +1623,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
16231623
align,
16241624
bx.const_usize(copy_bytes),
16251625
MemFlags::empty(),
1626+
None,
16261627
);
16271628
// ...and then load it with the ABI type.
16281629
llval = load_cast(bx, cast, llscratch, scratch_align);

compiler/rustc_codegen_ssa/src/mir/intrinsic.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ fn copy_intrinsic<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
3030
if allow_overlap {
3131
bx.memmove(dst, align, src, align, size, flags);
3232
} else {
33-
bx.memcpy(dst, align, src, align, size, flags);
33+
bx.memcpy(dst, align, src, align, size, flags, None);
3434
}
3535
}
3636

compiler/rustc_codegen_ssa/src/mir/statement.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
9090
let align = pointee_layout.align;
9191
let dst = dst_val.immediate();
9292
let src = src_val.immediate();
93-
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty());
93+
bx.memcpy(dst, align, src, align, bytes, crate::MemFlags::empty(), None);
9494
}
9595
mir::StatementKind::FakeRead(..)
9696
| mir::StatementKind::Retag { .. }

compiler/rustc_codegen_ssa/src/traits/builder.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,7 @@ pub trait BuilderMethods<'a, 'tcx>:
424424
src_align: Align,
425425
size: Self::Value,
426426
flags: MemFlags,
427+
tt: Option<rustc_ast::expand::typetree::FncTree>,
427428
);
428429
fn memmove(
429430
&mut self,
@@ -480,7 +481,7 @@ pub trait BuilderMethods<'a, 'tcx>:
480481
temp.val.store_with_flags(self, dst.with_type(layout), flags);
481482
} else if !layout.is_zst() {
482483
let bytes = self.const_usize(layout.size.bytes());
483-
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags);
484+
self.memcpy(dst.llval, dst.align, src.llval, src.align, bytes, flags, None);
484485
}
485486
}
486487

tests/codegen-llvm/autodiff/typetree.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ fn main() {
3030
let output_ = d_simple(&x, &mut df_dx, 1.0);
3131
assert_eq!(output, output_);
3232
assert_eq!(2.0, df_dx);
33-
}
33+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
; Check that enzyme_type attributes are present in the LLVM IR function definition
2+
; This verifies our TypeTree system correctly attaches metadata for Enzyme
3+
4+
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_memcpy({{.*}}"enzyme_type"="{[]:Pointer}"
5+
6+
; Check that the differentiated function also has proper enzyme_type attributes
7+
CHECK: @diffetest_memcpy({{.*}}"enzyme_type"="{[]:Pointer}"{{.*}}"enzyme_type"="{[]:Pointer}"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
CHECK: test_memcpy - {[-1]:Float@double} |{[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double, [-1,24]:Float@double, [-1,32]:Float@double, [-1,40]:Float@double, [-1,48]:Float@double, [-1,56]:Float@double}:{}
2+
3+
CHECK-DAG: ptr %{{[0-9]+}}: {[-1]:Pointer, [-1,0]:Float@double, [-1,8]:Float@double, [-1,16]:Float@double, [-1,24]:Float@double, [-1,32]:Float@double, [-1,40]:Float@double, [-1,48]:Float@double, [-1,56]:Float@double}
4+
5+
CHECK-DAG: load double{{.*}}: {[-1]:Float@double}
6+
7+
CHECK-DAG: fmul double{{.*}}: {[-1]:Float@double}
8+
9+
CHECK-DAG: fadd double{{.*}}: {[-1]:Float@double}

0 commit comments

Comments
 (0)