From 5377a55ac93abbdbe433c39653262881126e5e17 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Wed, 24 Sep 2025 15:43:23 +0100 Subject: [PATCH 01/14] feat(virtio-mem): add static allocation of hotpluggable memory Allocate the memory that will be used for hotplugging. Initially, this memory will be registered with KVM, but that will change later when we add dynamic slot support. Signed-off-by: Riccardo Mancini --- src/vmm/src/builder.rs | 15 ++++++++++++++- src/vmm/src/resources.rs | 12 ++++++++++++ src/vmm/src/vstate/memory.rs | 10 ++++++++++ src/vmm/src/vstate/vm.rs | 13 +++++++++++++ 4 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/vmm/src/builder.rs b/src/vmm/src/builder.rs index cb901a78c63..5556c62e44f 100644 --- a/src/vmm/src/builder.rs +++ b/src/vmm/src/builder.rs @@ -32,7 +32,7 @@ use crate::device_manager::{ use crate::devices::acpi::vmgenid::VmGenIdError; use crate::devices::virtio::balloon::Balloon; use crate::devices::virtio::block::device::Block; -use crate::devices::virtio::mem::VirtioMem; +use crate::devices::virtio::mem::{VIRTIO_MEM_GUEST_ADDRESS, VirtioMem}; use crate::devices::virtio::net::Net; use crate::devices::virtio::rng::Entropy; use crate::devices::virtio::vsock::{Vsock, VsockUnixBackend}; @@ -44,6 +44,7 @@ use crate::persist::{MicrovmState, MicrovmStateError}; use crate::resources::VmResources; use crate::seccomp::BpfThreadMap; use crate::snapshot::Persist; +use crate::utils::mib_to_bytes; use crate::vmm_config::instance_info::InstanceInfo; use crate::vmm_config::machine_config::MachineConfigError; use crate::vmm_config::memory_hotplug::MemoryHotplugConfig; @@ -172,6 +173,18 @@ pub fn build_microvm_for_boot( let (mut vcpus, vcpus_exit_evt) = vm.create_vcpus(vm_resources.machine_config.vcpu_count)?; vm.register_dram_memory_regions(guest_memory)?; + // Allocate memory as soon as possible to make hotpluggable memory available to all consumers, + // before they clone the GuestMemoryMmap object + if let Some(memory_hotplug) = &vm_resources.memory_hotplug { + let hotplug_memory_region = vm_resources + .allocate_memory_region( + VIRTIO_MEM_GUEST_ADDRESS, + mib_to_bytes(memory_hotplug.total_size_mib), + ) + .map_err(StartMicrovmError::GuestMemory)?; + vm.register_hotpluggable_memory_region(hotplug_memory_region)?; + } + let mut device_manager = DeviceManager::new( event_manager, &vcpus_exit_evt, diff --git a/src/vmm/src/resources.rs b/src/vmm/src/resources.rs index 03d9f9b0c77..25fbc14e5db 100644 --- a/src/vmm/src/resources.rs +++ b/src/vmm/src/resources.rs @@ -536,6 +536,18 @@ impl VmResources { crate::arch::arch_memory_regions(mib_to_bytes(self.machine_config.mem_size_mib)); self.allocate_memory_regions(®ions) } + + /// Allocates a single guest memory region. + pub fn allocate_memory_region( + &self, + start: GuestAddress, + size: usize, + ) -> Result { + Ok(self + .allocate_memory_regions(&[(start, size)])? + .pop() + .unwrap()) + } } impl From<&VmResources> for VmmConfig { diff --git a/src/vmm/src/vstate/memory.rs b/src/vmm/src/vstate/memory.rs index 1bf7cda6342..7fa349f6f4a 100644 --- a/src/vmm/src/vstate/memory.rs +++ b/src/vmm/src/vstate/memory.rs @@ -57,6 +57,8 @@ pub enum MemoryError { pub enum GuestRegionType { /// Guest DRAM Dram, + /// Hotpluggable memory + Hotpluggable, } /// An extension to GuestMemoryRegion that stores the type of region, and the KVM slot @@ -80,6 +82,14 @@ impl GuestRegionMmapExt { } } + pub(crate) fn hotpluggable_from_mmap_region(region: GuestRegionMmap, slot: u32) -> Self { + GuestRegionMmapExt { + inner: region, + region_type: GuestRegionType::Hotpluggable, + slot, + } + } + pub(crate) fn from_state( region: GuestRegionMmap, state: &GuestMemoryRegionState, diff --git a/src/vmm/src/vstate/vm.rs b/src/vmm/src/vstate/vm.rs index cc6afb722a2..67771473355 100644 --- a/src/vmm/src/vstate/vm.rs +++ b/src/vmm/src/vstate/vm.rs @@ -222,6 +222,19 @@ impl Vm { Ok(()) } + /// Register a new hotpluggable region to this [`Vm`]. + pub fn register_hotpluggable_memory_region( + &mut self, + region: GuestRegionMmap, + ) -> Result<(), VmError> { + let arcd_region = Arc::new(GuestRegionMmapExt::hotpluggable_from_mmap_region( + region, + self.allocate_slot_ids(1)?, + )); + + self._register_memory_region(arcd_region) + } + /// Register a list of new memory regions to this [`Vm`]. /// /// Note: regions and state.regions need to be in the same order. From 973793214d2a7112b1fb184242b4b80408be4994 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Wed, 24 Sep 2025 15:51:01 +0100 Subject: [PATCH 02/14] feat(virtio-mem): wire PATCH support Wire up PATCH requests with the virtio-mem device. All the validation is performed in the device, but the actual operation is not yet implemented. Signed-off-by: Riccardo Mancini --- .../src/api_server/parsed_request.rs | 5 +- .../src/api_server/request/hotplug/memory.rs | 37 ++++++++++++++- src/vmm/src/devices/virtio/mem/device.rs | 46 +++++++++++++++++++ src/vmm/src/lib.rs | 9 ++++ src/vmm/src/logger/metrics.rs | 6 +++ src/vmm/src/rpc_interface.rs | 23 +++++++++- src/vmm/src/vmm_config/memory_hotplug.rs | 8 ++++ 7 files changed, 130 insertions(+), 4 deletions(-) diff --git a/src/firecracker/src/api_server/parsed_request.rs b/src/firecracker/src/api_server/parsed_request.rs index 287742ede41..3d21695ce3e 100644 --- a/src/firecracker/src/api_server/parsed_request.rs +++ b/src/firecracker/src/api_server/parsed_request.rs @@ -28,7 +28,7 @@ use super::request::snapshot::{parse_patch_vm_state, parse_put_snapshot}; use super::request::version::parse_get_version; use super::request::vsock::parse_put_vsock; use crate::api_server::request::hotplug::memory::{ - parse_get_memory_hotplug, parse_put_memory_hotplug, + parse_get_memory_hotplug, parse_patch_memory_hotplug, parse_put_memory_hotplug, }; use crate::api_server::request::serial::parse_put_serial; @@ -119,6 +119,9 @@ impl TryFrom<&Request> for ParsedRequest { parse_patch_net(body, path_tokens.next()) } (Method::Patch, "vm", Some(body)) => parse_patch_vm_state(body), + (Method::Patch, "hotplug", Some(body)) if path_tokens.next() == Some("memory") => { + parse_patch_memory_hotplug(body) + } (Method::Patch, _, None) => method_to_error(Method::Patch), (method, unknown_uri, _) => Err(RequestError::InvalidPathMethod( unknown_uri.to_string(), diff --git a/src/firecracker/src/api_server/request/hotplug/memory.rs b/src/firecracker/src/api_server/request/hotplug/memory.rs index 4bdeec73a6d..5ec514ca964 100644 --- a/src/firecracker/src/api_server/request/hotplug/memory.rs +++ b/src/firecracker/src/api_server/request/hotplug/memory.rs @@ -4,7 +4,7 @@ use micro_http::Body; use vmm::logger::{IncMetric, METRICS}; use vmm::rpc_interface::VmmAction; -use vmm::vmm_config::memory_hotplug::MemoryHotplugConfig; +use vmm::vmm_config::memory_hotplug::{MemoryHotplugConfig, MemoryHotplugSizeUpdate}; use crate::api_server::parsed_request::{ParsedRequest, RequestError}; @@ -23,11 +23,23 @@ pub(crate) fn parse_get_memory_hotplug() -> Result Ok(ParsedRequest::new_sync(VmmAction::GetMemoryHotplugStatus)) } +pub(crate) fn parse_patch_memory_hotplug(body: &Body) -> Result { + METRICS.patch_api_requests.hotplug_memory_count.inc(); + let config = + serde_json::from_slice::(body.raw()).inspect_err(|_| { + METRICS.patch_api_requests.hotplug_memory_fails.inc(); + })?; + Ok(ParsedRequest::new_sync(VmmAction::UpdateMemoryHotplugSize( + config, + ))) +} + #[cfg(test)] mod tests { use vmm::devices::virtio::mem::{ VIRTIO_MEM_DEFAULT_BLOCK_SIZE_MIB, VIRTIO_MEM_DEFAULT_SLOT_SIZE_MIB, }; + use vmm::vmm_config::memory_hotplug::MemoryHotplugSizeUpdate; use super::*; use crate::api_server::parsed_request::tests::vmm_action_from_request; @@ -80,4 +92,27 @@ mod tests { VmmAction::GetMemoryHotplugStatus ); } + + #[test] + fn test_parse_patch_memory_hotplug_request() { + parse_patch_memory_hotplug(&Body::new("invalid_payload")).unwrap_err(); + + // PATCH with invalid fields. + let body = r#"{ + "requested_size_mib": "bar" + }"#; + parse_patch_memory_hotplug(&Body::new(body)).unwrap_err(); + + // PATCH with valid input fields. + let body = r#"{ + "requested_size_mib": 2048 + }"#; + let expected_config = MemoryHotplugSizeUpdate { + requested_size_mib: 2048, + }; + assert_eq!( + vmm_action_from_request(parse_patch_memory_hotplug(&Body::new(body)).unwrap()), + VmmAction::UpdateMemoryHotplugSize(expected_config) + ); + } } diff --git a/src/vmm/src/devices/virtio/mem/device.rs b/src/vmm/src/devices/virtio/mem/device.rs index c8bcb6cbf53..b0855820cf0 100644 --- a/src/vmm/src/devices/virtio/mem/device.rs +++ b/src/vmm/src/devices/virtio/mem/device.rs @@ -43,6 +43,10 @@ pub enum VirtioMemError { EventFd(#[from] io::Error), /// Received error while sending an interrupt: {0} InterruptError(#[from] InterruptError), + /// Size {0} is invalid: it must be a multiple of block size and less than the total size + InvalidSize(u64), + /// Device is not active + DeviceNotActive, } #[derive(Debug)] @@ -200,6 +204,48 @@ impl VirtioMem { pub(crate) fn activate_event(&self) -> &EventFd { &self.activate_event } + + /// Updates the requested size of the virtio-mem device. + pub fn update_requested_size( + &mut self, + requested_size_mib: usize, + ) -> Result<(), VirtioMemError> { + let requested_size = usize_to_u64(mib_to_bytes(requested_size_mib)); + if !self.is_activated() { + return Err(VirtioMemError::DeviceNotActive); + } + + if requested_size % self.config.block_size != 0 { + return Err(VirtioMemError::InvalidSize(requested_size)); + } + if requested_size > self.config.region_size { + return Err(VirtioMemError::InvalidSize(requested_size)); + } + + // Increase the usable_region_size if it's not enough for the guest to plug new + // memory blocks. + // The device cannot decrease the usable_region_size unless the guest requests + // to reset it with an UNPLUG_ALL request. + if self.config.usable_region_size < requested_size { + self.config.usable_region_size = + requested_size.next_multiple_of(usize_to_u64(self.slot_size)); + debug!( + "virtio-mem: Updated usable size to {} bytes", + self.config.usable_region_size + ); + } + + self.config.requested_size = requested_size; + debug!( + "virtio-mem: Updated requested size to {} bytes", + requested_size + ); + // TODO(virtio-mem): trigger interrupt once we add handling for the requests + // self.interrupt_trigger() + // .trigger(VirtioInterruptType::Config) + // .map_err(VirtioMemError::InterruptError) + Ok(()) + } } impl VirtioDevice for VirtioMem { diff --git a/src/vmm/src/lib.rs b/src/vmm/src/lib.rs index 0b6fee2e0a0..79e26c706a1 100644 --- a/src/vmm/src/lib.rs +++ b/src/vmm/src/lib.rs @@ -609,6 +609,15 @@ impl Vmm { .map_err(VmmError::FindDeviceError) } + /// Returns the current state of the memory hotplug device. + pub fn update_memory_hotplug_size(&self, requested_size_mib: usize) -> Result<(), VmmError> { + self.device_manager + .try_with_virtio_device_with_id(VIRTIO_MEM_DEV_ID, |dev: &mut VirtioMem| { + dev.update_requested_size(requested_size_mib) + }) + .map_err(VmmError::FindDeviceError) + } + /// Signals Vmm to stop and exit. pub fn stop(&mut self, exit_code: FcExitCode) { // To avoid cycles, all teardown paths take the following route: diff --git a/src/vmm/src/logger/metrics.rs b/src/vmm/src/logger/metrics.rs index c983a5a9f16..060a751562a 100644 --- a/src/vmm/src/logger/metrics.rs +++ b/src/vmm/src/logger/metrics.rs @@ -479,6 +479,10 @@ pub struct PatchRequestsMetrics { pub mmds_count: SharedIncMetric, /// Number of failures in PATCHing an mmds. pub mmds_fails: SharedIncMetric, + /// Number of PATCHes to /hotplug/memory + pub hotplug_memory_count: SharedIncMetric, + /// Number of failed PATCHes to /hotplug/memory + pub hotplug_memory_fails: SharedIncMetric, } impl PatchRequestsMetrics { /// Const default construction. @@ -492,6 +496,8 @@ impl PatchRequestsMetrics { machine_cfg_fails: SharedIncMetric::new(), mmds_count: SharedIncMetric::new(), mmds_fails: SharedIncMetric::new(), + hotplug_memory_count: SharedIncMetric::new(), + hotplug_memory_fails: SharedIncMetric::new(), } } } diff --git a/src/vmm/src/rpc_interface.rs b/src/vmm/src/rpc_interface.rs index fe3e9c296e7..d25ddb735a8 100644 --- a/src/vmm/src/rpc_interface.rs +++ b/src/vmm/src/rpc_interface.rs @@ -29,7 +29,9 @@ use crate::vmm_config::drive::{BlockDeviceConfig, BlockDeviceUpdateConfig, Drive use crate::vmm_config::entropy::{EntropyDeviceConfig, EntropyDeviceError}; use crate::vmm_config::instance_info::InstanceInfo; use crate::vmm_config::machine_config::{MachineConfig, MachineConfigError, MachineConfigUpdate}; -use crate::vmm_config::memory_hotplug::{MemoryHotplugConfig, MemoryHotplugConfigError}; +use crate::vmm_config::memory_hotplug::{ + MemoryHotplugConfig, MemoryHotplugConfigError, MemoryHotplugSizeUpdate, +}; use crate::vmm_config::metrics::{MetricsConfig, MetricsConfigError}; use crate::vmm_config::mmds::{MmdsConfig, MmdsConfigError}; use crate::vmm_config::net::{ @@ -113,6 +115,9 @@ pub enum VmmAction { /// Set the memory hotplug device using `MemoryHotplugConfig` as input. This action can only be /// called before the microVM has booted. SetMemoryHotplugDevice(MemoryHotplugConfig), + /// Updates the memory hotplug device using `MemoryHotplugConfigUpdate` as input. This action + /// can only be called after the microVM has booted. + UpdateMemoryHotplugSize(MemoryHotplugSizeUpdate), /// Launch the microVM. This action can only be called before the microVM has booted. StartMicroVm, /// Send CTRL+ALT+DEL to the microVM, using the i8042 keyboard function. If an AT-keyboard @@ -152,6 +157,8 @@ pub enum VmmActionError { EntropyDevice(#[from] EntropyDeviceError), /// Memory hotplug config error: {0} MemoryHotplugConfig(#[from] MemoryHotplugConfigError), + /// Memory hotplug update error: {0} + MemoryHotplugUpdate(VmmError), /// Internal VMM error: {0} InternalVmm(#[from] VmmError), /// Load snapshot error: {0} @@ -469,6 +476,7 @@ impl<'a> PrebootApiController<'a> { | UpdateBalloon(_) | UpdateBalloonStatistics(_) | UpdateBlockDevice(_) + | UpdateMemoryHotplugSize(_) | UpdateNetworkInterface(_) => Err(VmmActionError::OperationNotSupportedPreBoot), #[cfg(target_arch = "x86_64")] SendCtrlAltDel => Err(VmmActionError::OperationNotSupportedPreBoot), @@ -709,7 +717,13 @@ impl RuntimeApiController { .map_err(VmmActionError::BalloonUpdate), UpdateBlockDevice(new_cfg) => self.update_block_device(new_cfg), UpdateNetworkInterface(netif_update) => self.update_net_rate_limiters(netif_update), - + UpdateMemoryHotplugSize(cfg) => self + .vmm + .lock() + .expect("Poisoned lock") + .update_memory_hotplug_size(cfg.requested_size_mib) + .map(|_| VmmData::Empty) + .map_err(VmmActionError::MemoryHotplugUpdate), // Operations not allowed post-boot. ConfigureBootSource(_) | ConfigureLogger(_) @@ -1181,6 +1195,11 @@ mod tests { ))); #[cfg(target_arch = "x86_64")] check_unsupported(preboot_request(VmmAction::SendCtrlAltDel)); + check_unsupported(preboot_request(VmmAction::UpdateMemoryHotplugSize( + MemoryHotplugSizeUpdate { + requested_size_mib: 0, + }, + ))); } fn runtime_request(request: VmmAction) -> Result { diff --git a/src/vmm/src/vmm_config/memory_hotplug.rs b/src/vmm/src/vmm_config/memory_hotplug.rs index d09141c1b66..85cf45ee5e8 100644 --- a/src/vmm/src/vmm_config/memory_hotplug.rs +++ b/src/vmm/src/vmm_config/memory_hotplug.rs @@ -86,6 +86,14 @@ impl MemoryHotplugConfig { } } +/// Configuration for memory hotplug device. +#[derive(Clone, Debug, Default, PartialEq, Eq, Deserialize, Serialize)] +#[serde(deny_unknown_fields)] +pub struct MemoryHotplugSizeUpdate { + /// Requested size in MiB to resize the hotpluggable memory to. + pub requested_size_mib: usize, +} + #[cfg(test)] mod tests { use serde_json; From f9ea479d2d0ba82e357427f3a933e6aa3c2c6f21 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Wed, 24 Sep 2025 16:00:07 +0100 Subject: [PATCH 03/14] doc(virtio-mem): document PATCH API in swagger and docs Add entry for the patch API in Swagger and in the docs. Signed-off-by: Riccardo Mancini --- docs/device-api.md | 1 + src/firecracker/swagger/firecracker.yaml | 29 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/docs/device-api.md b/docs/device-api.md index 01e470f1d64..0b8035651f8 100644 --- a/docs/device-api.md +++ b/docs/device-api.md @@ -104,6 +104,7 @@ specification: | `MemoryHotplugConfig` | total_size_mib | O | O | O | O | O | O | O | **R** | | | slot_size_mib | O | O | O | O | O | O | O | **R** | | | block_size_mi | O | O | O | O | O | O | O | **R** | +| `MemoryHotplugSizeUpdate` | requested_size_mib | O | O | O | O | O | O | O | **R** | \* `Drive`'s `drive_id`, `is_root_device` and `partuuid` can be configured by either virtio-block or vhost-user-block devices. diff --git a/src/firecracker/swagger/firecracker.yaml b/src/firecracker/swagger/firecracker.yaml index c5011a79fd7..bddc4942c2a 100644 --- a/src/firecracker/swagger/firecracker.yaml +++ b/src/firecracker/swagger/firecracker.yaml @@ -549,6 +549,26 @@ paths: description: Internal server error schema: $ref: "#/definitions/Error" + patch: + summary: Updates the size of the hotpluggable memory region + operationId: patchMemoryHotplug + description: + Updates the size of the hotpluggable memory region. The guest will plug and unplug memory to + hit the requested memory. + parameters: + - name: body + in: body + description: Hotpluggable memory size update + required: true + schema: + $ref: "#/definitions/MemoryHotplugSizeUpdate" + responses: + 204: + description: Hotpluggable memory configured + default: + description: Internal server error + schema: + $ref: "#/definitions/Error" get: summary: Retrieves the status of the hotpluggable memory operationId: getMemoryHotplug @@ -1422,6 +1442,15 @@ definitions: description: (Logical) Block size for the hotpluggable memory in MiB. This will determine the logical granularity of hot-plug memory for the guest. Refer to the device documentation on how to tune this value. + MemoryHotplugSizeUpdate: + type: object + description: + An update to the size of the hotpluggable memory region. + properties: + requested_size_mib: + type: integer + description: New target region size. + MemoryHotplugStatus: type: object description: From 5c1c0f5cee294f767cd19d0375516e443737e8c1 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Wed, 24 Sep 2025 18:03:00 +0100 Subject: [PATCH 04/14] test(virtio-mem): add API tests for PATCH Test that the new PATCH API behaves as expected. Also updates expected metrics and fixes memory monitor to account for hotplugging. Signed-off-by: Riccardo Mancini --- tests/host_tools/fcmetrics.py | 2 ++ tests/host_tools/memory.py | 8 ++++++-- tests/integration_tests/functional/test_api.py | 14 ++++++++++++-- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/tests/host_tools/fcmetrics.py b/tests/host_tools/fcmetrics.py index 0dcff5eed00..6b110a70f96 100644 --- a/tests/host_tools/fcmetrics.py +++ b/tests/host_tools/fcmetrics.py @@ -202,6 +202,8 @@ def validate_fc_metrics(metrics): "machine_cfg_fails", "mmds_count", "mmds_fails", + "hotplug_memory_count", + "hotplug_memory_fails", ], "put_api_requests": [ "actions_count", diff --git a/tests/host_tools/memory.py b/tests/host_tools/memory.py index 134147724cd..c09dae0d206 100644 --- a/tests/host_tools/memory.py +++ b/tests/host_tools/memory.py @@ -99,7 +99,9 @@ def is_guest_mem_x86(self, size, guest_mem_bytes): Checks if a region is a guest memory region based on x86_64 physical memory layout """ - return size in ( + # it could be bigger if hotplugging is enabled + # if it's bigger, it's likely not from FC because we don't have big allocations + return size >= guest_mem_bytes or size in ( # memory fits before the first gap guest_mem_bytes, # guest memory spans at least two regions & memory fits before the second gap @@ -121,7 +123,9 @@ def is_guest_mem_arch64(self, size, guest_mem_bytes): Checks if a region is a guest memory region based on ARM64 physical memory layout """ - return size in ( + # it could be bigger if hotplugging is enabled + # if it's bigger, it's likely not from FC because we don't have big allocations + return size >= guest_mem_bytes or size in ( # guest memory fits before the gap guest_mem_bytes, # guest memory fills the space before the gap diff --git a/tests/integration_tests/functional/test_api.py b/tests/integration_tests/functional/test_api.py index 81454559990..ac929941dca 100644 --- a/tests/integration_tests/functional/test_api.py +++ b/tests/integration_tests/functional/test_api.py @@ -981,13 +981,14 @@ def test_api_entropy(uvm_plain): test_microvm.api.entropy.put() -def test_api_memory_hotplug(uvm_plain): +def test_api_memory_hotplug(uvm_plain_6_1): """ Test hotplug related API commands. """ - test_microvm = uvm_plain + test_microvm = uvm_plain_6_1 test_microvm.spawn() test_microvm.basic_config() + test_microvm.add_net_iface() # Adding hotplug memory region should be OK. test_microvm.api.memory_hotplug.put( @@ -1002,6 +1003,10 @@ def test_api_memory_hotplug(uvm_plain): with pytest.raises(AssertionError): test_microvm.api.memory_hotplug.get() + # Patch API should be rejected before boot + with pytest.raises(RuntimeError, match=NOT_SUPPORTED_BEFORE_START): + test_microvm.api.memory_hotplug.patch(requested_size_mib=512) + # Start the microvm test_microvm.start() @@ -1013,6 +1018,11 @@ def test_api_memory_hotplug(uvm_plain): status = test_microvm.api.memory_hotplug.get().json() assert status["total_size_mib"] == 1024 + # Patch API should work after boot + test_microvm.api.memory_hotplug.patch(requested_size_mib=512) + status = test_microvm.api.memory_hotplug.get().json() + assert status["requested_size_mib"] == 512 + def test_api_balloon(uvm_nano): """ From d5f594af64c28942e3c71e5a81093e4a7ca25a68 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Wed, 24 Sep 2025 18:22:27 +0100 Subject: [PATCH 05/14] feat(virtio-mem): add virtio request parsing and dummy response Parse virtio requests over the queue and always ack them. Following commits will add the state management inside the device. Signed-off-by: Riccardo Mancini --- src/vmm/src/devices/virtio/mem/device.rs | 168 ++++++++++++++++++++-- src/vmm/src/devices/virtio/mem/metrics.rs | 24 ++++ src/vmm/src/devices/virtio/mem/mod.rs | 1 + src/vmm/src/devices/virtio/mem/request.rs | 119 +++++++++++++++ 4 files changed, 303 insertions(+), 9 deletions(-) create mode 100644 src/vmm/src/devices/virtio/mem/request.rs diff --git a/src/vmm/src/devices/virtio/mem/device.rs b/src/vmm/src/devices/virtio/mem/device.rs index b0855820cf0..aff896df45b 100644 --- a/src/vmm/src/devices/virtio/mem/device.rs +++ b/src/vmm/src/devices/virtio/mem/device.rs @@ -9,7 +9,7 @@ use std::sync::atomic::AtomicU32; use log::info; use serde::{Deserialize, Serialize}; use vm_memory::{ - Address, GuestAddress, GuestMemory, GuestMemoryError, GuestMemoryRegion, GuestUsize, + Address, Bytes, GuestAddress, GuestMemory, GuestMemoryError, GuestMemoryRegion, GuestUsize, }; use vmm_sys_util::eventfd::EventFd; @@ -20,12 +20,15 @@ use crate::devices::virtio::device::{ActiveState, DeviceState, VirtioDevice}; use crate::devices::virtio::generated::virtio_config::VIRTIO_F_VERSION_1; use crate::devices::virtio::generated::virtio_ids::VIRTIO_ID_MEM; use crate::devices::virtio::generated::virtio_mem::{ - VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE, virtio_mem_config, + self, VIRTIO_MEM_F_UNPLUGGED_INACCESSIBLE, virtio_mem_config, }; use crate::devices::virtio::iov_deque::IovDequeError; use crate::devices::virtio::mem::metrics::METRICS; +use crate::devices::virtio::mem::request::{BlockRangeState, Request, RequestedRange, Response}; use crate::devices::virtio::mem::{VIRTIO_MEM_DEV_ID, VIRTIO_MEM_GUEST_ADDRESS}; -use crate::devices::virtio::queue::{FIRECRACKER_MAX_QUEUE_SIZE, InvalidAvailIdx, Queue}; +use crate::devices::virtio::queue::{ + DescriptorChain, FIRECRACKER_MAX_QUEUE_SIZE, InvalidAvailIdx, Queue, QueueError, +}; use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType}; use crate::logger::{IncMetric, debug, error}; use crate::utils::{bytes_to_mib, mib_to_bytes, u64_to_usize, usize_to_u64}; @@ -47,6 +50,24 @@ pub enum VirtioMemError { InvalidSize(u64), /// Device is not active DeviceNotActive, + /// Descriptor is write-only + UnexpectedWriteOnlyDescriptor, + /// Error reading virtio descriptor + DescriptorWriteFailed, + /// Error writing virtio descriptor + DescriptorReadFailed, + /// Unknown request type: {0} + UnknownRequestType(u32), + /// Descriptor chain is too short + DescriptorChainTooShort, + /// Descriptor is too small + DescriptorLengthTooSmall, + /// Descriptor is read-only + UnexpectedReadOnlyDescriptor, + /// Error popping from virtio queue: {0} + InvalidAvailIdx(#[from] InvalidAvailIdx), + /// Error adding used queue: {0} + QueueError(#[from] QueueError), } #[derive(Debug)] @@ -170,8 +191,139 @@ impl VirtioMem { .map_err(VirtioMemError::InterruptError) } + fn guest_memory(&self) -> &GuestMemoryMmap { + &self.device_state.active_state().unwrap().mem + } + + fn parse_request( + &self, + avail_desc: &DescriptorChain, + ) -> Result<(Request, GuestAddress, u16), VirtioMemError> { + // The head contains the request type which MUST be readable. + if avail_desc.is_write_only() { + return Err(VirtioMemError::UnexpectedWriteOnlyDescriptor); + } + + if (avail_desc.len as usize) < size_of::() { + return Err(VirtioMemError::DescriptorLengthTooSmall); + } + + let request: virtio_mem::virtio_mem_req = self + .guest_memory() + .read_obj(avail_desc.addr) + .map_err(|_| VirtioMemError::DescriptorReadFailed)?; + + let resp_desc = avail_desc + .next_descriptor() + .ok_or(VirtioMemError::DescriptorChainTooShort)?; + + // The response MUST always be writable. + if !resp_desc.is_write_only() { + return Err(VirtioMemError::UnexpectedReadOnlyDescriptor); + } + + if (resp_desc.len as usize) < std::mem::size_of::() { + return Err(VirtioMemError::DescriptorLengthTooSmall); + } + + Ok((request.into(), resp_desc.addr, avail_desc.index)) + } + + fn write_response( + &mut self, + resp: Response, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + debug!("virtio-mem: Response: {:?}", resp); + self.guest_memory() + .write_obj(virtio_mem::virtio_mem_resp::from(resp), resp_addr) + .map_err(|_| VirtioMemError::DescriptorWriteFailed) + .map(|_| size_of::())?; + self.queues[MEM_QUEUE] + .add_used( + used_idx, + u32::try_from(std::mem::size_of::()).unwrap(), + ) + .map_err(VirtioMemError::QueueError) + } + + fn handle_plug_request( + &mut self, + range: &RequestedRange, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.plug_count.inc(); + let _metric = METRICS.plug_agg.record_latency_metrics(); + + // TODO: implement PLUG request + let response = Response::ack(); + self.write_response(response, resp_addr, used_idx) + } + + fn handle_unplug_request( + &mut self, + range: &RequestedRange, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.unplug_count.inc(); + let _metric = METRICS.unplug_agg.record_latency_metrics(); + + // TODO: implement UNPLUG request + let response = Response::ack(); + self.write_response(response, resp_addr, used_idx) + } + + fn handle_unplug_all_request( + &mut self, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.unplug_all_count.inc(); + let _metric = METRICS.unplug_all_agg.record_latency_metrics(); + + // TODO: implement UNPLUG ALL request + let response = Response::ack(); + self.write_response(response, resp_addr, used_idx) + } + + fn handle_state_request( + &mut self, + range: &RequestedRange, + resp_addr: GuestAddress, + used_idx: u16, + ) -> Result<(), VirtioMemError> { + METRICS.state_count.inc(); + let _metric = METRICS.state_agg.record_latency_metrics(); + + // TODO: implement STATE request + let response = Response::ack_with_state(BlockRangeState::Mixed); + self.write_response(response, resp_addr, used_idx) + } + fn process_mem_queue(&mut self) -> Result<(), VirtioMemError> { - info!("TODO: Received mem queue event, but it's not implemented."); + while let Some(desc) = self.queues[MEM_QUEUE].pop()? { + let index = desc.index; + + let (req, resp_addr, used_idx) = self.parse_request(&desc)?; + debug!("virtio-mem: Request: {:?}", req); + // Handle request and write response + match req { + Request::State(ref range) => self.handle_state_request(range, resp_addr, used_idx), + Request::Plug(ref range) => self.handle_plug_request(range, resp_addr, used_idx), + Request::Unplug(ref range) => { + self.handle_unplug_request(range, resp_addr, used_idx) + } + Request::UnplugAll => self.handle_unplug_all_request(resp_addr, used_idx), + Request::Unsupported(t) => Err(VirtioMemError::UnknownRequestType(t)), + }?; + } + + self.queues[MEM_QUEUE].advance_used_ring_idx(); + self.signal_used_queue()?; + Ok(()) } @@ -240,11 +392,9 @@ impl VirtioMem { "virtio-mem: Updated requested size to {} bytes", requested_size ); - // TODO(virtio-mem): trigger interrupt once we add handling for the requests - // self.interrupt_trigger() - // .trigger(VirtioInterruptType::Config) - // .map_err(VirtioMemError::InterruptError) - Ok(()) + self.interrupt_trigger() + .trigger(VirtioInterruptType::Config) + .map_err(VirtioMemError::InterruptError) } } diff --git a/src/vmm/src/devices/virtio/mem/metrics.rs b/src/vmm/src/devices/virtio/mem/metrics.rs index 443e9a8b8f1..e9e97707782 100644 --- a/src/vmm/src/devices/virtio/mem/metrics.rs +++ b/src/vmm/src/devices/virtio/mem/metrics.rs @@ -45,6 +45,22 @@ pub(super) struct VirtioMemDeviceMetrics { pub queue_event_fails: SharedIncMetric, /// Number of queue events handled pub queue_event_count: SharedIncMetric, + /// Latency of Plug operations + pub plug_agg: LatencyAggregateMetrics, + /// Number of Plug operations + pub plug_count: SharedIncMetric, + /// Latency of Unplug operations + pub unplug_agg: LatencyAggregateMetrics, + /// Number of Unplug operations + pub unplug_count: SharedIncMetric, + /// Latency of UnplugAll operations + pub unplug_all_agg: LatencyAggregateMetrics, + /// Number of UnplugAll operations + pub unplug_all_count: SharedIncMetric, + /// Latency of State operations + pub state_agg: LatencyAggregateMetrics, + /// Number of State operations + pub state_count: SharedIncMetric, } impl VirtioMemDeviceMetrics { @@ -54,6 +70,14 @@ impl VirtioMemDeviceMetrics { activate_fails: SharedIncMetric::new(), queue_event_fails: SharedIncMetric::new(), queue_event_count: SharedIncMetric::new(), + plug_agg: LatencyAggregateMetrics::new(), + plug_count: SharedIncMetric::new(), + unplug_agg: LatencyAggregateMetrics::new(), + unplug_count: SharedIncMetric::new(), + unplug_all_agg: LatencyAggregateMetrics::new(), + unplug_all_count: SharedIncMetric::new(), + state_agg: LatencyAggregateMetrics::new(), + state_count: SharedIncMetric::new(), } } } diff --git a/src/vmm/src/devices/virtio/mem/mod.rs b/src/vmm/src/devices/virtio/mem/mod.rs index 1c9e98f98a6..5c76afc4c24 100644 --- a/src/vmm/src/devices/virtio/mem/mod.rs +++ b/src/vmm/src/devices/virtio/mem/mod.rs @@ -5,6 +5,7 @@ mod device; mod event_handler; pub mod metrics; pub mod persist; +mod request; use vm_memory::GuestAddress; diff --git a/src/vmm/src/devices/virtio/mem/request.rs b/src/vmm/src/devices/virtio/mem/request.rs new file mode 100644 index 00000000000..c68220bdaeb --- /dev/null +++ b/src/vmm/src/devices/virtio/mem/request.rs @@ -0,0 +1,119 @@ +// Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use vm_memory::{ByteValued, GuestAddress}; + +use crate::devices::virtio::generated::virtio_mem; + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct RequestedRange { + pub(crate) addr: GuestAddress, + pub(crate) nb_blocks: usize, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub(crate) enum Request { + Plug(RequestedRange), + Unplug(RequestedRange), + UnplugAll, + State(RequestedRange), + Unsupported(u32), +} + +// SAFETY: this is safe, trust me bro +unsafe impl ByteValued for virtio_mem::virtio_mem_req {} + +impl From for Request { + fn from(req: virtio_mem::virtio_mem_req) -> Self { + match req.type_.into() { + // SAFETY: union type is checked in the match + virtio_mem::VIRTIO_MEM_REQ_PLUG => unsafe { + Request::Plug(RequestedRange { + addr: GuestAddress(req.u.plug.addr), + nb_blocks: req.u.plug.nb_blocks.into(), + }) + }, + // SAFETY: union type is checked in the match + virtio_mem::VIRTIO_MEM_REQ_UNPLUG => unsafe { + Request::Unplug(RequestedRange { + addr: GuestAddress(req.u.unplug.addr), + nb_blocks: req.u.unplug.nb_blocks.into(), + }) + }, + virtio_mem::VIRTIO_MEM_REQ_UNPLUG_ALL => Request::UnplugAll, + // SAFETY: union type is checked in the match + virtio_mem::VIRTIO_MEM_REQ_STATE => unsafe { + Request::State(RequestedRange { + addr: GuestAddress(req.u.state.addr), + nb_blocks: req.u.state.nb_blocks.into(), + }) + }, + t => Request::Unsupported(t), + } + } +} + +#[repr(u16)] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +#[allow(clippy::cast_possible_truncation)] +pub enum ResponseType { + Ack = virtio_mem::VIRTIO_MEM_RESP_ACK as u16, + Nack = virtio_mem::VIRTIO_MEM_RESP_NACK as u16, + Busy = virtio_mem::VIRTIO_MEM_RESP_BUSY as u16, + Error = virtio_mem::VIRTIO_MEM_RESP_ERROR as u16, +} + +#[repr(u16)] +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +#[allow(clippy::cast_possible_truncation)] +pub enum BlockRangeState { + Plugged = virtio_mem::VIRTIO_MEM_STATE_PLUGGED as u16, + Unplugged = virtio_mem::VIRTIO_MEM_STATE_UNPLUGGED as u16, + Mixed = virtio_mem::VIRTIO_MEM_STATE_MIXED as u16, +} + +#[derive(Debug, Clone)] +pub struct Response { + pub resp_type: ResponseType, + // Only for State requests + pub state: Option, +} + +impl Response { + pub(crate) fn error() -> Self { + Response { + resp_type: ResponseType::Error, + state: None, + } + } + + pub(crate) fn ack() -> Self { + Response { + resp_type: ResponseType::Ack, + state: None, + } + } + + pub(crate) fn ack_with_state(state: BlockRangeState) -> Self { + Response { + resp_type: ResponseType::Ack, + state: Some(state), + } + } +} + +// SAFETY: Plain data structures +unsafe impl ByteValued for virtio_mem::virtio_mem_resp {} + +impl From for virtio_mem::virtio_mem_resp { + fn from(resp: Response) -> Self { + let mut out = virtio_mem::virtio_mem_resp { + type_: resp.resp_type as u16, + ..Default::default() + }; + if let Some(state) = resp.state { + out.u.state.state = state as u16; + } + out + } +} From f72864eb55b332669b1bc35800638d7600a6c1ce Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Wed, 24 Sep 2025 18:31:09 +0100 Subject: [PATCH 06/14] feat(virtio-mem): implement virtio requests This commit adds block state management and implements the virtio requests for the virtio-mem device. Block state is tracked using a BitVec, each bit representing a single block. Plug/Unplug requests are validated before being executed to verify the range is valid (aligned and within range), and that all blocks in range are unplugged/plugged, as per the virtio spec. UplugAll is the only request where usable_region_size can be lowered. This commit is missing the dynamic KVM slot management which will be added later. Signed-off-by: Riccardo Mancini --- Cargo.lock | 41 ++++ src/vmm/Cargo.toml | 1 + src/vmm/src/devices/virtio/mem/device.rs | 217 ++++++++++++++++++++-- src/vmm/src/devices/virtio/mem/metrics.rs | 21 +++ src/vmm/src/devices/virtio/mem/persist.rs | 19 +- 5 files changed, 277 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0ea518f3a49..5a66c51b66e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -248,6 +248,19 @@ version = "2.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2261d10cca569e4643e526d8dc2e62e433cc8aba21ab764233731f8d369bf394" +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "serde", + "tap", + "wyz", +] + [[package]] name = "bumpalo" version = "3.19.0" @@ -648,6 +661,12 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "gdbstub" version = "0.7.7" @@ -1140,6 +1159,12 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "rand" version = "0.9.2" @@ -1414,6 +1439,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "thiserror" version = "1.0.69" @@ -1682,6 +1713,7 @@ dependencies = [ "base64", "bincode", "bitflags 2.9.4", + "bitvec", "byteorder", "crc64", "criterion", @@ -2008,6 +2040,15 @@ version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + [[package]] name = "zerocopy" version = "0.8.27" diff --git a/src/vmm/Cargo.toml b/src/vmm/Cargo.toml index d2601f4a305..aa6f9a2fc47 100644 --- a/src/vmm/Cargo.toml +++ b/src/vmm/Cargo.toml @@ -23,6 +23,7 @@ aws-lc-rs = { version = "1.14.0", features = ["bindgen"] } base64 = "0.22.1" bincode = { version = "2.0.1", features = ["serde"] } bitflags = "2.9.4" +bitvec = { version = "1.0.1", features = ["atomic", "serde"] } byteorder = "1.5.0" crc64 = "2.0.0" derive_more = { version = "2.0.1", default-features = false, features = [ diff --git a/src/vmm/src/devices/virtio/mem/device.rs b/src/vmm/src/devices/virtio/mem/device.rs index aff896df45b..2a837c44bcb 100644 --- a/src/vmm/src/devices/virtio/mem/device.rs +++ b/src/vmm/src/devices/virtio/mem/device.rs @@ -2,10 +2,11 @@ // SPDX-License-Identifier: Apache-2.0 use std::io; -use std::ops::Deref; +use std::ops::{Deref, Range}; use std::sync::Arc; use std::sync::atomic::AtomicU32; +use bitvec::vec::BitVec; use log::info; use serde::{Deserialize, Serialize}; use vm_memory::{ @@ -14,7 +15,6 @@ use vm_memory::{ use vmm_sys_util::eventfd::EventFd; use super::{MEM_NUM_QUEUES, MEM_QUEUE}; -use crate::devices::DeviceError; use crate::devices::virtio::ActivateError; use crate::devices::virtio::device::{ActiveState, DeviceState, VirtioDevice}; use crate::devices::virtio::generated::virtio_config::VIRTIO_F_VERSION_1; @@ -33,7 +33,7 @@ use crate::devices::virtio::transport::{VirtioInterrupt, VirtioInterruptType}; use crate::logger::{IncMetric, debug, error}; use crate::utils::{bytes_to_mib, mib_to_bytes, u64_to_usize, usize_to_u64}; use crate::vstate::interrupts::InterruptError; -use crate::vstate::memory::{ByteValued, GuestMemoryMmap, GuestRegionMmap}; +use crate::vstate::memory::{ByteValued, GuestMemoryExtension, GuestMemoryMmap, GuestRegionMmap}; use crate::vstate::vm::VmError; use crate::{Vm, impl_device_type}; @@ -68,6 +68,14 @@ pub enum VirtioMemError { InvalidAvailIdx(#[from] InvalidAvailIdx), /// Error adding used queue: {0} QueueError(#[from] QueueError), + /// Invalid requested range: {0:?}. + InvalidRange(RequestedRange), + /// The requested range cannot be plugged because it's {0:?}. + PlugRequestBlockStateInvalid(BlockRangeState), + /// Plug request rejected as plugged_size would be greater than requested_size + PlugRequestIsTooBig, + /// The requested range cannot be unplugged because it's {0:?}. + UnplugRequestBlockStateInvalid(BlockRangeState), } #[derive(Debug)] @@ -85,6 +93,8 @@ pub struct VirtioMem { // Device specific fields pub(crate) config: virtio_mem_config, pub(crate) slot_size: usize, + // Bitmap to track which blocks are plugged + pub(crate) plugged_blocks: BitVec, vm: Arc, } @@ -118,8 +128,15 @@ impl VirtioMem { block_size: mib_to_bytes(block_size_mib) as u64, ..Default::default() }; + let plugged_blocks = BitVec::repeat(false, total_size_mib / block_size_mib); - Self::from_state(vm, queues, config, mib_to_bytes(slot_size_mib)) + Self::from_state( + vm, + queues, + config, + mib_to_bytes(slot_size_mib), + plugged_blocks, + ) } pub fn from_state( @@ -127,6 +144,7 @@ impl VirtioMem { queues: Vec, config: virtio_mem_config, slot_size: usize, + plugged_blocks: BitVec, ) -> Result { let activate_event = EventFd::new(libc::EFD_NONBLOCK)?; let queue_events = (0..MEM_NUM_QUEUES) @@ -143,6 +161,7 @@ impl VirtioMem { config, vm, slot_size, + plugged_blocks, }) } @@ -150,6 +169,10 @@ impl VirtioMem { VIRTIO_MEM_DEV_ID } + pub fn guest_address(&self) -> GuestAddress { + GuestAddress(self.config.addr) + } + /// Gets the total hotpluggable size. pub fn total_size_mib(&self) -> usize { bytes_to_mib(u64_to_usize(self.config.region_size)) @@ -195,6 +218,24 @@ impl VirtioMem { &self.device_state.active_state().unwrap().mem } + fn nb_blocks_to_len(&self, nb_blocks: usize) -> usize { + nb_blocks * u64_to_usize(self.config.block_size) + } + + /// Returns the state of all the blocks in the given range. + /// + /// Note: the range passed to this function must be within the device memory to avoid + /// out-of-bound panics. + fn range_state(&self, range: &RequestedRange) -> BlockRangeState { + let plugged_count = self.plugged_blocks[self.unchecked_block_range(range)].count_ones(); + + match plugged_count { + nb_blocks if nb_blocks == range.nb_blocks => BlockRangeState::Plugged, + 0 => BlockRangeState::Unplugged, + _ => BlockRangeState::Mixed, + } + } + fn parse_request( &self, avail_desc: &DescriptorChain, @@ -248,6 +289,59 @@ impl VirtioMem { .map_err(VirtioMemError::QueueError) } + /// Checks that the range provided by the driver is within the usable memory region + fn validate_range(&self, range: &RequestedRange) -> Result<(), VirtioMemError> { + // Ensure the range is aligned + if !range + .addr + .raw_value() + .is_multiple_of(self.config.block_size) + { + return Err(VirtioMemError::InvalidRange(*range)); + } + + if range.nb_blocks == 0 { + return Err(VirtioMemError::InvalidRange(*range)); + } + + // Ensure the start addr is within the usable region + let start_off = range + .addr + .checked_offset_from(self.guest_address()) + .filter(|&off| off < self.config.usable_region_size) + .ok_or(VirtioMemError::InvalidRange(*range))?; + + // Ensure the end offset (exclusive) is within the usable region + let end_off = start_off + .checked_add(usize_to_u64(self.nb_blocks_to_len(range.nb_blocks))) + .filter(|&end_off| end_off <= self.config.usable_region_size) + .ok_or(VirtioMemError::InvalidRange(*range))?; + + Ok(()) + } + + fn unchecked_block_range(&self, range: &RequestedRange) -> Range { + let start_block = u64_to_usize((range.addr.0 - self.config.addr) / self.config.block_size); + + start_block..(start_block + range.nb_blocks) + } + + fn process_plug_request(&mut self, range: &RequestedRange) -> Result<(), VirtioMemError> { + self.validate_range(range)?; + + if self.config.plugged_size + usize_to_u64(self.nb_blocks_to_len(range.nb_blocks)) + > self.config.requested_size + { + return Err(VirtioMemError::PlugRequestIsTooBig); + } + + match self.range_state(range) { + // the range was validated + BlockRangeState::Unplugged => self.update_range(range, true), + state => Err(VirtioMemError::PlugRequestBlockStateInvalid(state)), + } + } + fn handle_plug_request( &mut self, range: &RequestedRange, @@ -257,11 +351,32 @@ impl VirtioMem { METRICS.plug_count.inc(); let _metric = METRICS.plug_agg.record_latency_metrics(); - // TODO: implement PLUG request - let response = Response::ack(); + let response = match self.process_plug_request(range) { + Err(err) => { + METRICS.plug_fails.inc(); + error!("virtio-mem: Failed to plug range: {}", err); + Response::error() + } + Ok(_) => { + METRICS + .plug_bytes + .add(usize_to_u64(self.nb_blocks_to_len(range.nb_blocks))); + Response::ack() + } + }; self.write_response(response, resp_addr, used_idx) } + fn process_unplug_request(&mut self, range: &RequestedRange) -> Result<(), VirtioMemError> { + self.validate_range(range)?; + + match self.range_state(range) { + // the range was validated + BlockRangeState::Plugged => self.update_range(range, false), + state => Err(VirtioMemError::UnplugRequestBlockStateInvalid(state)), + } + } + fn handle_unplug_request( &mut self, range: &RequestedRange, @@ -270,9 +385,19 @@ impl VirtioMem { ) -> Result<(), VirtioMemError> { METRICS.unplug_count.inc(); let _metric = METRICS.unplug_agg.record_latency_metrics(); - - // TODO: implement UNPLUG request - let response = Response::ack(); + let response = match self.process_unplug_request(range) { + Err(err) => { + METRICS.unplug_fails.inc(); + error!("virtio-mem: Failed to unplug range: {}", err); + Response::error() + } + Ok(_) => { + METRICS + .unplug_bytes + .add(usize_to_u64(self.nb_blocks_to_len(range.nb_blocks))); + Response::ack() + } + }; self.write_response(response, resp_addr, used_idx) } @@ -283,9 +408,21 @@ impl VirtioMem { ) -> Result<(), VirtioMemError> { METRICS.unplug_all_count.inc(); let _metric = METRICS.unplug_all_agg.record_latency_metrics(); - - // TODO: implement UNPLUG ALL request - let response = Response::ack(); + let range = RequestedRange { + addr: self.guest_address(), + nb_blocks: self.plugged_blocks.len(), + }; + let response = match self.update_range(&range, false) { + Err(err) => { + METRICS.unplug_all_fails.inc(); + error!("virtio-mem: Failed to unplug all: {}", err); + Response::error() + } + Ok(_) => { + self.config.usable_region_size = 0; + Response::ack() + } + }; self.write_response(response, resp_addr, used_idx) } @@ -297,9 +434,15 @@ impl VirtioMem { ) -> Result<(), VirtioMemError> { METRICS.state_count.inc(); let _metric = METRICS.state_agg.record_latency_metrics(); - - // TODO: implement STATE request - let response = Response::ack_with_state(BlockRangeState::Mixed); + let response = match self.validate_range(range) { + Err(err) => { + METRICS.state_fails.inc(); + error!("virtio-mem: Failed to retrieve state of range: {}", err); + Response::error() + } + // the range was validated + Ok(_) => Response::ack_with_state(self.range_state(range)), + }; self.write_response(response, resp_addr, used_idx) } @@ -357,6 +500,37 @@ impl VirtioMem { &self.activate_event } + /// Plugs/unplugs the given range + /// + /// Note: the range passed to this function must be within the device memory to avoid + /// out-of-bound panics. + fn update_range(&mut self, range: &RequestedRange, plug: bool) -> Result<(), VirtioMemError> { + // Update internal state + let block_range = self.unchecked_block_range(range); + let plugged_blocks_slice = &mut self.plugged_blocks[block_range]; + let plugged_before = plugged_blocks_slice.count_ones(); + plugged_blocks_slice.fill(plug); + let plugged_after = plugged_blocks_slice.count_ones(); + self.config.plugged_size -= usize_to_u64(self.nb_blocks_to_len(plugged_before)); + self.config.plugged_size += usize_to_u64(self.nb_blocks_to_len(plugged_after)); + + // If unplugging, discard the range + if !plug { + self.guest_memory() + .discard_range(range.addr, self.nb_blocks_to_len(range.nb_blocks)) + .inspect_err(|err| { + // Failure to discard is not fatal and is not reported to the driver. It only + // gets logged. + METRICS.unplug_discard_fails.inc(); + error!("virtio-mem: Failed to discard memory range: {}", err); + }); + } + + // TODO: update KVM slots to plug/unplug them + + Ok(()) + } + /// Updates the requested size of the virtio-mem device. pub fn update_requested_size( &mut self, @@ -547,7 +721,18 @@ mod tests { usable_region_size, ..Default::default() }; - let mem = VirtioMem::from_state(vm, queues, config, mib_to_bytes(slot_size_mib)).unwrap(); + let plugged_blocks = BitVec::repeat( + false, + mib_to_bytes(region_size_mib) / mib_to_bytes(block_size_mib), + ); + let mem = VirtioMem::from_state( + vm, + queues, + config, + mib_to_bytes(slot_size_mib), + plugged_blocks, + ) + .unwrap(); assert_eq!(mem.total_size_mib(), region_size_mib); assert_eq!(mem.block_size_mib(), block_size_mib); assert_eq!(mem.slot_size_mib(), slot_size_mib); diff --git a/src/vmm/src/devices/virtio/mem/metrics.rs b/src/vmm/src/devices/virtio/mem/metrics.rs index e9e97707782..f2a6f58b92d 100644 --- a/src/vmm/src/devices/virtio/mem/metrics.rs +++ b/src/vmm/src/devices/virtio/mem/metrics.rs @@ -49,18 +49,32 @@ pub(super) struct VirtioMemDeviceMetrics { pub plug_agg: LatencyAggregateMetrics, /// Number of Plug operations pub plug_count: SharedIncMetric, + /// Number of plugged bytes + pub plug_bytes: SharedIncMetric, + /// Number of Plug operations failed + pub plug_fails: SharedIncMetric, /// Latency of Unplug operations pub unplug_agg: LatencyAggregateMetrics, /// Number of Unplug operations pub unplug_count: SharedIncMetric, + /// Number of unplugged bytes + pub unplug_bytes: SharedIncMetric, + /// Number of Unplug operations failed + pub unplug_fails: SharedIncMetric, + /// Number of discards failed for an Unplug or UnplugAll operation + pub unplug_discard_fails: SharedIncMetric, /// Latency of UnplugAll operations pub unplug_all_agg: LatencyAggregateMetrics, /// Number of UnplugAll operations pub unplug_all_count: SharedIncMetric, + /// Number of UnplugAll operations failed + pub unplug_all_fails: SharedIncMetric, /// Latency of State operations pub state_agg: LatencyAggregateMetrics, /// Number of State operations pub state_count: SharedIncMetric, + /// Number of State operations failed + pub state_fails: SharedIncMetric, } impl VirtioMemDeviceMetrics { @@ -72,12 +86,19 @@ impl VirtioMemDeviceMetrics { queue_event_count: SharedIncMetric::new(), plug_agg: LatencyAggregateMetrics::new(), plug_count: SharedIncMetric::new(), + plug_bytes: SharedIncMetric::new(), + plug_fails: SharedIncMetric::new(), unplug_agg: LatencyAggregateMetrics::new(), unplug_count: SharedIncMetric::new(), + unplug_bytes: SharedIncMetric::new(), + unplug_fails: SharedIncMetric::new(), + unplug_discard_fails: SharedIncMetric::new(), unplug_all_agg: LatencyAggregateMetrics::new(), unplug_all_count: SharedIncMetric::new(), + unplug_all_fails: SharedIncMetric::new(), state_agg: LatencyAggregateMetrics::new(), state_count: SharedIncMetric::new(), + state_fails: SharedIncMetric::new(), } } } diff --git a/src/vmm/src/devices/virtio/mem/persist.rs b/src/vmm/src/devices/virtio/mem/persist.rs index e48246de500..09f41680e32 100644 --- a/src/vmm/src/devices/virtio/mem/persist.rs +++ b/src/vmm/src/devices/virtio/mem/persist.rs @@ -5,6 +5,7 @@ use std::sync::Arc; +use bitvec::vec::BitVec; use serde::{Deserialize, Serialize}; use vm_memory::Address; @@ -17,6 +18,7 @@ use crate::devices::virtio::mem::{ use crate::devices::virtio::persist::{PersistError as VirtioStateError, VirtioDeviceState}; use crate::devices::virtio::queue::FIRECRACKER_MAX_QUEUE_SIZE; use crate::snapshot::Persist; +use crate::utils::usize_to_u64; use crate::vstate::memory::{GuestMemoryMmap, GuestRegionMmap}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -25,9 +27,9 @@ pub struct VirtioMemState { region_size: u64, block_size: u64, usable_region_size: u64, - plugged_size: u64, requested_size: u64, slot_size: usize, + plugged_blocks: BitVec, } #[derive(Debug)] @@ -60,7 +62,7 @@ impl Persist<'_> for VirtioMem { region_size: self.config.region_size, block_size: self.config.block_size, usable_region_size: self.config.usable_region_size, - plugged_size: self.config.plugged_size, + plugged_blocks: self.plugged_blocks.clone(), requested_size: self.config.requested_size, slot_size: self.slot_size, } @@ -82,13 +84,18 @@ impl Persist<'_> for VirtioMem { region_size: state.region_size, block_size: state.block_size, usable_region_size: state.usable_region_size, - plugged_size: state.plugged_size, + plugged_size: usize_to_u64(state.plugged_blocks.count_ones()) * state.block_size, requested_size: state.requested_size, ..Default::default() }; - let mut virtio_mem = - VirtioMem::from_state(constructor_args.vm, queues, config, state.slot_size)?; + let mut virtio_mem = VirtioMem::from_state( + constructor_args.vm, + queues, + config, + state.slot_size, + state.plugged_blocks.clone(), + )?; virtio_mem.set_avail_features(state.virtio_state.avail_features); virtio_mem.set_acked_features(state.virtio_state.acked_features); @@ -111,7 +118,7 @@ mod tests { assert_eq!(state.region_size, dev.config.region_size); assert_eq!(state.block_size, dev.config.block_size); assert_eq!(state.usable_region_size, dev.config.usable_region_size); - assert_eq!(state.plugged_size, dev.config.plugged_size); + assert_eq!(state.plugged_blocks, dev.plugged_blocks); assert_eq!(state.requested_size, dev.config.requested_size); assert_eq!(state.slot_size, dev.slot_size); } From c0a681b9d78d994c4d5163abb8c32df34653259a Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Thu, 25 Sep 2025 16:35:17 +0100 Subject: [PATCH 07/14] test(virtio-mem): add unit tests for virtio queue request handling Adds unit tests using VirtioTestHelper to verify correct functioning of the new device. Signed-off-by: Riccardo Mancini --- src/vmm/src/devices/virtio/mem/device.rs | 504 +++++++++++++++++++++- src/vmm/src/devices/virtio/mem/metrics.rs | 7 +- src/vmm/src/devices/virtio/mem/request.rs | 92 +++- src/vmm/src/devices/virtio/test_utils.rs | 5 + 4 files changed, 599 insertions(+), 9 deletions(-) diff --git a/src/vmm/src/devices/virtio/mem/device.rs b/src/vmm/src/devices/virtio/mem/device.rs index 2a837c44bcb..458a1adc8fe 100644 --- a/src/vmm/src/devices/virtio/mem/device.rs +++ b/src/vmm/src/devices/virtio/mem/device.rs @@ -663,10 +663,34 @@ impl VirtioDevice for VirtioMem { #[cfg(test)] pub(crate) mod test_utils { use super::*; + use crate::devices::virtio::test_utils::test::VirtioTestDevice; + use crate::test_utils::single_region_mem; + use crate::vmm_config::machine_config::HugePageConfig; + use crate::vstate::memory; use crate::vstate::vm::tests::setup_vm_with_memory; + impl VirtioTestDevice for VirtioMem { + fn set_queues(&mut self, queues: Vec) { + self.queues = queues; + } + + fn num_queues() -> usize { + MEM_NUM_QUEUES + } + } + pub(crate) fn default_virtio_mem() -> VirtioMem { - let (_, vm) = setup_vm_with_memory(0x1000); + let (_, mut vm) = setup_vm_with_memory(0x1000); + vm.register_hotpluggable_memory_region( + memory::anonymous( + std::iter::once((VIRTIO_MEM_GUEST_ADDRESS, mib_to_bytes(1024))), + false, + HugePageConfig::None, + ) + .unwrap() + .pop() + .unwrap(), + ); let vm = Arc::new(vm); VirtioMem::new(vm, 1024, 2, 128).unwrap() } @@ -676,11 +700,15 @@ pub(crate) mod test_utils { mod tests { use std::ptr::null_mut; + use serde_json::de; + use vm_memory::guest_memory; use vm_memory::mmap::MmapRegionBuilder; use super::*; use crate::devices::virtio::device::VirtioDevice; use crate::devices::virtio::mem::device::test_utils::default_virtio_mem; + use crate::devices::virtio::queue::VIRTQ_DESC_F_WRITE; + use crate::devices::virtio::test_utils::test::VirtioTestHelper; use crate::vstate::vm::tests::setup_vm_with_memory; #[test] @@ -815,4 +843,478 @@ mod tests { } ); } + + #[allow(clippy::cast_possible_truncation)] + const REQ_SIZE: u32 = std::mem::size_of::() as u32; + #[allow(clippy::cast_possible_truncation)] + const RESP_SIZE: u32 = std::mem::size_of::() as u32; + + fn test_helper<'a>( + mut dev: VirtioMem, + mem: &'a GuestMemoryMmap, + ) -> VirtioTestHelper<'a, VirtioMem> { + dev.set_acked_features(dev.avail_features); + + let mut th = VirtioTestHelper::::new(mem, dev); + th.activate_device(mem); + th + } + + fn emulate_request( + th: &mut VirtioTestHelper, + mem: &GuestMemoryMmap, + req: Request, + ) -> Response { + th.add_desc_chain( + MEM_QUEUE, + 0, + &[(0, REQ_SIZE, 0), (1, RESP_SIZE, VIRTQ_DESC_F_WRITE)], + ); + mem.write_obj( + virtio_mem::virtio_mem_req::from(req), + th.desc_address(MEM_QUEUE, 0), + ) + .unwrap(); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + mem.read_obj::(th.desc_address(MEM_QUEUE, 1)) + .unwrap() + .into() + } + + #[test] + fn test_event_fail_descriptor_chain_too_short() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, REQ_SIZE, 0)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_descriptor_length_too_small() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, 1, 0)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_unexpected_writeonly_descriptor() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, REQ_SIZE, VIRTQ_DESC_F_WRITE)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_unexpected_readonly_descriptor() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain(MEM_QUEUE, 0, &[(0, REQ_SIZE, 0), (1, RESP_SIZE, 0)]); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_event_fail_response_descriptor_length_too_small() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain( + MEM_QUEUE, + 0, + &[(0, REQ_SIZE, 0), (1, 1, VIRTQ_DESC_F_WRITE)], + ); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } + + #[test] + fn test_update_requested_size_device_not_active() { + let mut mem_dev = default_virtio_mem(); + let result = mem_dev.update_requested_size(512); + assert!(matches!(result, Err(VirtioMemError::DeviceNotActive))); + } + + #[test] + fn test_update_requested_size_invalid_size() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + // Size not multiple of block size + let result = th.device().update_requested_size(3); + assert!(matches!(result, Err(VirtioMemError::InvalidSize(_)))); + + // Size too large + let result = th.device().update_requested_size(2048); + assert!(matches!(result, Err(VirtioMemError::InvalidSize(_)))); + } + + #[test] + fn test_update_requested_size_success() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + th.device().update_requested_size(512).unwrap(); + assert_eq!(th.device().requested_size_mib(), 512); + } + + #[test] + fn test_plug_request_success() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + let plug_count = METRICS.plug_count.count(); + let plug_bytes = METRICS.plug_bytes.count(); + let plug_fails = METRICS.plug_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 2); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails); + assert_eq!(METRICS.plug_count.count(), plug_count + 1); + assert_eq!(METRICS.plug_bytes.count(), plug_bytes + (2 << 20)); + assert_eq!(METRICS.plug_fails.count(), plug_fails); + } + + #[test] + fn test_plug_request_too_big() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(2); + let addr = th.device().guest_address(); + + let plug_count = METRICS.plug_count.count(); + let plug_bytes = METRICS.plug_bytes.count(); + let plug_fails = METRICS.plug_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 2 }), + ); + assert!(resp.is_error()); + + assert_eq!(METRICS.plug_count.count(), plug_count + 1); + assert_eq!(METRICS.plug_bytes.count(), plug_bytes); + assert_eq!(METRICS.plug_fails.count(), plug_fails + 1); + } + + #[test] + fn test_plug_request_already_plugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + // First plug succeeds + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + + // Second plug fails + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_error()); + } + + #[test] + fn test_unplug_request_success() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let unplug_count = METRICS.unplug_count.count(); + let unplug_bytes = METRICS.unplug_bytes.count(); + let unplug_fails = METRICS.unplug_fails.count(); + + // First plug + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 2); + + // Then unplug + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Unplug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 0); + + assert_eq!(METRICS.unplug_count.count(), unplug_count + 1); + assert_eq!(METRICS.unplug_bytes.count(), unplug_bytes + (2 << 20)); + assert_eq!(METRICS.unplug_fails.count(), unplug_fails); + } + + #[test] + fn test_unplug_request_not_plugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let unplug_count = METRICS.unplug_count.count(); + let unplug_bytes = METRICS.unplug_bytes.count(); + let unplug_fails = METRICS.unplug_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Unplug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_error()); + + assert_eq!(METRICS.unplug_count.count(), unplug_count + 1); + assert_eq!(METRICS.unplug_bytes.count(), unplug_bytes); + assert_eq!(METRICS.unplug_fails.count(), unplug_fails + 1); + } + + #[test] + fn test_unplug_all_request() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let unplug_all_count = METRICS.unplug_all_count.count(); + let unplug_all_fails = METRICS.unplug_all_fails.count(); + + // Plug some blocks + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 2 }), + ); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 4); + + // Unplug all + let resp = emulate_request(&mut th, &guest_mem, Request::UnplugAll); + assert!(resp.is_ack()); + assert_eq!(th.device().plugged_size_mib(), 0); + + assert_eq!(METRICS.unplug_all_count.count(), unplug_all_count + 1); + assert_eq!(METRICS.unplug_all_fails.count(), unplug_all_fails); + } + + #[test] + fn test_state_request_unplugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let state_count = METRICS.state_count.count(); + let state_fails = METRICS.state_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 1 }), + ); + assert_eq!(resp, Response::ack_with_state(BlockRangeState::Unplugged)); + + assert_eq!(METRICS.state_count.count(), state_count + 1); + assert_eq!(METRICS.state_fails.count(), state_fails); + } + + #[test] + fn test_state_request_plugged() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + // Plug first + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + + // Check state + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 1 }), + ); + assert_eq!(resp, Response::ack_with_state(BlockRangeState::Plugged)); + } + + #[test] + fn test_state_request_mixed() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + // Plug first block only + let resp = emulate_request( + &mut th, + &guest_mem, + Request::Plug(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_ack()); + + // Check state of 2 blocks (one plugged, one unplugged) + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 2 }), + ); + assert_eq!(resp, Response::ack_with_state(BlockRangeState::Mixed)); + } + + #[test] + fn test_invalid_range_unaligned() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address().unchecked_add(1); + + let state_count = METRICS.state_count.count(); + let state_fails = METRICS.state_fails.count(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 1 }), + ); + assert!(resp.is_error()); + + assert_eq!(METRICS.state_count.count(), state_count + 1); + assert_eq!(METRICS.state_fails.count(), state_fails + 1); + } + + #[test] + fn test_invalid_range_zero_blocks() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(1024); + let addr = th.device().guest_address(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { addr, nb_blocks: 0 }), + ); + assert!(resp.is_error()); + } + + #[test] + fn test_invalid_range_out_of_bounds() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + th.device().update_requested_size(4); + let addr = th.device().guest_address(); + + let resp = emulate_request( + &mut th, + &guest_mem, + Request::State(RequestedRange { + addr, + nb_blocks: 1024, + }), + ); + assert!(resp.is_error()); + } + + #[test] + fn test_unsupported_request() { + let mut mem_dev = default_virtio_mem(); + let guest_mem = mem_dev.vm.guest_memory().clone(); + let mut th = test_helper(mem_dev, &guest_mem); + + let queue_event_count = METRICS.queue_event_count.count(); + let queue_event_fails = METRICS.queue_event_fails.count(); + + th.add_desc_chain( + MEM_QUEUE, + 0, + &[(0, REQ_SIZE, 0), (1, RESP_SIZE, VIRTQ_DESC_F_WRITE)], + ); + guest_mem + .write_obj( + virtio_mem::virtio_mem_req::from(Request::Unsupported(999)), + th.desc_address(MEM_QUEUE, 0), + ) + .unwrap(); + assert_eq!(th.emulate_for_msec(100).unwrap(), 1); + + assert_eq!(METRICS.queue_event_count.count(), queue_event_count + 1); + assert_eq!(METRICS.queue_event_fails.count(), queue_event_fails + 1); + } } diff --git a/src/vmm/src/devices/virtio/mem/metrics.rs b/src/vmm/src/devices/virtio/mem/metrics.rs index f2a6f58b92d..d69255d44ec 100644 --- a/src/vmm/src/devices/virtio/mem/metrics.rs +++ b/src/vmm/src/devices/virtio/mem/metrics.rs @@ -111,13 +111,8 @@ pub mod tests { #[test] fn test_memory_hotplug_metrics() { let mem_metrics: VirtioMemDeviceMetrics = VirtioMemDeviceMetrics::new(); - let mem_metrics_local: String = serde_json::to_string(&mem_metrics).unwrap(); - // the 1st serialize flushes the metrics and resets values to 0 so that - // we can compare the values with local metrics. - serde_json::to_string(&METRICS).unwrap(); - let mem_metrics_global: String = serde_json::to_string(&METRICS).unwrap(); - assert_eq!(mem_metrics_local, mem_metrics_global); mem_metrics.queue_event_count.inc(); assert_eq!(mem_metrics.queue_event_count.count(), 1); + let _ = serde_json::to_string(&mem_metrics).unwrap(); } } diff --git a/src/vmm/src/devices/virtio/mem/request.rs b/src/vmm/src/devices/virtio/mem/request.rs index c68220bdaeb..1cc0643392c 100644 --- a/src/vmm/src/devices/virtio/mem/request.rs +++ b/src/vmm/src/devices/virtio/mem/request.rs @@ -1,7 +1,7 @@ // Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use vm_memory::{ByteValued, GuestAddress}; +use vm_memory::{Address, ByteValued, GuestAddress}; use crate::devices::virtio::generated::virtio_mem; @@ -72,7 +72,7 @@ pub enum BlockRangeState { Mixed = virtio_mem::VIRTIO_MEM_STATE_MIXED as u16, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] pub struct Response { pub resp_type: ResponseType, // Only for State requests @@ -100,6 +100,14 @@ impl Response { state: Some(state), } } + + pub(crate) fn is_ack(&self) -> bool { + self.resp_type == ResponseType::Ack + } + + pub(crate) fn is_error(&self) -> bool { + self.resp_type == ResponseType::Error + } } // SAFETY: Plain data structures @@ -117,3 +125,83 @@ impl From for virtio_mem::virtio_mem_resp { out } } + +#[cfg(test)] +mod test_util { + use super::*; + + // Implement the reverse conversions to use in test code. + + impl From for virtio_mem::virtio_mem_req { + fn from(req: Request) -> virtio_mem::virtio_mem_req { + match req { + Request::Plug(r) => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_PLUG.try_into().unwrap(), + u: virtio_mem::virtio_mem_req__bindgen_ty_1 { + plug: virtio_mem::virtio_mem_req_plug { + addr: r.addr.raw_value(), + nb_blocks: r.nb_blocks.try_into().unwrap(), + ..Default::default() + }, + }, + ..Default::default() + }, + Request::Unplug(r) => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_UNPLUG.try_into().unwrap(), + u: virtio_mem::virtio_mem_req__bindgen_ty_1 { + unplug: virtio_mem::virtio_mem_req_unplug { + addr: r.addr.raw_value(), + nb_blocks: r.nb_blocks.try_into().unwrap(), + ..Default::default() + }, + }, + ..Default::default() + }, + Request::UnplugAll => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_UNPLUG_ALL.try_into().unwrap(), + ..Default::default() + }, + Request::State(r) => virtio_mem::virtio_mem_req { + type_: virtio_mem::VIRTIO_MEM_REQ_STATE.try_into().unwrap(), + u: virtio_mem::virtio_mem_req__bindgen_ty_1 { + state: virtio_mem::virtio_mem_req_state { + addr: r.addr.raw_value(), + nb_blocks: r.nb_blocks.try_into().unwrap(), + ..Default::default() + }, + }, + ..Default::default() + }, + Request::Unsupported(t) => virtio_mem::virtio_mem_req { + type_: t.try_into().unwrap(), + ..Default::default() + }, + } + } + } + + impl From for Response { + fn from(resp: virtio_mem::virtio_mem_resp) -> Self { + Response { + resp_type: match resp.type_.into() { + virtio_mem::VIRTIO_MEM_RESP_ACK => ResponseType::Ack, + virtio_mem::VIRTIO_MEM_RESP_NACK => ResponseType::Nack, + virtio_mem::VIRTIO_MEM_RESP_BUSY => ResponseType::Busy, + virtio_mem::VIRTIO_MEM_RESP_ERROR => ResponseType::Error, + t => panic!("Invalid response type: {:?}", t), + }, + // There is no way to know whether this is present or not as it depends on the + // request types. Callers should ignore this value if the request wasn't STATE + /// SAFETY: test code only. Uninitialized values are 0 and recognized as PLUGGED. + state: Some(unsafe { + match resp.u.state.state.into() { + virtio_mem::VIRTIO_MEM_STATE_PLUGGED => BlockRangeState::Plugged, + virtio_mem::VIRTIO_MEM_STATE_UNPLUGGED => BlockRangeState::Unplugged, + virtio_mem::VIRTIO_MEM_STATE_MIXED => BlockRangeState::Mixed, + t => panic!("Invalid state: {:?}", t), + } + }), + } + } + } +} diff --git a/src/vmm/src/devices/virtio/test_utils.rs b/src/vmm/src/devices/virtio/test_utils.rs index 6f1489dd380..0c7978504e7 100644 --- a/src/vmm/src/devices/virtio/test_utils.rs +++ b/src/vmm/src/devices/virtio/test_utils.rs @@ -442,6 +442,11 @@ pub(crate) mod test { self.virtqueues.last().unwrap().end().raw_value() } + /// Get the address of a descriptor + pub fn desc_address(&self, queue: usize, index: usize) -> GuestAddress { + GuestAddress(self.virtqueues[queue].dtable[index].addr.get()) + } + /// Add a new Descriptor in one of the device's queues /// /// This function adds in one of the queues of the device a DescriptorChain at some offset From d9c2bbe5c47708558bc14451a8b4528f26e7cd42 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Fri, 26 Sep 2025 17:35:29 +0100 Subject: [PATCH 08/14] test(metrics): add virtio-mem device metrics to validation Add the virtio-mem device metrics to the integ test validation. Signed-off-by: Riccardo Mancini --- tests/host_tools/fcmetrics.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/host_tools/fcmetrics.py b/tests/host_tools/fcmetrics.py index 6b110a70f96..3b65901b5aa 100644 --- a/tests/host_tools/fcmetrics.py +++ b/tests/host_tools/fcmetrics.py @@ -302,6 +302,21 @@ def validate_fc_metrics(metrics): "activate_fails", "queue_event_fails", "queue_event_count", + "plug_count", + "plug_bytes", + "plug_fails", + {"plug_agg": latency_agg_metrics_fields}, + "unplug_count", + "unplug_bytes", + "unplug_fails", + "unplug_discard_fails", + {"unplug_agg": latency_agg_metrics_fields}, + "state_count", + "state_fails", + {"state_agg": latency_agg_metrics_fields}, + "unplug_all_count", + "unplug_all_fails", + {"unplug_all_agg": latency_agg_metrics_fields}, ], } From b3fefeb22f5c338be0f422569d929607b2e69be4 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Fri, 26 Sep 2025 15:25:25 +0100 Subject: [PATCH 09/14] fix(examples/uffd): unregister range on UFFD Remove event If the handler receives a UFFD remove event, it currently stores the PFN and will reply with a zero page whenever it receives a pagefault event for that page. This works well with 4k pages, but zeropage is not supported on hugepages. In order to support hugepages, let's just unregister from UFFD whenever we get a remove event. By doing so, the handler won't receive a notification for the removed page, and the VM will get a new zero page from the kernel. Signed-off-by: Riccardo Mancini --- .../examples/uffd/on_demand_handler.rs | 2 +- src/firecracker/examples/uffd/uffd_utils.rs | 35 +++++++------------ 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/src/firecracker/examples/uffd/on_demand_handler.rs b/src/firecracker/examples/uffd/on_demand_handler.rs index 3be958b3578..3101aa19253 100644 --- a/src/firecracker/examples/uffd/on_demand_handler.rs +++ b/src/firecracker/examples/uffd/on_demand_handler.rs @@ -87,7 +87,7 @@ fn main() { } } userfaultfd::Event::Remove { start, end } => { - uffd_handler.mark_range_removed(start as u64, end as u64) + uffd_handler.unregister_range(start, end) } _ => panic!("Unexpected event on userfaultfd"), } diff --git a/src/firecracker/examples/uffd/uffd_utils.rs b/src/firecracker/examples/uffd/uffd_utils.rs index 97c6150b65b..ab28f6f4d2e 100644 --- a/src/firecracker/examples/uffd/uffd_utils.rs +++ b/src/firecracker/examples/uffd/uffd_utils.rs @@ -9,7 +9,7 @@ dead_code )] -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::ffi::c_void; use std::fs::File; use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd}; @@ -54,7 +54,6 @@ pub struct UffdHandler { pub page_size: usize, backing_buffer: *const u8, uffd: Uffd, - removed_pages: HashSet, } impl UffdHandler { @@ -125,7 +124,6 @@ impl UffdHandler { page_size, backing_buffer, uffd, - removed_pages: HashSet::new(), } } @@ -133,24 +131,23 @@ impl UffdHandler { self.uffd.read_event() } - pub fn mark_range_removed(&mut self, start: u64, end: u64) { - let pfn_start = start / self.page_size as u64; - let pfn_end = end / self.page_size as u64; - - for pfn in pfn_start..pfn_end { - self.removed_pages.insert(pfn); - } + pub fn unregister_range(&mut self, start: *mut c_void, end: *mut c_void) { + assert!( + (start as usize).is_multiple_of(self.page_size) + && (end as usize).is_multiple_of(self.page_size) + && end > start + ); + // SAFETY: start and end are valid and provided by UFFD + let len = unsafe { end.offset_from_unsigned(start) }; + self.uffd + .unregister(start, len) + .expect("range should be valid"); } pub fn serve_pf(&mut self, addr: *mut u8, len: usize) -> bool { // Find the start of the page that the current faulting address belongs to. let dst = (addr as usize & !(self.page_size - 1)) as *mut libc::c_void; let fault_page_addr = dst as u64; - let fault_pfn = fault_page_addr / self.page_size as u64; - - if self.removed_pages.contains(&fault_pfn) { - return self.zero_out(fault_page_addr); - } for region in self.mem_regions.iter() { if region.contains(fault_page_addr) { @@ -193,14 +190,6 @@ impl UffdHandler { true } - - fn zero_out(&mut self, addr: u64) -> bool { - match unsafe { self.uffd.zeropage(addr as *mut _, self.page_size, true) } { - Ok(_) => true, - Err(Error::ZeropageFailed(error)) if error as i32 == libc::EAGAIN => false, - r => panic!("Unexpected zeropage result: {:?}", r), - } - } } #[derive(Debug)] From 3976cd6f4a1cc7c350b4fd9435a5f74ad1db5c5e Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Fri, 26 Sep 2025 15:35:40 +0100 Subject: [PATCH 10/14] feat(test/balloon): include HugePages in RSS measurements This moves the logic to measure RSS to framework.utils and adds a logic to also include huge pages in the measurement. Furthermore, this also adds caching for the firecracker_pid, as well as a new property to get the corresponding psutil.Process. Signed-off-by: Riccardo Mancini --- tests/framework/microvm.py | 10 ++- tests/framework/utils.py | 14 ++++ .../functional/test_balloon.py | 82 +++++++------------ .../test_snapshot_restore_cross_kernel.py | 11 +-- 4 files changed, 57 insertions(+), 60 deletions(-) diff --git a/tests/framework/microvm.py b/tests/framework/microvm.py index fa9dea79b82..69f00e4b94b 100644 --- a/tests/framework/microvm.py +++ b/tests/framework/microvm.py @@ -23,10 +23,11 @@ from collections import namedtuple from dataclasses import dataclass from enum import Enum, auto -from functools import lru_cache +from functools import cached_property, lru_cache from pathlib import Path from typing import Optional +import psutil from tenacity import Retrying, retry, stop_after_attempt, wait_fixed import host_tools.cargo_build as build_tools @@ -472,7 +473,7 @@ def state(self): """Get the InstanceInfo property and return the state field.""" return self.api.describe.get().json()["state"] - @property + @cached_property def firecracker_pid(self): """Return Firecracker's PID @@ -491,6 +492,11 @@ def firecracker_pid(self): with attempt: return int(self.jailer.pid_file.read_text(encoding="ascii")) + @cached_property + def ps(self): + """Returns a handle to the psutil.Process for this VM""" + return psutil.Process(self.firecracker_pid) + @property def dimensions(self): """Gets a default set of cloudwatch dimensions describing the configuration of this microvm""" diff --git a/tests/framework/utils.py b/tests/framework/utils.py index 64bc9526e5c..d592daec84f 100644 --- a/tests/framework/utils.py +++ b/tests/framework/utils.py @@ -14,6 +14,7 @@ import typing from collections import defaultdict, namedtuple from contextlib import contextmanager +from pathlib import Path from typing import Dict import psutil @@ -129,6 +130,19 @@ def track_cpu_utilization( return cpu_utilization +def get_resident_memory(process: psutil.Process): + """Returns current memory utilization in KiB, including used HugeTLBFS""" + + proc_status = Path("/proc", str(process.pid), "status").read_text("utf-8") + for line in proc_status.splitlines(): + if line.startswith("HugetlbPages:"): # entry is in KiB + hugetlbfs_usage = int(line.split()[1]) + break + else: + assert False, f"HugetlbPages not found in {str(proc_status)}" + return hugetlbfs_usage + process.memory_info().rss // 1024 + + @contextmanager def chroot(path): """ diff --git a/tests/integration_tests/functional/test_balloon.py b/tests/integration_tests/functional/test_balloon.py index f8960bedb6d..59c87358c42 100644 --- a/tests/integration_tests/functional/test_balloon.py +++ b/tests/integration_tests/functional/test_balloon.py @@ -9,12 +9,12 @@ import pytest import requests -from framework.utils import check_output, get_free_mem_ssh +from framework.utils import get_free_mem_ssh, get_resident_memory STATS_POLLING_INTERVAL_S = 1 -def get_stable_rss_mem_by_pid(pid, percentage_delta=1): +def get_stable_rss_mem(uvm, percentage_delta=1): """ Get the RSS memory that a guest uses, given the pid of the guest. @@ -22,22 +22,16 @@ def get_stable_rss_mem_by_pid(pid, percentage_delta=1): Or print a warning if this does not happen. """ - # All values are reported as KiB - - def get_rss_from_pmap(): - _, output, _ = check_output("pmap -X {}".format(pid)) - return int(output.split("\n")[-2].split()[1], 10) - first_rss = 0 second_rss = 0 for _ in range(5): - first_rss = get_rss_from_pmap() + first_rss = get_resident_memory(uvm.ps) time.sleep(1) - second_rss = get_rss_from_pmap() + second_rss = get_resident_memory(uvm.ps) abs_diff = abs(first_rss - second_rss) abs_delta = abs_diff / first_rss * 100 print( - f"RSS readings: old: {first_rss} new: {second_rss} abs_diff: {abs_diff} abs_delta: {abs_delta}" + f"RSS readings (bytes): old: {first_rss} new: {second_rss} abs_diff: {abs_diff} abs_delta: {abs_delta}" ) if abs_delta < percentage_delta: return second_rss @@ -87,25 +81,24 @@ def make_guest_dirty_memory(ssh_connection, amount_mib=32): def _test_rss_memory_lower(test_microvm): """Check inflating the balloon makes guest use less rss memory.""" # Get the firecracker pid, and open an ssh connection. - firecracker_pid = test_microvm.firecracker_pid ssh_connection = test_microvm.ssh # Using deflate_on_oom, get the RSS as low as possible test_microvm.api.balloon.patch(amount_mib=200) # Get initial rss consumption. - init_rss = get_stable_rss_mem_by_pid(firecracker_pid) + init_rss = get_stable_rss_mem(test_microvm) # Get the balloon back to 0. test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Dirty memory, then inflate balloon and get ballooned rss consumption. make_guest_dirty_memory(ssh_connection, amount_mib=32) test_microvm.api.balloon.patch(amount_mib=200) - balloon_rss = get_stable_rss_mem_by_pid(firecracker_pid) + balloon_rss = get_stable_rss_mem(test_microvm) # Check that the ballooning reclaimed the memory. assert balloon_rss - init_rss <= 15000 @@ -149,7 +142,6 @@ def test_inflate_reduces_free(uvm_plain_any): # Start the microvm test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # Get the free memory before ballooning. available_mem_deflated = get_free_mem_ssh(test_microvm.ssh) @@ -157,7 +149,7 @@ def test_inflate_reduces_free(uvm_plain_any): # Inflate 64 MB == 16384 page balloon. test_microvm.api.balloon.patch(amount_mib=64) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get the free memory after ballooning. available_mem_inflated = get_free_mem_ssh(test_microvm.ssh) @@ -195,19 +187,18 @@ def test_deflate_on_oom(uvm_plain_any, deflate_on_oom): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # We get an initial reading of the RSS, then calculate the amount # we need to inflate the balloon with by subtracting it from the # VM size and adding an offset of 50 MiB in order to make sure we # get a lower reading than the initial one. - initial_rss = get_stable_rss_mem_by_pid(firecracker_pid) + initial_rss = get_stable_rss_mem(test_microvm) inflate_size = 256 - (int(initial_rss / 1024) + 50) # Inflate the balloon test_microvm.api.balloon.patch(amount_mib=inflate_size) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Check that using memory leads to the balloon device automatically # deflate (or not). @@ -250,39 +241,38 @@ def test_reinflate_balloon(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # First inflate the balloon to free up the uncertain amount of memory # used by the kernel at boot and establish a baseline, then give back # the memory. test_microvm.api.balloon.patch(amount_mib=200) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get the guest to dirty memory. make_guest_dirty_memory(test_microvm.ssh, amount_mib=32) - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(test_microvm) # Now inflate the balloon. test_microvm.api.balloon.patch(amount_mib=200) - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(test_microvm) # Now deflate the balloon. test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Now have the guest dirty memory again. make_guest_dirty_memory(test_microvm.ssh, amount_mib=32) - third_reading = get_stable_rss_mem_by_pid(firecracker_pid) + third_reading = get_stable_rss_mem(test_microvm) # Now inflate the balloon again. test_microvm.api.balloon.patch(amount_mib=200) - fourth_reading = get_stable_rss_mem_by_pid(firecracker_pid) + fourth_reading = get_stable_rss_mem(test_microvm) # Check that the memory used is the same after regardless of the previous # inflate history of the balloon (with the third reading being allowed @@ -309,10 +299,9 @@ def test_size_reduction(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # Check memory usage. - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(test_microvm) # Have the guest drop its caches. test_microvm.ssh.run("sync; echo 3 > /proc/sys/vm/drop_caches") @@ -328,7 +317,7 @@ def test_size_reduction(uvm_plain_any): test_microvm.api.balloon.patch(amount_mib=inflate_size) # Check memory usage again. - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(test_microvm) # There should be a reduction of at least 10MB. assert first_reading - second_reading >= 10000 @@ -353,7 +342,6 @@ def test_stats(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # Give Firecracker enough time to poll the stats at least once post-boot time.sleep(STATS_POLLING_INTERVAL_S * 2) @@ -371,7 +359,7 @@ def test_stats(uvm_plain_any): make_guest_dirty_memory(test_microvm.ssh, amount_mib=10) time.sleep(1) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Make sure that the stats catch the page faults. after_workload_stats = test_microvm.api.balloon_stats.get().json() @@ -380,7 +368,7 @@ def test_stats(uvm_plain_any): # Now inflate the balloon with 10MB of pages. test_microvm.api.balloon.patch(amount_mib=10) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get another reading of the stats after the polling interval has passed. inflated_stats = test_microvm.api.balloon_stats.get().json() @@ -393,7 +381,7 @@ def test_stats(uvm_plain_any): # available memory. test_microvm.api.balloon.patch(amount_mib=0) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get another reading of the stats after the polling interval has passed. deflated_stats = test_microvm.api.balloon_stats.get().json() @@ -421,13 +409,12 @@ def test_stats_update(uvm_plain_any): # Start the microvm. test_microvm.start() - firecracker_pid = test_microvm.firecracker_pid # Dirty 30MB of pages. make_guest_dirty_memory(test_microvm.ssh, amount_mib=30) # This call will internally wait for rss to become stable. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(test_microvm) # Get an initial reading of the stats. initial_stats = test_microvm.api.balloon_stats.get().json() @@ -477,17 +464,14 @@ def test_balloon_snapshot(uvm_plain_any, microvm_factory): make_guest_dirty_memory(vm.ssh, amount_mib=60) time.sleep(1) - # Get the firecracker pid, and open an ssh connection. - firecracker_pid = vm.firecracker_pid - # Check memory usage. - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(vm) # Now inflate the balloon with 20MB of pages. vm.api.balloon.patch(amount_mib=20) # Check memory usage again. - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(vm) # There should be a reduction in RSS, but it's inconsistent. # We only test that the reduction happens. @@ -496,28 +480,25 @@ def test_balloon_snapshot(uvm_plain_any, microvm_factory): snapshot = vm.snapshot_full() microvm = microvm_factory.build_from_snapshot(snapshot) - # Get the firecracker from snapshot pid, and open an ssh connection. - firecracker_pid = microvm.firecracker_pid - # Wait out the polling interval, then get the updated stats. time.sleep(STATS_POLLING_INTERVAL_S * 2) stats_after_snap = microvm.api.balloon_stats.get().json() # Check memory usage. - third_reading = get_stable_rss_mem_by_pid(firecracker_pid) + third_reading = get_stable_rss_mem(microvm) # Dirty 60MB of pages. make_guest_dirty_memory(microvm.ssh, amount_mib=60) # Check memory usage. - fourth_reading = get_stable_rss_mem_by_pid(firecracker_pid) + fourth_reading = get_stable_rss_mem(microvm) assert fourth_reading > third_reading # Inflate the balloon with another 20MB of pages. microvm.api.balloon.patch(amount_mib=40) - fifth_reading = get_stable_rss_mem_by_pid(firecracker_pid) + fifth_reading = get_stable_rss_mem(microvm) # There should be a reduction in RSS, but it's inconsistent. # We only test that the reduction happens. @@ -557,15 +538,14 @@ def test_memory_scrub(uvm_plain_any): microvm.api.balloon.patch(amount_mib=60) # Get the firecracker pid, and open an ssh connection. - firecracker_pid = microvm.firecracker_pid # Wait for the inflate to complete. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(microvm) # Deflate the balloon completely. microvm.api.balloon.patch(amount_mib=0) # Wait for the deflate to complete. - _ = get_stable_rss_mem_by_pid(firecracker_pid) + _ = get_stable_rss_mem(microvm) microvm.ssh.check_output("/usr/local/bin/readmem {} {}".format(60, 1)) diff --git a/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py b/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py index bfe5316d9e5..253502a2d1f 100644 --- a/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py +++ b/tests/integration_tests/functional/test_snapshot_restore_cross_kernel.py @@ -20,7 +20,7 @@ from framework.utils_cpu_templates import get_supported_cpu_templates from framework.utils_vsock import check_vsock_device from integration_tests.functional.test_balloon import ( - get_stable_rss_mem_by_pid, + get_stable_rss_mem, make_guest_dirty_memory, ) @@ -28,21 +28,18 @@ def _test_balloon(microvm): - # Get the firecracker pid. - firecracker_pid = microvm.firecracker_pid - # Check memory usage. - first_reading = get_stable_rss_mem_by_pid(firecracker_pid) + first_reading = get_stable_rss_mem(microvm) # Dirty 300MB of pages. make_guest_dirty_memory(microvm.ssh, amount_mib=300) # Check memory usage again. - second_reading = get_stable_rss_mem_by_pid(firecracker_pid) + second_reading = get_stable_rss_mem(microvm) assert second_reading > first_reading # Inflate the balloon. Get back 200MB. microvm.api.balloon.patch(amount_mib=200) - third_reading = get_stable_rss_mem_by_pid(firecracker_pid) + third_reading = get_stable_rss_mem(microvm) # Ensure that there is a reduction in RSS. assert second_reading > third_reading From fa2aa712ae72b01da2fe74f785d209d19cfbffc5 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Fri, 26 Sep 2025 15:39:55 +0100 Subject: [PATCH 11/14] refactor(test/balloon): move logic to get guest avail mem to framework Move the logic to get the MemAvailable from /proc/meminfo inside the guest to a new guest_stats module in the test framework. This provides a new class MeminfoGuest that can be used to retrieve this information (and more!). Signed-off-by: Riccardo Mancini --- tests/framework/guest_stats.py | 79 +++++++++++++++++++ tests/framework/utils.py | 19 ----- .../functional/test_balloon.py | 8 +- 3 files changed, 84 insertions(+), 22 deletions(-) create mode 100644 tests/framework/guest_stats.py diff --git a/tests/framework/guest_stats.py b/tests/framework/guest_stats.py new file mode 100644 index 00000000000..468d7167c44 --- /dev/null +++ b/tests/framework/guest_stats.py @@ -0,0 +1,79 @@ +# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Classes for querying guest stats inside microVMs. +""" + + +class ByteUnit: + """Represents a byte unit that can be converted to other units.""" + + value_bytes: int + + def __init__(self, value_bytes: int): + self.value_bytes = value_bytes + + @classmethod + def from_kib(cls, value_kib: int): + """Creates a ByteUnit from a value in KiB.""" + if value_kib < 0: + raise ValueError("value_kib must be non-negative") + return ByteUnit(value_kib * 1024) + + def bytes(self) -> int: + """Returns the value in B.""" + return self.value_bytes + + def kib(self) -> float: + """Returns the value in KiB as a decimal.""" + return self.value_bytes / 1024 + + def mib(self) -> float: + """Returns the value in MiB as a decimal.""" + return self.value_bytes / (1 << 20) + + def gib(self) -> float: + """Returns the value in GiB as a decimal.""" + return self.value_bytes / (1 << 30) + + +class Meminfo: + """Represents the contents of /proc/meminfo inside the guest""" + + mem_total: ByteUnit + mem_free: ByteUnit + mem_available: ByteUnit + buffers: ByteUnit + cached: ByteUnit + + def __init__(self): + self.mem_total = ByteUnit(0) + self.mem_free = ByteUnit(0) + self.mem_available = ByteUnit(0) + self.buffers = ByteUnit(0) + self.cached = ByteUnit(0) + + +class MeminfoGuest: + """Queries /proc/meminfo inside the guest""" + + def __init__(self, vm): + self.vm = vm + + def get(self) -> Meminfo: + """Returns the contents of /proc/meminfo inside the guest""" + meminfo = Meminfo() + for line in self.vm.ssh.check_output("cat /proc/meminfo").stdout.splitlines(): + parts = line.split() + if parts[0] == "MemTotal:": + meminfo.mem_total = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "MemFree:": + meminfo.mem_free = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "MemAvailable:": + meminfo.mem_available = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "Buffers:": + meminfo.buffers = ByteUnit.from_kib(int(parts[1])) + elif parts[0] == "Cached:": + meminfo.cached = ByteUnit.from_kib(int(parts[1])) + + return meminfo diff --git a/tests/framework/utils.py b/tests/framework/utils.py index d592daec84f..448b351fd86 100644 --- a/tests/framework/utils.py +++ b/tests/framework/utils.py @@ -254,25 +254,6 @@ def search_output_from_cmd(cmd: str, find_regex: typing.Pattern) -> typing.Match ) -def get_free_mem_ssh(ssh_connection): - """ - Get how much free memory in kB a guest sees, over ssh. - - :param ssh_connection: connection to the guest - :return: available mem column output of 'free' - """ - _, stdout, stderr = ssh_connection.run("cat /proc/meminfo | grep MemAvailable") - assert stderr == "" - - # Split "MemAvailable: 123456 kB" and validate it - meminfo_data = stdout.split() - if len(meminfo_data) == 3: - # Return the middle element in the array - return int(meminfo_data[1]) - - raise Exception("Available memory not found in `/proc/meminfo") - - def _format_output_message(proc, stdout, stderr): output_message = f"\n[{proc.pid}] Command:\n{proc.args}" # Append stdout/stderr to the output message diff --git a/tests/integration_tests/functional/test_balloon.py b/tests/integration_tests/functional/test_balloon.py index 59c87358c42..19b1651c72a 100644 --- a/tests/integration_tests/functional/test_balloon.py +++ b/tests/integration_tests/functional/test_balloon.py @@ -9,7 +9,8 @@ import pytest import requests -from framework.utils import get_free_mem_ssh, get_resident_memory +from framework.guest_stats import MeminfoGuest +from framework.utils import get_resident_memory STATS_POLLING_INTERVAL_S = 1 @@ -142,9 +143,10 @@ def test_inflate_reduces_free(uvm_plain_any): # Start the microvm test_microvm.start() + meminfo = MeminfoGuest(test_microvm) # Get the free memory before ballooning. - available_mem_deflated = get_free_mem_ssh(test_microvm.ssh) + available_mem_deflated = meminfo.get().mem_free.kib() # Inflate 64 MB == 16384 page balloon. test_microvm.api.balloon.patch(amount_mib=64) @@ -152,7 +154,7 @@ def test_inflate_reduces_free(uvm_plain_any): _ = get_stable_rss_mem(test_microvm) # Get the free memory after ballooning. - available_mem_inflated = get_free_mem_ssh(test_microvm.ssh) + available_mem_inflated = meminfo.get().mem_free.kib() # Assert that ballooning reclaimed about 64 MB of memory. assert available_mem_inflated <= available_mem_deflated - 85 * 64000 / 100 From 55da1377af45bfce7e94c9d9030208348b8b65c2 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Fri, 26 Sep 2025 15:44:05 +0100 Subject: [PATCH 12/14] test(virtio-mem): add functional integration tests for device Add integration tests for the new device: - check that the device is detected - check that hotplugging and unplugging works - check that memory can be used after hotplugging - check that memory is freed on hotunplug - check different config combinations - check different uvm types - check that contents are preserved across snapshot-restore Signed-off-by: Riccardo Mancini --- tests/framework/microvm.py | 28 ++ .../functional/test_memory_hp.py | 271 +++++++++++++++++- 2 files changed, 288 insertions(+), 11 deletions(-) diff --git a/tests/framework/microvm.py b/tests/framework/microvm.py index 69f00e4b94b..853de6fc4ef 100644 --- a/tests/framework/microvm.py +++ b/tests/framework/microvm.py @@ -1186,6 +1186,22 @@ def wait_for_ssh_up(self): # run commands. The actual connection retry loop happens in SSHConnection._init_connection _ = self.ssh_iface(0) + def hotplug_memory( + self, requested_size_mib: int, timeout: int = 60, poll: float = 0.1 + ): + """Send a hot(un)plug request and wait up to timeout seconds for completion polling every poll seconds""" + self.api.memory_hotplug.patch(requested_size_mib=requested_size_mib) + # Wait for the hotplug to complete + deadline = time.time() + timeout + while time.time() < deadline: + if ( + self.api.memory_hotplug.get().json()["plugged_size_mib"] + == requested_size_mib + ): + return + time.sleep(poll) + raise TimeoutError(f"Hotplug did not complete within {timeout} seconds") + class MicroVMFactory: """MicroVM factory""" @@ -1300,6 +1316,18 @@ def build_n_from_snapshot( last_snapshot.delete() current_snapshot.delete() + def clone_uvm(self, uvm, uffd_handler_name=None): + """ + Clone the given VM and start it. + """ + snapshot = uvm.snapshot_full() + restored_vm = self.build() + restored_vm.spawn() + restored_vm.restore_from_snapshot( + snapshot, resume=True, uffd_handler_name=uffd_handler_name + ) + return restored_vm + def kill(self): """Clean up all built VMs""" for vm in self.vms: diff --git a/tests/integration_tests/functional/test_memory_hp.py b/tests/integration_tests/functional/test_memory_hp.py index b2132d6c9ed..87b243919a6 100644 --- a/tests/integration_tests/functional/test_memory_hp.py +++ b/tests/integration_tests/functional/test_memory_hp.py @@ -3,21 +3,120 @@ """Tests for verifying the virtio-mem is working correctly""" +import pytest +from packaging import version +from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed + +from framework.guest_stats import MeminfoGuest +from framework.microvm import HugePagesConfig +from framework.utils import get_kernel_version, get_resident_memory + +MEMHP_BOOTARGS = "console=ttyS0 reboot=k panic=1 memhp_default_state=online_movable" +DEFAULT_CONFIG = {"total_size_mib": 1024, "slot_size_mib": 128, "block_size_mib": 2} + + +def uvm_booted_memhp( + uvm, rootfs, _microvm_factory, vhost_user, memhp_config, huge_pages, _uffd_handler +): + """Boots a VM with the given memory hotplugging config""" -def test_virtio_mem_detected(uvm_plain_6_1): - """ - Check that the guest kernel has enabled PV steal time. - """ - uvm = uvm_plain_6_1 uvm.spawn() uvm.memory_monitor = None - uvm.basic_config( - boot_args="console=ttyS0 reboot=k panic=1 memhp_default_state=online_movable" - ) + if vhost_user: + # We need to setup ssh keys manually because we did not specify rootfs + # in microvm_factory.build method + ssh_key = rootfs.with_suffix(".id_rsa") + uvm.ssh_key = ssh_key + uvm.basic_config( + boot_args=MEMHP_BOOTARGS, add_root_device=False, huge_pages=huge_pages + ) + uvm.add_vhost_user_drive( + "rootfs", rootfs, is_root_device=True, is_read_only=True + ) + else: + uvm.basic_config(boot_args=MEMHP_BOOTARGS, huge_pages=huge_pages) + + uvm.api.memory_hotplug.put(**memhp_config) uvm.add_net_iface() - uvm.api.memory_hotplug.put(total_size_mib=1024) uvm.start() + return uvm + + +def uvm_resumed_memhp( + uvm_plain, + rootfs, + microvm_factory, + vhost_user, + memhp_config, + huge_pages, + uffd_handler, +): + """Restores a VM with the given memory hotplugging config after booting and snapshotting""" + if vhost_user: + pytest.skip("vhost-user doesn't support snapshot/restore") + if huge_pages and huge_pages != HugePagesConfig.NONE and not uffd_handler: + pytest.skip("Hugepages requires a UFFD handler") + uvm = uvm_booted_memhp( + uvm_plain, rootfs, microvm_factory, vhost_user, memhp_config, huge_pages, None + ) + return microvm_factory.clone_uvm(uvm, uffd_handler_name=uffd_handler) + + +@pytest.fixture( + params=[ + (uvm_booted_memhp, False, HugePagesConfig.NONE, None), + (uvm_booted_memhp, False, HugePagesConfig.HUGETLBFS_2MB, None), + (uvm_booted_memhp, True, HugePagesConfig.NONE, None), + (uvm_resumed_memhp, False, HugePagesConfig.NONE, None), + (uvm_resumed_memhp, False, HugePagesConfig.NONE, "on_demand"), + (uvm_resumed_memhp, False, HugePagesConfig.HUGETLBFS_2MB, "on_demand"), + ], + ids=[ + "booted", + "booted-huge-pages", + "booted-vhost-user", + "resumed", + "resumed-uffd", + "resumed-uffd-huge-pages", + ], +) +def uvm_any_memhp(request, uvm_plain_6_1, rootfs, microvm_factory): + """Fixture that yields a booted or resumed VM with memory hotplugging""" + ctor, vhost_user, huge_pages, uffd_handler = request.param + yield ctor( + uvm_plain_6_1, + rootfs, + microvm_factory, + vhost_user, + DEFAULT_CONFIG, + huge_pages, + uffd_handler, + ) + + +def supports_hugetlbfs_discard(): + """Returns True if the kernel supports hugetlbfs discard""" + return version.parse(get_kernel_version()) >= version.parse("5.18.0") + + +def validate_metrics(uvm): + """Validates that there are no fails in the metrics""" + metrics_to_check = ["plug_fails", "unplug_fails", "unplug_all_fails", "state_fails"] + if supports_hugetlbfs_discard(): + metrics_to_check.append("unplug_discard_fails") + uvm.flush_metrics() + for metrics in uvm.get_all_metrics(): + for k in metrics_to_check: + assert ( + metrics["memory_hotplug"][k] == 0 + ), f"{k}={metrics[k]} is greater than zero" + +def check_device_detected(uvm): + """ + Check that the guest kernel has enabled virtio-mem. + """ + hp_config = uvm.api.memory_hotplug.get().json() _, stdout, _ = uvm.ssh.check_output("dmesg | grep 'virtio_mem'") for line in stdout.splitlines(): _, key, value = line.strip().split(":") @@ -27,12 +126,162 @@ def test_virtio_mem_detected(uvm_plain_6_1): case "start address": assert value == (512 << 30), "start address doesn't match" case "region size": - assert value == 1024 << 20, "region size doesn't match" + assert ( + value == hp_config["total_size_mib"] << 20 + ), "region size doesn't match" case "device block size": - assert value == 2 << 20, "block size doesn't match" + assert ( + value == hp_config["block_size_mib"] << 20 + ), "block size doesn't match" case "plugged size": assert value == 0, "plugged size doesn't match" case "requested size": assert value == 0, "requested size doesn't match" case _: continue + + +def check_memory_usable(uvm): + """Allocates memory to verify it's usable (5% margin to avoid OOM-kill)""" + mem_available = MeminfoGuest(uvm).get().mem_available.mib() + # try to allocate 95% of available memory + amount_mib = int(mem_available * 95 / 100) + + _ = uvm.ssh.check_output(f"/usr/local/bin/fillmem {amount_mib}", timeout=10) + # verify the allocation was successful + _ = uvm.ssh.check_output("cat /tmp/fillmem_output.txt | grep successful") + + +def check_hotplug(uvm, requested_size_mib): + """Verifies memory can be hot(un)plugged""" + meminfo = MeminfoGuest(uvm) + mem_total_fixed = ( + meminfo.get().mem_total.mib() + - uvm.api.memory_hotplug.get().json()["plugged_size_mib"] + ) + uvm.hotplug_memory(requested_size_mib) + + # verify guest driver received the request + _, stdout, _ = uvm.ssh.check_output( + "dmesg | grep 'virtio_mem' | grep 'requested size' | tail -1" + ) + assert ( + int(stdout.strip().split(":")[-1].strip(), base=0) == requested_size_mib << 20 + ) + + for attempt in Retrying( + retry=retry_if_exception_type(AssertionError), + stop=stop_after_delay(5), + wait=wait_fixed(1), + reraise=True, + ): + with attempt: + # verify guest driver executed the request + mem_total_after = meminfo.get().mem_total.mib() + assert mem_total_after == mem_total_fixed + requested_size_mib + + +def check_hotunplug(uvm, requested_size_mib): + """Verifies memory can be hotunplugged and gets released""" + + rss_before = get_resident_memory(uvm.ps) + + check_hotplug(uvm, requested_size_mib) + + rss_after = get_resident_memory(uvm.ps) + + print(f"RSS before: {rss_before}, after: {rss_after}") + + huge_pages = HugePagesConfig(uvm.api.machine_config.get().json()["huge_pages"]) + if huge_pages == HugePagesConfig.HUGETLBFS_2MB and supports_hugetlbfs_discard(): + assert rss_after < rss_before, "RSS didn't decrease" + + +def test_virtio_mem_hotplug_hotunplug(uvm_any_memhp): + """ + Check that memory can be hotplugged into the VM. + """ + uvm = uvm_any_memhp + check_device_detected(uvm) + + check_hotplug(uvm, 1024) + check_memory_usable(uvm) + + check_hotunplug(uvm, 0) + + # Check it works again + check_hotplug(uvm, 1024) + check_memory_usable(uvm) + + validate_metrics(uvm) + + +@pytest.mark.parametrize( + "memhp_config", + [ + {"total_size_mib": 256, "slot_size_mib": 128, "block_size_mib": 64}, + {"total_size_mib": 256, "slot_size_mib": 128, "block_size_mib": 128}, + {"total_size_mib": 256, "slot_size_mib": 256, "block_size_mib": 64}, + {"total_size_mib": 256, "slot_size_mib": 256, "block_size_mib": 256}, + ], + ids=["all_different", "slot_sized_block", "single_slot", "single_block"], +) +def test_virtio_mem_configs(uvm_plain_6_1, memhp_config): + """ + Check that the virtio mem device is working as expected for different configs + """ + uvm = uvm_booted_memhp(uvm_plain_6_1, None, None, False, memhp_config, None, None) + if not uvm.pci_enabled: + pytest.skip( + "Skip tests on MMIO transport to save time as we don't expect any difference." + ) + + check_device_detected(uvm) + + for size in range( + 0, memhp_config["total_size_mib"] + 1, memhp_config["block_size_mib"] + ): + check_hotplug(uvm, size) + + check_memory_usable(uvm) + + for size in range( + memhp_config["total_size_mib"] - memhp_config["block_size_mib"], + -1, + -memhp_config["block_size_mib"], + ): + check_hotunplug(uvm, size) + + validate_metrics(uvm) + + +def test_snapshot_restore_persistence(uvm_plain_6_1, microvm_factory): + """ + Check that hptplugged memory is persisted across snapshot/restore. + """ + if not uvm_plain_6_1.pci_enabled: + pytest.skip( + "Skip tests on MMIO transport to save time as we don't expect any difference." + ) + uvm = uvm_booted_memhp( + uvm_plain_6_1, None, microvm_factory, False, DEFAULT_CONFIG, None, None + ) + + uvm.hotplug_memory(1024) + + # Increase /dev/shm size as it defaults to half of the boot memory + uvm.ssh.check_output("mount -o remount,size=1024M -t tmpfs tmpfs /dev/shm") + + uvm.ssh.check_output("dd if=/dev/urandom of=/dev/shm/mem_hp_test bs=1M count=1024") + + _, checksum_before, _ = uvm.ssh.check_output("sha256sum /dev/shm/mem_hp_test") + + restored_vm = microvm_factory.clone_uvm(uvm) + + _, checksum_after, _ = restored_vm.ssh.check_output( + "sha256sum /dev/shm/mem_hp_test" + ) + + assert checksum_before == checksum_after, "Checksums didn't match" + + validate_metrics(restored_vm) From 6dabd5c076c68e38f497bdb9474d32c0a60560f0 Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Thu, 2 Oct 2025 17:17:29 +0100 Subject: [PATCH 13/14] chore(test/virtio-mem): move tests under performance Since these tests need to be run on an ag=1 host, move them under the "performance" folder. Signed-off-by: Riccardo Mancini --- .../test_hotplug_memory.py} | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) rename tests/integration_tests/{functional/test_memory_hp.py => performance/test_hotplug_memory.py} (98%) diff --git a/tests/integration_tests/functional/test_memory_hp.py b/tests/integration_tests/performance/test_hotplug_memory.py similarity index 98% rename from tests/integration_tests/functional/test_memory_hp.py rename to tests/integration_tests/performance/test_hotplug_memory.py index 87b243919a6..a598307d74e 100644 --- a/tests/integration_tests/functional/test_memory_hp.py +++ b/tests/integration_tests/performance/test_hotplug_memory.py @@ -1,7 +1,12 @@ # Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -"""Tests for verifying the virtio-mem is working correctly""" +""" +Tests for verifying the virtio-mem is working correctly + +This file also contains functional tests for virtio-mem because they need to be +run on an ag=1 host due to the use of HugePages. +""" import pytest from packaging import version From a56af1714e417c91bfcc3dc33314bc288f4112be Mon Sep 17 00:00:00 2001 From: Riccardo Mancini Date: Mon, 6 Oct 2025 16:29:57 +0100 Subject: [PATCH 14/14] test(virtio-mem): add rust integration tests These tests add unit test coverage to the builder.rs and vm.rs files which where previously untested in the memory hotplug case. Signed-off-by: Riccardo Mancini --- src/vmm/src/test_utils/mod.rs | 24 ++++----- src/vmm/tests/integration_tests.rs | 79 ++++++++++++++++++------------ 2 files changed, 60 insertions(+), 43 deletions(-) diff --git a/src/vmm/src/test_utils/mod.rs b/src/vmm/src/test_utils/mod.rs index 6fe66cdbadb..887acc54d38 100644 --- a/src/vmm/src/test_utils/mod.rs +++ b/src/vmm/src/test_utils/mod.rs @@ -16,6 +16,7 @@ use crate::vm_memory_vendored::GuestRegionCollection; use crate::vmm_config::boot_source::BootSourceConfig; use crate::vmm_config::instance_info::InstanceInfo; use crate::vmm_config::machine_config::HugePageConfig; +use crate::vmm_config::memory_hotplug::MemoryHotplugConfig; use crate::vstate::memory::{self, GuestMemoryMmap, GuestRegionMmap, GuestRegionMmapExt}; use crate::{EventManager, Vmm}; @@ -73,6 +74,7 @@ pub fn create_vmm( is_diff: bool, boot_microvm: bool, pci_enabled: bool, + memory_hotplug_enabled: bool, ) -> (Arc>, EventManager) { let mut event_manager = EventManager::new().unwrap(); let empty_seccomp_filters = get_empty_filters(); @@ -96,6 +98,14 @@ pub fn create_vmm( resources.pci_enabled = pci_enabled; + if memory_hotplug_enabled { + resources.memory_hotplug = Some(MemoryHotplugConfig { + total_size_mib: 1024, + block_size_mib: 2, + slot_size_mib: 128, + }); + } + let vmm = build_microvm_for_boot( &InstanceInfo::default(), &resources, @@ -112,23 +122,15 @@ pub fn create_vmm( } pub fn default_vmm(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, true, false) + create_vmm(kernel_image, false, true, false, false) } pub fn default_vmm_no_boot(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, false, false) -} - -pub fn default_vmm_pci_no_boot(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, false, true) + create_vmm(kernel_image, false, false, false, false) } pub fn dirty_tracking_vmm(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, true, true, false) -} - -pub fn default_vmm_pci(kernel_image: Option<&str>) -> (Arc>, EventManager) { - create_vmm(kernel_image, false, true, false) + create_vmm(kernel_image, true, true, false, false) } #[allow(clippy::undocumented_unsafe_blocks)] diff --git a/src/vmm/tests/integration_tests.rs b/src/vmm/tests/integration_tests.rs index 4abbedc4530..6a5e6a08a14 100644 --- a/src/vmm/tests/integration_tests.rs +++ b/src/vmm/tests/integration_tests.rs @@ -18,9 +18,7 @@ use vmm::rpc_interface::{ use vmm::seccomp::get_empty_filters; use vmm::snapshot::Snapshot; use vmm::test_utils::mock_resources::{MockVmResources, NOISY_KERNEL_IMAGE}; -use vmm::test_utils::{ - create_vmm, default_vmm, default_vmm_no_boot, default_vmm_pci, default_vmm_pci_no_boot, -}; +use vmm::test_utils::{create_vmm, default_vmm, default_vmm_no_boot}; use vmm::vmm_config::balloon::BalloonDeviceConfig; use vmm::vmm_config::boot_source::BootSourceConfig; use vmm::vmm_config::drive::BlockDeviceConfig; @@ -66,13 +64,12 @@ fn test_build_and_boot_microvm() { assert_eq!(format!("{:?}", vmm_ret.err()), "Some(MissingKernelConfig)"); } - // Success case. - let (vmm, evmgr) = default_vmm(None); - check_booted_microvm(vmm, evmgr); - - // microVM with PCI - let (vmm, evmgr) = default_vmm_pci(None); - check_booted_microvm(vmm, evmgr); + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + let (vmm, evmgr) = create_vmm(None, false, true, pci_enabled, memory_hotplug); + check_booted_microvm(vmm, evmgr); + } + } } #[allow(unused_mut, unused_variables)] @@ -96,10 +93,12 @@ fn check_build_microvm(vmm: Arc>, mut evmgr: EventManager) { #[test] fn test_build_microvm() { - let (vmm, evtmgr) = default_vmm_no_boot(None); - check_build_microvm(vmm, evtmgr); - let (vmm, evtmgr) = default_vmm_pci_no_boot(None); - check_build_microvm(vmm, evtmgr); + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + let (vmm, evmgr) = create_vmm(None, false, false, pci_enabled, memory_hotplug); + check_build_microvm(vmm, evmgr); + } + } } fn pause_resume_microvm(vmm: Arc>) { @@ -118,13 +117,14 @@ fn pause_resume_microvm(vmm: Arc>) { #[test] fn test_pause_resume_microvm() { - // Tests that pausing and resuming a microVM work as expected. - let (vmm, _) = default_vmm(None); + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + // Tests that pausing and resuming a microVM work as expected. + let (vmm, _) = create_vmm(None, false, true, pci_enabled, memory_hotplug); - pause_resume_microvm(vmm); - - let (vmm, _) = default_vmm_pci(None); - pause_resume_microvm(vmm); + pause_resume_microvm(vmm); + } + } } #[test] @@ -195,11 +195,21 @@ fn test_disallow_dump_cpu_config_without_pausing() { vmm.lock().unwrap().stop(FcExitCode::Ok); } -fn verify_create_snapshot(is_diff: bool, pci_enabled: bool) -> (TempFile, TempFile) { +fn verify_create_snapshot( + is_diff: bool, + pci_enabled: bool, + memory_hotplug: bool, +) -> (TempFile, TempFile) { let snapshot_file = TempFile::new().unwrap(); let memory_file = TempFile::new().unwrap(); - let (vmm, _) = create_vmm(Some(NOISY_KERNEL_IMAGE), is_diff, true, pci_enabled); + let (vmm, _) = create_vmm( + Some(NOISY_KERNEL_IMAGE), + is_diff, + true, + pci_enabled, + memory_hotplug, + ); let resources = VmResources { machine_config: MachineConfig { mem_size_mib: 1, @@ -303,14 +313,19 @@ fn verify_load_snapshot(snapshot_file: TempFile, memory_file: TempFile) { #[test] fn test_create_and_load_snapshot() { - for (diff_snap, pci_enabled) in [(false, false), (false, true), (true, false), (true, true)] { - // Create snapshot. - let (snapshot_file, memory_file) = verify_create_snapshot(diff_snap, pci_enabled); - // Create a new microVm from snapshot. This only tests code-level logic; it verifies - // that a microVM can be built with no errors from given snapshot. - // It does _not_ verify that the guest is actually restored properly. We're using - // python integration tests for that. - verify_load_snapshot(snapshot_file, memory_file); + for diff_snap in [false, true] { + for pci_enabled in [false, true] { + for memory_hotplug in [false, true] { + // Create snapshot. + let (snapshot_file, memory_file) = + verify_create_snapshot(diff_snap, pci_enabled, memory_hotplug); + // Create a new microVm from snapshot. This only tests code-level logic; it verifies + // that a microVM can be built with no errors from given snapshot. + // It does _not_ verify that the guest is actually restored properly. We're using + // python integration tests for that. + verify_load_snapshot(snapshot_file, memory_file); + } + } } } @@ -338,7 +353,7 @@ fn check_snapshot(mut microvm_state: MicrovmState) { fn get_microvm_state_from_snapshot(pci_enabled: bool) -> MicrovmState { // Create a diff snapshot - let (snapshot_file, _) = verify_create_snapshot(true, pci_enabled); + let (snapshot_file, _) = verify_create_snapshot(true, pci_enabled, false); // Deserialize the microVM state. snapshot_file.as_file().seek(SeekFrom::Start(0)).unwrap(); @@ -346,7 +361,7 @@ fn get_microvm_state_from_snapshot(pci_enabled: bool) -> MicrovmState { } fn verify_load_snap_disallowed_after_boot_resources(res: VmmAction, res_name: &str) { - let (snapshot_file, memory_file) = verify_create_snapshot(false, false); + let (snapshot_file, memory_file) = verify_create_snapshot(false, false, false); let mut event_manager = EventManager::new().unwrap(); let empty_seccomp_filters = get_empty_filters();