Skip to content

Commit d2de1b9

Browse files
committed
fix: improve WASM runtime memory management
- Add memory tracking for allocated parameters in guest function calls - Implement automatic cleanup of allocated memory after function calls - Component: Add post_return calls for proper WASM function cleanup - Fix memory leaks in parameter marshaling This ensures guest function parameters are properly freed after use. Signed-off-by: Ludvig Liljenberg <4257730+ludfjig@users.noreply.github.com>
1 parent ae9bacd commit d2de1b9

File tree

4 files changed

+95
-27
lines changed

4 files changed

+95
-27
lines changed

src/hyperlight_wasm_macro/src/wasmguest.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,17 @@ fn emit_wasm_function_call(
212212
let rwt = match result {
213213
None => {
214214
quote! {
215-
instance.get_typed_func::<(#(#pwts,)*), ()>(&mut *store, func_idx)?
216-
.call(&mut *store, (#(#pus,)*))?;
215+
let func = instance.get_typed_func::<(#(#pwts,)*), ()>(&mut *store, func_idx)?;
216+
func.call(&mut *store, (#(#pus,)*))?;
217+
func.post_return(&mut *store)?;
217218
}
218219
}
219220
_ => {
220221
let r = rtypes::emit_func_result(s, result);
221222
quote! {
222-
let #ret = instance.get_typed_func::<(#(#pwts,)*), ((#r,))>(&mut *store, func_idx)?
223-
.call(&mut *store, (#(#pus,)*))?.0;
223+
let func = instance.get_typed_func::<(#(#pwts,)*), ((#r,))>(&mut *store, func_idx)?;
224+
let #ret = func.call(&mut *store, (#(#pus,)*))?.0;
225+
func.post_return(&mut *store)?;
224226
}
225227
}
226228
};

src/wasm_runtime/src/marshal.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ fn malloc<C: AsContextMut>(
4646
Ok(addr)
4747
}
4848

49+
fn free<C: AsContextMut>(
50+
ctx: &mut C,
51+
get_export: &impl Fn(&mut C, &str) -> Option<Extern>,
52+
addr: i32,
53+
) -> Result<()> {
54+
let free = get_export(&mut *ctx, "free")
55+
.and_then(Extern::into_func)
56+
.ok_or(HyperlightGuestError::new(
57+
ErrorCode::GuestError,
58+
"free function not exported".to_string(),
59+
))?;
60+
free.typed::<i32, ()>(&mut *ctx)?.call(&mut *ctx, addr)?;
61+
Ok(())
62+
}
63+
4964
fn write<C: AsContextMut>(
5065
ctx: &mut C,
5166
get_export: &impl Fn(&mut C, &str) -> Option<Extern>,
@@ -126,10 +141,11 @@ fn read_cstr<C: AsContextMut>(
126141
})
127142
}
128143

129-
pub fn hl_param_to_val<C: AsContextMut>(
144+
pub fn hl_param_to_val_with_tracking<C: AsContextMut>(
130145
mut ctx: C,
131146
get_export: impl Fn(&mut C, &str) -> Option<Extern>,
132147
param: &ParameterValue,
148+
allocated_addrs: &mut Vec<i32>,
133149
) -> Result<Val> {
134150
match param {
135151
ParameterValue::Int(i) => Ok(Val::I32(*i)),
@@ -144,17 +160,31 @@ pub fn hl_param_to_val<C: AsContextMut>(
144160
let nbytes = s.count_bytes() + 1; // include the NUL terminator
145161
let addr = malloc(&mut ctx, &get_export, nbytes)?;
146162
write(&mut ctx, &get_export, addr, s.as_bytes_with_nul())?;
163+
allocated_addrs.push(addr); // Track for later cleanup
147164
Ok(Val::I32(addr))
148165
}
149166
ParameterValue::VecBytes(b) => {
150167
let addr = malloc(&mut ctx, &get_export, b.len())?;
151168
write(&mut ctx, &get_export, addr, b)?;
169+
allocated_addrs.push(addr); // Track for later cleanup
152170
Ok(Val::I32(addr))
153171
// TODO: check that the next parameter is the correct length
154172
}
155173
}
156174
}
157175

176+
/// Helper function to free all tracked allocated addresses
177+
pub fn free_allocated_addrs<C: AsContextMut>(
178+
mut ctx: C,
179+
get_export: impl Fn(&mut C, &str) -> Option<Extern>,
180+
allocated_addrs: &[i32],
181+
) -> Result<()> {
182+
for &addr in allocated_addrs {
183+
free(&mut ctx, &get_export, addr)?;
184+
}
185+
Ok(())
186+
}
187+
158188
pub fn val_to_hl_result<C: AsContextMut>(
159189
mut ctx: C,
160190
get_export: impl Fn(&mut C, &str) -> Option<Extern>,

src/wasm_runtime/src/module.rs

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License.
1717
use alloc::string::ToString;
1818
use alloc::vec::Vec;
1919
use alloc::{format, vec};
20-
use core::ops::Deref;
20+
use core::ops::{Deref, DerefMut};
2121

2222
use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall;
2323
use hyperlight_common::flatbuffer_wrappers::function_types::{
@@ -34,57 +34,71 @@ use wasmtime::{Config, Engine, Linker, Module, Store, Val};
3434

3535
use crate::{hostfuncs, marshal, platform, wasip1};
3636

37+
// Set by transition to WasmSandbox (by init_wasm_runtime)
3738
static CUR_ENGINE: Mutex<Option<Engine>> = Mutex::new(None);
3839
static CUR_LINKER: Mutex<Option<Linker<()>>> = Mutex::new(None);
40+
// Set by transition to LoadedWasmSandbox (by load_wasm_module/load_wasm_module_phys)
3941
static CUR_MODULE: Mutex<Option<Module>> = Mutex::new(None);
42+
static CUR_STORE: Mutex<Option<Store<()>>> = Mutex::new(None);
43+
static CUR_INSTANCE: Mutex<Option<wasmtime::Instance>> = Mutex::new(None);
4044

4145
#[no_mangle]
4246
pub fn guest_dispatch_function(function_call: &FunctionCall) -> Result<Vec<u8>> {
43-
let engine = CUR_ENGINE.lock();
44-
let engine = engine.deref().as_ref().ok_or(HyperlightGuestError::new(
47+
let mut store = CUR_STORE.lock();
48+
let store = store.deref_mut().as_mut().ok_or(HyperlightGuestError::new(
4549
ErrorCode::GuestError,
46-
"Wasm runtime is not initialized".to_string(),
50+
"No wasm store available".to_string(),
4751
))?;
48-
let linker = CUR_LINKER.lock();
49-
let linker = linker.deref().as_ref().ok_or(HyperlightGuestError::new(
52+
let instance = CUR_INSTANCE.lock();
53+
let instance = instance.deref().as_ref().ok_or(HyperlightGuestError::new(
5054
ErrorCode::GuestError,
51-
"impossible: wasm runtime has no valid linker".to_string(),
55+
"No wasm instance available".to_string(),
5256
))?;
53-
let module = CUR_MODULE.lock();
54-
let module = module.deref().as_ref().ok_or(HyperlightGuestError::new(
55-
ErrorCode::GuestError,
56-
"No wasm module loaded".to_string(),
57-
))?;
58-
let mut store = Store::new(engine, ());
59-
let instance = linker.instantiate(&mut store, module)?;
57+
6058
let func = instance
61-
.get_func(&mut store, &function_call.function_name)
59+
.get_func(&mut *store, &function_call.function_name)
6260
.ok_or(HyperlightGuestError::new(
6361
ErrorCode::GuestError,
6462
"Function not found".to_string(),
6563
))?;
64+
6665
let mut w_params = vec![];
66+
let mut allocated_addrs = vec![];
6767
for f_param in (function_call.parameters)
6868
.as_ref()
6969
.unwrap_or(&vec![])
7070
.iter()
7171
{
72-
w_params.push(marshal::hl_param_to_val(
73-
&mut store,
72+
w_params.push(marshal::hl_param_to_val_with_tracking(
73+
&mut *store,
7474
|ctx, name| instance.get_export(ctx, name),
7575
f_param,
76+
&mut allocated_addrs,
7677
)?);
7778
}
7879
let is_void = ReturnType::Void == function_call.expected_return_type;
7980
let n_results = if is_void { 0 } else { 1 };
8081
let mut results = vec![Val::I32(0); n_results];
81-
func.call(&mut store, &w_params, &mut results)?;
82-
marshal::val_to_hl_result(
83-
&mut store,
82+
func.call(&mut *store, &w_params, &mut results)?;
83+
let result = marshal::val_to_hl_result(
84+
&mut *store,
8485
|ctx, name| instance.get_export(ctx, name),
8586
function_call.expected_return_type,
8687
&results,
88+
);
89+
90+
marshal::free_allocated_addrs(
91+
&mut *store,
92+
|ctx, name| instance.get_export(ctx, name),
93+
&allocated_addrs,
8794
)
95+
.map_err(|e| {
96+
HyperlightGuestError::new(
97+
ErrorCode::GuestError,
98+
format!("Failed to free memory allocated for params: {:?}", e),
99+
)
100+
})?;
101+
result
88102
}
89103

90104
fn init_wasm_runtime() -> Result<Vec<u8>> {
@@ -124,8 +138,19 @@ fn load_wasm_module(function_call: &FunctionCall) -> Result<Vec<u8>> {
124138
&function_call.parameters.as_ref().unwrap()[1],
125139
&*CUR_ENGINE.lock(),
126140
) {
141+
let linker = CUR_LINKER.lock();
142+
let linker = linker.deref().as_ref().ok_or(HyperlightGuestError::new(
143+
ErrorCode::GuestError,
144+
"impossible: wasm runtime has no valid linker".to_string(),
145+
))?;
146+
127147
let module = unsafe { Module::deserialize(engine, wasm_bytes)? };
148+
let mut store = Store::new(engine, ());
149+
let instance = linker.instantiate(&mut store, &module)?;
150+
128151
*CUR_MODULE.lock() = Some(module);
152+
*CUR_STORE.lock() = Some(store);
153+
*CUR_INSTANCE.lock() = Some(instance);
129154
Ok(get_flatbuffer_result::<i32>(0))
130155
} else {
131156
Err(HyperlightGuestError::new(
@@ -141,8 +166,19 @@ fn load_wasm_module_phys(function_call: &FunctionCall) -> Result<Vec<u8>> {
141166
&function_call.parameters.as_ref().unwrap()[1],
142167
&*CUR_ENGINE.lock(),
143168
) {
169+
let linker = CUR_LINKER.lock();
170+
let linker = linker.deref().as_ref().ok_or(HyperlightGuestError::new(
171+
ErrorCode::GuestError,
172+
"impossible: wasm runtime has no valid linker".to_string(),
173+
))?;
174+
144175
let module = unsafe { Module::deserialize_raw(engine, platform::map_buffer(*phys, *len))? };
176+
let mut store = Store::new(engine, ());
177+
let instance = linker.instantiate(&mut store, &module)?;
178+
145179
*CUR_MODULE.lock() = Some(module);
180+
*CUR_STORE.lock() = Some(store);
181+
*CUR_INSTANCE.lock() = Some(instance);
146182
Ok(get_flatbuffer_result::<()>(()))
147183
} else {
148184
Err(HyperlightGuestError::new(

src/wasm_runtime/src/platform.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ pub(crate) fn register_page_fault_handler() {
5959
// See AMD64 Architecture Programmer's Manual, Volume 2
6060
// §8.2 Vectors, p. 245
6161
// Table 8-1: Interrupt Vector Source and Cause
62-
handler::handlers[14].store(page_fault_handler as usize as u64, Ordering::Release);
62+
handler::HANDLERS[14].store(page_fault_handler as usize as u64, Ordering::Release);
6363
}
6464

6565
// Wasmtime Embedding Interface
@@ -155,7 +155,7 @@ pub extern "C" fn wasmtime_init_traps(handler: wasmtime_trap_handler_t) -> i32 {
155155
// See AMD64 Architecture Programmer's Manual, Volume 2
156156
// §8.2 Vectors, p. 245
157157
// Table 8-1: Interrupt Vector Source and Cause
158-
handler::handlers[6].store(wasmtime_trap_handler as usize as u64, Ordering::Release);
158+
handler::HANDLERS[6].store(wasmtime_trap_handler as usize as u64, Ordering::Release);
159159
// TODO: Add handlers for any other traps that wasmtime needs,
160160
// probably including at least some floating-point
161161
// exceptions

0 commit comments

Comments
 (0)