Skip to content

Commit fa04ec0

Browse files
committed
[guest] improve WASM runtime memory management
- Reuse wasmtime Store and Instance across guest function calls instead of creating new one per call. - Add memory tracking for allocated parameters in guest function calls - Add memory tracking for allocated parameters in host function calls - Free allocated parameters after each guest function call (implies that host function return values are returned by ref, so caller must copy values if they want to keep them around as they will get freed after returning). - Component: Add post_return calls for proper WASM function cleanup - Fix ABI mismatch in parameter of guest_dispatch_function Signed-off-by: Ludvig Liljenberg <4257730+ludfjig@users.noreply.github.com>
1 parent 66b47e7 commit fa04ec0

File tree

5 files changed

+151
-32
lines changed

5 files changed

+151
-32
lines changed

src/hyperlight_wasm_macro/src/wasmguest.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,15 +212,17 @@ fn emit_wasm_function_call(
212212
let rwt = match result {
213213
None => {
214214
quote! {
215-
instance.get_typed_func::<(#(#pwts,)*), ()>(&mut *store, func_idx)?
216-
.call(&mut *store, (#(#pus,)*))?;
215+
let func = instance.get_typed_func::<(#(#pwts,)*), ()>(&mut *store, func_idx)?;
216+
func.call(&mut *store, (#(#pus,)*))?;
217+
func.post_return(&mut *store)?;
217218
}
218219
}
219220
_ => {
220221
let r = rtypes::emit_func_result(s, result);
221222
quote! {
222-
let #ret = instance.get_typed_func::<(#(#pwts,)*), ((#r,))>(&mut *store, func_idx)?
223-
.call(&mut *store, (#(#pus,)*))?.0;
223+
let func = instance.get_typed_func::<(#(#pwts,)*), ((#r,))>(&mut *store, func_idx)?;
224+
let #ret = func.call(&mut *store, (#(#pus,)*))?.0;
225+
func.post_return(&mut *store)?;
224226
}
225227
}
226228
};

src/wasm_runtime/src/hostfuncs.rs

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515
*/
1616

1717
use alloc::string::ToString;
18+
use alloc::vec;
1819
use alloc::vec::Vec;
1920

2021
use hyperlight_common::flatbuffer_wrappers::function_types::{
@@ -23,10 +24,23 @@ use hyperlight_common::flatbuffer_wrappers::function_types::{
2324
use hyperlight_common::flatbuffer_wrappers::guest_error::ErrorCode;
2425
use hyperlight_guest::error::{HyperlightGuestError, Result};
2526
use hyperlight_guest_bin::host_comm::call_host_function;
27+
use spin::Mutex;
2628
use wasmtime::{Caller, Engine, FuncType, Val, ValType};
2729

2830
use crate::marshal;
2931

32+
// Track memory addresses allocated by host function return values
33+
static HOSTFUNC_ALLOCATED_ADDRS: Mutex<Vec<i32>> = Mutex::new(Vec::new());
34+
35+
/// Retrieve and clear all addresses allocated by host function return values.
36+
/// Caller is responsible for making sure the addresses are freed
37+
pub(crate) fn take_hostfunc_allocated_addrs() -> Vec<i32> {
38+
let mut addrs = HOSTFUNC_ALLOCATED_ADDRS.lock();
39+
let mut result = Vec::new();
40+
core::mem::swap(&mut *addrs, &mut result);
41+
result
42+
}
43+
3044
pub(crate) type HostFunctionDefinition =
3145
hyperlight_common::flatbuffer_wrappers::host_function_definition::HostFunctionDefinition;
3246
pub(crate) type HostFunctionDetails =
@@ -79,9 +93,9 @@ pub(crate) fn hostfunc_type(d: &HostFunctionDefinition, e: &Engine) -> Result<Fu
7993
Ok(FuncType::new(e, params, results))
8094
}
8195

82-
pub(crate) fn call(
96+
pub(crate) fn call<T>(
8397
d: &HostFunctionDefinition,
84-
mut c: Caller<'_, ()>,
98+
mut c: Caller<'_, T>,
8599
ps: &[Val],
86100
rs: &mut [Val],
87101
) -> Result<()> {
@@ -110,7 +124,19 @@ pub(crate) fn call(
110124
return Ok(());
111125
}
112126

113-
rs[0] = marshal::hl_return_to_val(&mut c, |c, n| c.get_export(n), rv)?;
127+
let mut allocated_addrs = vec![];
128+
rs[0] = marshal::hl_return_to_val_with_tracking(
129+
&mut c,
130+
|c, n| c.get_export(n),
131+
rv,
132+
&mut allocated_addrs,
133+
)?;
134+
135+
// Track any allocations for later cleanup
136+
if !allocated_addrs.is_empty() {
137+
let mut global_addrs = HOSTFUNC_ALLOCATED_ADDRS.lock();
138+
global_addrs.extend(allocated_addrs);
139+
}
114140

115141
Ok(())
116142
}

src/wasm_runtime/src/marshal.rs

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@ fn malloc<C: AsContextMut>(
4646
Ok(addr)
4747
}
4848

49+
fn free<C: AsContextMut>(
50+
ctx: &mut C,
51+
get_export: &impl Fn(&mut C, &str) -> Option<Extern>,
52+
addr: i32,
53+
) -> Result<()> {
54+
let free = get_export(&mut *ctx, "free")
55+
.and_then(Extern::into_func)
56+
.ok_or(HyperlightGuestError::new(
57+
ErrorCode::GuestError,
58+
"free function not exported".to_string(),
59+
))?;
60+
free.typed::<i32, ()>(&mut *ctx)?.call(&mut *ctx, addr)?;
61+
Ok(())
62+
}
63+
4964
fn write<C: AsContextMut>(
5065
ctx: &mut C,
5166
get_export: &impl Fn(&mut C, &str) -> Option<Extern>,
@@ -126,10 +141,11 @@ fn read_cstr<C: AsContextMut>(
126141
})
127142
}
128143

129-
pub fn hl_param_to_val<C: AsContextMut>(
144+
pub fn hl_param_to_val_with_tracking<C: AsContextMut>(
130145
mut ctx: C,
131146
get_export: impl Fn(&mut C, &str) -> Option<Extern>,
132147
param: &ParameterValue,
148+
allocated_addrs: &mut Vec<i32>,
133149
) -> Result<Val> {
134150
match param {
135151
ParameterValue::Int(i) => Ok(Val::I32(*i)),
@@ -144,17 +160,31 @@ pub fn hl_param_to_val<C: AsContextMut>(
144160
let nbytes = s.count_bytes() + 1; // include the NUL terminator
145161
let addr = malloc(&mut ctx, &get_export, nbytes)?;
146162
write(&mut ctx, &get_export, addr, s.as_bytes_with_nul())?;
163+
allocated_addrs.push(addr);
147164
Ok(Val::I32(addr))
148165
}
149166
ParameterValue::VecBytes(b) => {
150167
let addr = malloc(&mut ctx, &get_export, b.len())?;
151168
write(&mut ctx, &get_export, addr, b)?;
169+
allocated_addrs.push(addr);
152170
Ok(Val::I32(addr))
153171
// TODO: check that the next parameter is the correct length
154172
}
155173
}
156174
}
157175

176+
/// Helper function to free all tracked allocated addresses
177+
pub fn free_allocated_addrs<C: AsContextMut>(
178+
mut ctx: C,
179+
get_export: impl Fn(&mut C, &str) -> Option<Extern>,
180+
allocated_addrs: &[i32],
181+
) -> Result<()> {
182+
for &addr in allocated_addrs {
183+
free(&mut ctx, &get_export, addr)?;
184+
}
185+
Ok(())
186+
}
187+
158188
pub fn val_to_hl_result<C: AsContextMut>(
159189
mut ctx: C,
160190
get_export: impl Fn(&mut C, &str) -> Option<Extern>,
@@ -248,10 +278,11 @@ pub fn val_to_hl_param<'a, C: AsContextMut>(
248278
}
249279
}
250280

251-
pub fn hl_return_to_val<C: AsContextMut>(
281+
pub fn hl_return_to_val_with_tracking<C: AsContextMut>(
252282
ctx: &mut C,
253283
get_export: impl Fn(&mut C, &str) -> Option<Extern>,
254284
rv: ReturnValue,
285+
allocated_addrs: &mut Vec<i32>,
255286
) -> Result<Val> {
256287
match rv {
257288
ReturnValue::Int(i) => Ok(Val::I32(i)),
@@ -266,11 +297,13 @@ pub fn hl_return_to_val<C: AsContextMut>(
266297
let nbytes = s.count_bytes() + 1; // include the NUL terminator
267298
let addr = malloc(ctx, &get_export, nbytes)?;
268299
write(ctx, &get_export, addr, s.as_bytes_with_nul())?;
300+
allocated_addrs.push(addr);
269301
Ok(Val::I32(addr))
270302
}
271303
ReturnValue::VecBytes(b) => {
272304
let addr = malloc(ctx, &get_export, b.len())?;
273305
write(ctx, &get_export, addr, b.as_ref())?;
306+
allocated_addrs.push(addr);
274307
Ok(Val::I32(addr))
275308
// TODO: check that the next parameter is the correct length
276309
}

src/wasm_runtime/src/module.rs

Lines changed: 79 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License.
1717
use alloc::string::ToString;
1818
use alloc::vec::Vec;
1919
use alloc::{format, vec};
20-
use core::ops::Deref;
20+
use core::ops::{Deref, DerefMut};
2121

2222
use hyperlight_common::flatbuffer_wrappers::function_call::FunctionCall;
2323
use hyperlight_common::flatbuffer_wrappers::function_types::{
@@ -32,59 +32,95 @@ use hyperlight_guest_bin::host_comm::print_output_with_host_print;
3232
use spin::Mutex;
3333
use wasmtime::{Config, Engine, Linker, Module, Store, Val};
3434

35+
use crate::hostfuncs::take_hostfunc_allocated_addrs;
3536
use crate::{hostfuncs, marshal, platform, wasip1};
3637

38+
// Set by transition to WasmSandbox (by init_wasm_runtime)
3739
static CUR_ENGINE: Mutex<Option<Engine>> = Mutex::new(None);
3840
static CUR_LINKER: Mutex<Option<Linker<()>>> = Mutex::new(None);
41+
// Set by transition to LoadedWasmSandbox (by load_wasm_module/load_wasm_module_phys)
3942
static CUR_MODULE: Mutex<Option<Module>> = Mutex::new(None);
43+
static CUR_STORE: Mutex<Option<Store<()>>> = Mutex::new(None);
44+
static CUR_INSTANCE: Mutex<Option<wasmtime::Instance>> = Mutex::new(None);
4045

4146
#[no_mangle]
42-
pub fn guest_dispatch_function(function_call: &FunctionCall) -> Result<Vec<u8>> {
43-
let engine = CUR_ENGINE.lock();
44-
let engine = engine.deref().as_ref().ok_or(HyperlightGuestError::new(
47+
pub fn guest_dispatch_function(function_call: FunctionCall) -> Result<Vec<u8>> {
48+
let mut store = CUR_STORE.lock();
49+
let store = store.deref_mut().as_mut().ok_or(HyperlightGuestError::new(
4550
ErrorCode::GuestError,
46-
"Wasm runtime is not initialized".to_string(),
51+
"No wasm store available".to_string(),
4752
))?;
48-
let linker = CUR_LINKER.lock();
49-
let linker = linker.deref().as_ref().ok_or(HyperlightGuestError::new(
53+
let instance = CUR_INSTANCE.lock();
54+
let instance = instance.deref().as_ref().ok_or(HyperlightGuestError::new(
5055
ErrorCode::GuestError,
51-
"impossible: wasm runtime has no valid linker".to_string(),
56+
"No wasm instance available".to_string(),
5257
))?;
53-
let module = CUR_MODULE.lock();
54-
let module = module.deref().as_ref().ok_or(HyperlightGuestError::new(
55-
ErrorCode::GuestError,
56-
"No wasm module loaded".to_string(),
57-
))?;
58-
let mut store = Store::new(engine, ());
59-
let instance = linker.instantiate(&mut store, module)?;
58+
6059
let func = instance
61-
.get_func(&mut store, &function_call.function_name)
60+
.get_func(&mut *store, &function_call.function_name)
6261
.ok_or(HyperlightGuestError::new(
6362
ErrorCode::GuestError,
6463
"Function not found".to_string(),
6564
))?;
65+
6666
let mut w_params = vec![];
67+
let mut allocated_addrs = vec![];
6768
for f_param in (function_call.parameters)
6869
.as_ref()
6970
.unwrap_or(&vec![])
7071
.iter()
7172
{
72-
w_params.push(marshal::hl_param_to_val(
73-
&mut store,
73+
w_params.push(marshal::hl_param_to_val_with_tracking(
74+
&mut *store,
7475
|ctx, name| instance.get_export(ctx, name),
7576
f_param,
77+
&mut allocated_addrs,
7678
)?);
7779
}
7880
let is_void = ReturnType::Void == function_call.expected_return_type;
7981
let n_results = if is_void { 0 } else { 1 };
8082
let mut results = vec![Val::I32(0); n_results];
81-
func.call(&mut store, &w_params, &mut results)?;
82-
marshal::val_to_hl_result(
83-
&mut store,
83+
func.call(&mut *store, &w_params, &mut results)?;
84+
let result = marshal::val_to_hl_result(
85+
&mut *store,
8486
|ctx, name| instance.get_export(ctx, name),
8587
function_call.expected_return_type,
8688
&results,
89+
);
90+
91+
// Free memory allocated during marshalling of hyperlight parameters into wasm parameters
92+
marshal::free_allocated_addrs(
93+
&mut *store,
94+
|ctx, name| instance.get_export(ctx, name),
95+
&allocated_addrs,
8796
)
97+
.map_err(|e| {
98+
HyperlightGuestError::new(
99+
ErrorCode::GuestError,
100+
format!("Failed to free memory allocated for params: {:?}", e),
101+
)
102+
})?;
103+
104+
// Free memory allocated by marshalling host function return values into wasm values
105+
let hostfunc_addrs = take_hostfunc_allocated_addrs();
106+
if !hostfunc_addrs.is_empty() {
107+
marshal::free_allocated_addrs(
108+
&mut *store,
109+
|ctx, name| instance.get_export(ctx, name),
110+
&hostfunc_addrs,
111+
)
112+
.map_err(|e| {
113+
HyperlightGuestError::new(
114+
ErrorCode::GuestError,
115+
format!(
116+
"Failed to free memory allocated for host function returns: {:?}",
117+
e
118+
),
119+
)
120+
})?;
121+
}
122+
123+
result
88124
}
89125

90126
fn init_wasm_runtime() -> Result<Vec<u8>> {
@@ -124,8 +160,19 @@ fn load_wasm_module(function_call: &FunctionCall) -> Result<Vec<u8>> {
124160
&function_call.parameters.as_ref().unwrap()[1],
125161
&*CUR_ENGINE.lock(),
126162
) {
163+
let linker = CUR_LINKER.lock();
164+
let linker = linker.deref().as_ref().ok_or(HyperlightGuestError::new(
165+
ErrorCode::GuestError,
166+
"impossible: wasm runtime has no valid linker".to_string(),
167+
))?;
168+
127169
let module = unsafe { Module::deserialize(engine, wasm_bytes)? };
170+
let mut store = Store::new(engine, ());
171+
let instance = linker.instantiate(&mut store, &module)?;
172+
128173
*CUR_MODULE.lock() = Some(module);
174+
*CUR_STORE.lock() = Some(store);
175+
*CUR_INSTANCE.lock() = Some(instance);
129176
Ok(get_flatbuffer_result::<i32>(0))
130177
} else {
131178
Err(HyperlightGuestError::new(
@@ -141,8 +188,19 @@ fn load_wasm_module_phys(function_call: &FunctionCall) -> Result<Vec<u8>> {
141188
&function_call.parameters.as_ref().unwrap()[1],
142189
&*CUR_ENGINE.lock(),
143190
) {
191+
let linker = CUR_LINKER.lock();
192+
let linker = linker.deref().as_ref().ok_or(HyperlightGuestError::new(
193+
ErrorCode::GuestError,
194+
"impossible: wasm runtime has no valid linker".to_string(),
195+
))?;
196+
144197
let module = unsafe { Module::deserialize_raw(engine, platform::map_buffer(*phys, *len))? };
198+
let mut store = Store::new(engine, ());
199+
let instance = linker.instantiate(&mut store, &module)?;
200+
145201
*CUR_MODULE.lock() = Some(module);
202+
*CUR_STORE.lock() = Some(store);
203+
*CUR_INSTANCE.lock() = Some(instance);
146204
Ok(get_flatbuffer_result::<()>(()))
147205
} else {
148206
Err(HyperlightGuestError::new(

src/wasm_runtime/src/platform.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ pub(crate) fn register_page_fault_handler() {
5959
// See AMD64 Architecture Programmer's Manual, Volume 2
6060
// §8.2 Vectors, p. 245
6161
// Table 8-1: Interrupt Vector Source and Cause
62-
handler::handlers[14].store(page_fault_handler as usize as u64, Ordering::Release);
62+
handler::HANDLERS[14].store(page_fault_handler as usize as u64, Ordering::Release);
6363
}
6464

6565
// Wasmtime Embedding Interface
@@ -155,7 +155,7 @@ pub extern "C" fn wasmtime_init_traps(handler: wasmtime_trap_handler_t) -> i32 {
155155
// See AMD64 Architecture Programmer's Manual, Volume 2
156156
// §8.2 Vectors, p. 245
157157
// Table 8-1: Interrupt Vector Source and Cause
158-
handler::handlers[6].store(wasmtime_trap_handler as usize as u64, Ordering::Release);
158+
handler::HANDLERS[6].store(wasmtime_trap_handler as usize as u64, Ordering::Release);
159159
// TODO: Add handlers for any other traps that wasmtime needs,
160160
// probably including at least some floating-point
161161
// exceptions

0 commit comments

Comments
 (0)