Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/cuda_builder/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,10 @@ fn invoke_rustc(builder: &CudaBuilder) -> Result<PathBuf, CudaBuilderError> {
llvm_args.push("--override-libm".to_string());
}

if builder.use_constant_memory_space {
llvm_args.push("--use-constant-memory-space".to_string());
}

if let Some(path) = &builder.final_module_path {
llvm_args.push("--final-module-path".to_string());
llvm_args.push(path.to_str().unwrap().to_string());
Expand Down
105 changes: 96 additions & 9 deletions crates/rustc_codegen_nvvm/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ use tracing::{debug, trace};
/// <https://docs.nvidia.com/cuda/archive/12.8.1/pdf/CUDA_C_Best_Practices_Guide.pdf>
const CONSTANT_MEMORY_SIZE_LIMIT_BYTES: u64 = 64 * 1024;

/// Threshold for warning when approaching 80% of constant memory limit
const CONSTANT_MEMORY_WARNING_THRESHOLD_BYTES: u64 = (CONSTANT_MEMORY_SIZE_LIMIT_BYTES * 80) / 100;

pub(crate) struct CodegenCx<'ll, 'tcx> {
pub tcx: TyCtxt<'tcx>,

Expand Down Expand Up @@ -104,6 +107,9 @@ pub(crate) struct CodegenCx<'ll, 'tcx> {
pub codegen_args: CodegenArgs,
// the value of the last call instruction. Needed for return type remapping.
pub last_call_llfn: Cell<Option<&'ll Value>>,

/// Tracks cumulative constant memory usage in bytes for compile-time diagnostics
constant_memory_usage: Cell<u64>,
}

impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
Expand Down Expand Up @@ -174,6 +180,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
dbg_cx,
codegen_args: CodegenArgs::from_session(tcx.sess()),
last_call_llfn: Cell::new(None),
constant_memory_usage: Cell::new(0),
};
cx.build_intrinsics_map();
cx
Expand Down Expand Up @@ -281,16 +288,96 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
// static and many small ones, you might want the small ones to all be
// in constant memory or just the big one depending on your workload.
let layout = self.layout_of(ty);
if layout.size.bytes() > CONSTANT_MEMORY_SIZE_LIMIT_BYTES {
self.tcx.sess.dcx().warn(format!(
"static `{instance}` exceeds the constant memory limit; placing in global memory (performance may be reduced)"
));
// Place instance in global memory if it is too big for constant memory.
AddressSpace(1)
} else {
// Place instance in constant memory if it fits.
AddressSpace(4)
let size_bytes = layout.size.bytes();
let current_usage = self.constant_memory_usage.get();
let new_usage = current_usage + size_bytes;

// Check if this single static is too large for constant memory
if size_bytes > CONSTANT_MEMORY_SIZE_LIMIT_BYTES {
let def_id = instance.def_id();
let span = self.tcx.def_span(def_id);
let mut diag = self.tcx.sess.dcx().struct_span_warn(
span,
format!(
"static `{instance}` is {size_bytes} bytes, exceeds the constant memory limit of {} bytes",
CONSTANT_MEMORY_SIZE_LIMIT_BYTES
),
);
diag.span_label(span, "static exceeds constant memory limit");
diag.note("placing in global memory (performance may be reduced)");
diag.help("use `#[cuda_std::address_space(global)]` to explicitly place this static in global memory");
diag.emit();
return AddressSpace(1);
}

// Check if adding this static would exceed the cumulative limit
if new_usage > CONSTANT_MEMORY_SIZE_LIMIT_BYTES {
let def_id = instance.def_id();
let span = self.tcx.def_span(def_id);
let mut diag = self.tcx.sess.dcx().struct_span_err(
span,
format!(
"cannot place static `{instance}` ({size_bytes} bytes) in constant memory: \
cumulative constant memory usage would be {new_usage} bytes, exceeding the {} byte limit",
CONSTANT_MEMORY_SIZE_LIMIT_BYTES
),
);
diag.span_label(
span,
format!(
"this static would cause total usage to exceed {} bytes",
CONSTANT_MEMORY_SIZE_LIMIT_BYTES
),
);
diag.note(format!(
"current constant memory usage: {current_usage} bytes"
));
diag.note(format!("static size: {size_bytes} bytes"));
diag.note(format!("would result in: {new_usage} bytes total"));

diag.help("move this or other statics to global memory using `#[cuda_std::address_space(global)]`");
diag.help("reduce the total size of static data");
diag.help("disable automatic constant memory placement by setting `.use_constant_memory_space(false)` on `CudaBuilder` in build.rs");

diag.emit();
self.tcx.sess.dcx().abort_if_errors();
unreachable!()
}

// If successfully placed in constant memory: update cumulative usage
self.constant_memory_usage.set(new_usage);

// If approaching the threshold: warns
if new_usage > CONSTANT_MEMORY_WARNING_THRESHOLD_BYTES
&& current_usage <= CONSTANT_MEMORY_WARNING_THRESHOLD_BYTES
{
let def_id = instance.def_id();
let span = self.tcx.def_span(def_id);
let usage_percent =
(new_usage as f64 / CONSTANT_MEMORY_SIZE_LIMIT_BYTES as f64) * 100.0;
let mut diag = self.tcx.sess.dcx().struct_span_warn(
span,
format!(
"constant memory usage is approaching the limit: {new_usage} / {} bytes ({usage_percent:.1}% used)",
CONSTANT_MEMORY_SIZE_LIMIT_BYTES
),
);
diag.span_label(
span,
"this placement brought you over 80% of constant memory capacity",
);
diag.note(format!(
"only {} bytes of constant memory remain",
CONSTANT_MEMORY_SIZE_LIMIT_BYTES - new_usage
));
diag.help("to prevent constant memory overflow, consider moving some statics to global memory using `#[cuda_std::address_space(global)]`");
diag.emit();
}

trace!(
"Placing static `{instance}` ({size_bytes} bytes) in constant memory. Total usage: {new_usage} bytes"
);
AddressSpace(4)
}
} else {
AddressSpace::ZERO
Expand Down
Loading