diff --git a/docs/source/getting-started/example/index.md b/docs/source/getting-started/example/index.md
index 686e7465..f6ff52f5 100644
--- a/docs/source/getting-started/example/index.md
+++ b/docs/source/getting-started/example/index.md
@@ -4,6 +4,7 @@
:maxdepth: 2
nfs_conn.md
dram_conn.md
+mooncake_conn.md
disaggregated_prefill/index.md
:::
diff --git a/docs/source/getting-started/example/mooncake_conn.md b/docs/source/getting-started/example/mooncake_conn.md
new file mode 100644
index 00000000..34561fc3
--- /dev/null
+++ b/docs/source/getting-started/example/mooncake_conn.md
@@ -0,0 +1,180 @@
+# Mooncake Connector
+
+This document provides a usage example and configuration guide for the **Mooncake Connector**. This connector enables offloading of KV cache from GPU HBM to CPU Mooncake, helping reduce memory pressure and support larger models or batch sizes.
+
+## Performance
+
+| tokens | mooncake-first | mooncake-second | default |
+| ------ | ------------------ | ------------------ | ------------------ |
+| 2k | 1.9231491860002279 | 0.8265988459810615 | 0.5419427898712457 |
+| 4k | 3.9460434830747544 | 1.5273493870627135 | 0.991630249004811 |
+| 8k | 7.577957597002387 | 2.7632693520281464 | 2.0716467570047827 |
+| 16k | 16.823639799049126 | 5.515289016952738 | 4.742832682048902 |
+| 32k | 81.98759594326839 | 14.217441103421152 | 12.310140203218907 |
+
+Use mooncake fig && default:
+
+
+
+
+## Features
+
+The Monncake connector supports the following functionalities:
+
+- `dump`: Offload KV cache blocks from HBM to Mooncake.
+- `load`: Load KV cache blocks from Mooncake back to HBM.
+- `lookup`: Look up KV blocks stored in Mooncake by block hash.
+- `wait`: Ensure that all copy streams between CPU and GPU have completed.
+
+## Configuration
+
+### Start Mooncake Services
+
+1. Follow the [Mooncake official guide](https://github.com/kvcache-ai/Mooncake/blob/v0.3.4/doc/en/build.md) to build Mooncake.
+
+> **[Warning]**: Currently, this connector only supports Mooncake v0.3.4, and the updated version is being adapted.
+
+2. Start Mooncake Store Service
+
+ Please change the IP addresses and ports in the following guide according to your env.
+
+```bash
+# Unset HTTP proxies
+unset http_proxy https_proxy no_proxy HTTP_PROXY HTTPS_PROXY NO_PROXY
+# Navigate to the metadata server directory, http server for example.
+cd $MOONCAKE_ROOT_DIR/mooncake-transfer-engine/example/http-metadata-server
+# Start Metadata Service
+go run . --addr=0.0.0.0:23790
+# Start Master Service
+mooncake_master --port 50001
+```
+- Replace `$MOONCAKE_ROOT_DIR` with your Mooncake source root path.
+- Make sure to unset any HTTP proxies to prevent networking issues.
+- Use appropriate port based on your environment.
+
+
+
+### Required Parameters
+
+To use the Mooncake connector, you need to configure the `connector_config` dictionary in your model's launch configuration.
+
+- `local_hostname`:
+ The IP address of the current node used to communicate with the metadata server.
+- `metadata_server`:
+ The metadata server of the mooncake transfer engine.
+- `master_server_address`:
+ The IP address and the port of the master daemon process of MooncakeStore.
+- `protocol` *(optional)*:
+ If not provided, it defaults to **tcp**.
+- `device_name` *(optional)*:
+ The device to be used for data transmission, it is required when “protocol” is set to “rdma”. If multiple NIC devices are used, they can be separated by commas such as “erdma_0,erdma_1”. Please note that there are no spaces between them.
+- `global_segment_size`*(optional)*:
+ The size of each global segment in bytes. `DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200` **3.125 GiB**
+- `local_buffer_size`*(optional)*:
+ The size of the local buffer in bytes. `DEFAULT_LOCAL_BUFFER_SIZE = 1073741824` **1.0 GiB**
+
+
+### Example:
+
+```python
+kv_connector_extra_config={
+ "ucm_connector_name": "UcmMooncakeStore",
+ "ucm_connector_config":{
+ "local_hostname": "127.0.0.1",
+ "metadata_server": "http://127.0.0.1:23790/metadata",
+ "protocol": "tcp",
+ "device_name": "",
+ "master_server_address": "127.0.0.1:50001"
+ }
+ }
+```
+
+## Launching Inference
+
+### Offline Inference
+
+To start **offline inference** with the Mooncake connector,modify the script `examples/offline_inference.py` to include the `kv_connector_extra_config` for Mooncake connector usage:
+
+```python
+# In examples/offline_inference.py
+ktc = KVTransferConfig(
+ ...
+ kv_connector_extra_config={
+ "ucm_connector_name": "UcmMooncakeStore",
+ "ucm_connector_config":{
+ "local_hostname": "127.0.0.1",
+ "metadata_server": "http://127.0.0.1:23790/metadata",
+ "protocol": "tcp",
+ "device_name": "",
+ "master_server_address": "127.0.0.1:50001"
+ }
+ }
+)
+```
+
+Then run the script as follows:
+
+```bash
+cd examples/
+python offline_inference.py
+```
+
+### Online Inference
+
+For **online inference** , vLLM with our connector can also be deployed as a server that implements the OpenAI API protocol.
+
+First, specify the python hash seed by:
+```bash
+export PYTHONHASHSEED=123456
+```
+
+Run the following command to start the vLLM server with the Qwen/Qwen2.5-14B-Instruct model:
+
+```bash
+vllm serve /home/models/Qwen2.5-14B-Instruct \
+--max-model-len 20000 \
+--tensor-parallel-size 2 \
+--gpu_memory_utilization 0.87 \
+--trust-remote-code \
+--port 7800 \
+--kv-transfer-config \
+'{
+ "kv_connector": "UnifiedCacheConnectorV1",
+ "kv_connector_module_path": "unifiedcache.integration.vllm.uc_connector",
+ "kv_role": "kv_both",
+ "kv_connector_extra_config": {
+ "ucm_connector_name": "UcmMooncakeStore",
+ "ucm_connector_config":{
+ "local_hostname": "127.0.0.1",
+ "metadata_server": "http://127.0.0.1:23790/metadata",
+ "protocol": "tcp",
+ "device_name": "",
+ "master_server_address": "127.0.0.1:50001"
+ }
+ }
+ }
+}'
+```
+
+If you see log as below:
+
+```bash
+INFO: Started server process [321290]
+INFO: Waiting for application startup.
+INFO: Application startup complete.
+```
+
+Congratulations, you have successfully started the vLLM server with Mooncake Connector!
+
+After successfully started the vLLM server,You can interact with the API as following:
+
+```bash
+curl http://localhost:7800/v1/completions \
+ -H "Content-Type: application/json" \
+ -d '{
+ "model": "/home/models/Qwen2.5-14B-Instruct",
+ "prompt": "Shanghai is a",
+ "max_tokens": 7,
+ "temperature": 0
+ }'
+```
diff --git a/docs/source/images/mooncake_performance.png b/docs/source/images/mooncake_performance.png
new file mode 100644
index 00000000..6fd3b43f
Binary files /dev/null and b/docs/source/images/mooncake_performance.png differ
diff --git a/test/test_mooncake.py b/test/test_mooncake.py
new file mode 100644
index 00000000..37bc7972
--- /dev/null
+++ b/test/test_mooncake.py
@@ -0,0 +1,155 @@
+import hashlib
+import uuid
+
+import torch
+
+from unifiedcache.logger import init_logger
+from unifiedcache.ucm_connector.base import Task
+from unifiedcache.ucm_connector.ucm_mooncake import UcmMooncakeStore
+
+logger = init_logger(__name__)
+
+mooncake_dict_config = {
+ "local_hostname": "127.0.0.1",
+ "metadata_server": "http://127.0.0.1:23790/metadata",
+ "protocol": "tcp",
+ "device_name": "",
+ "master_server_address": "127.0.0.1:50001",
+}
+
+
+def tensor_hash(tensor: torch.Tensor) -> str:
+ """Calculate the hash value of the tensor."""
+ tensor_bytes = tensor.clone().detach().cpu().numpy().tobytes()
+ hash_object = hashlib.blake2b(tensor_bytes)
+ hash_hex = hash_object.hexdigest()
+ return str(int(hash_hex[:16], 16))
+
+
+def test_lookup_not_found():
+ """Test that lookup returns False for non-existent block IDs."""
+ store = UcmMooncakeStore(mooncake_dict_config)
+ block_ids = [uuid.uuid4().hex for _ in range(10)]
+ masks = store.lookup(block_ids)
+ assert all(mask is False for mask in masks)
+
+
+def test_lookup_found():
+ """Test that lookup returns True for existing block IDs after dumping data."""
+ src_block_data = [
+ torch.randint(0, 1000, (1, 100), dtype=torch.int) for _ in range(5)
+ ]
+ block_ids = [tensor_hash(data) for data in src_block_data]
+ offset = [0] * len(block_ids)
+
+ store = UcmMooncakeStore(mooncake_dict_config)
+ task: Task = store.dump(
+ block_ids=block_ids, offset=offset, src_tensor=src_block_data
+ )
+ ret = store.wait(task)
+ assert ret == 0
+ masks = store.lookup(block_ids)
+ assert all(mask is True for mask in masks)
+
+
+def test_dump_once():
+ """Test dumping data once and verifying it exists in the store."""
+ src_block_data = [
+ torch.randint(0, 1000, (1, 100), dtype=torch.int) for _ in range(5)
+ ]
+ block_ids = [tensor_hash(data) for data in src_block_data]
+ offset = [0] * len(block_ids)
+
+ store = UcmMooncakeStore(mooncake_dict_config)
+ task: Task = store.dump(
+ block_ids=block_ids, offset=offset, src_tensor=src_block_data
+ )
+ ret = store.wait(task)
+ assert ret == 0
+ masks = store.lookup(block_ids)
+ assert all(mask is True for mask in masks)
+
+
+def test_dump_repeated():
+ """Test that repeated dumping of the same data doesn't cause errors."""
+ src_block_data = [
+ torch.randint(0, 1000, (1, 100), dtype=torch.int) for _ in range(5)
+ ]
+ block_ids = [tensor_hash(data) for data in src_block_data]
+ offset = [0] * len(block_ids)
+
+ store = UcmMooncakeStore(mooncake_dict_config)
+ task: Task = store.dump(
+ block_ids=block_ids, offset=offset, src_tensor=src_block_data
+ )
+ ret = store.wait(task)
+ assert ret == 0
+ masks = store.lookup(block_ids)
+ assert all(mask is True for mask in masks)
+
+ task: Task = store.dump(
+ block_ids=block_ids, offset=offset, src_tensor=src_block_data
+ )
+ ret = store.wait(task)
+ assert ret == 0
+
+
+def test_load_existing_data():
+ """Test loading data that was previously dumped into the store."""
+ src_block_data = [
+ torch.randint(0, 1000, (1, 100), dtype=torch.int) for _ in range(5)
+ ]
+ dst_block_data = [
+ torch.empty(data.shape, dtype=data.dtype) for data in src_block_data
+ ]
+ block_ids = [tensor_hash(data) for data in src_block_data]
+ offset = [0] * len(block_ids)
+
+ store = UcmMooncakeStore(mooncake_dict_config)
+ task: Task = store.dump(
+ block_ids=block_ids, offset=offset, src_tensor=src_block_data
+ )
+ ret = store.wait(task)
+ assert ret == 0
+
+ masks = store.lookup(block_ids)
+ assert all(mask is True for mask in masks)
+
+ task: Task = store.load(
+ block_ids=block_ids, offset=offset, dst_tensor=dst_block_data
+ )
+ ret = store.wait(task)
+ assert ret == 0
+ assert all(
+ [
+ torch.equal(src_block_data[i], dst_block_data[i]) is True
+ for i in range(len(src_block_data))
+ ]
+ )
+
+
+def test_load_non_existent_data():
+ """Test loading data that doesn't exist in the store verifies the destination remains unchanged."""
+ src_block_data = [
+ torch.randint(0, 1000, (1, 100), dtype=torch.int) for _ in range(5)
+ ]
+ dst_block_data = [
+ torch.empty(data.shape, dtype=data.dtype) for data in src_block_data
+ ]
+ block_ids = [tensor_hash(data) for data in src_block_data]
+ offset = [0] * len(block_ids)
+ store = UcmMooncakeStore(mooncake_dict_config)
+ masks = store.lookup(block_ids)
+ assert all(mask is False for mask in masks)
+
+ task: Task = store.load(
+ block_ids=block_ids, offset=offset, dst_tensor=dst_block_data
+ )
+ ret = store.wait(task)
+ assert ret != 0
+ assert all(
+ [
+ torch.equal(src_block_data[i], dst_block_data[i]) is False
+ for i in range(len(src_block_data))
+ ]
+ )
diff --git a/unifiedcache/ucm_connector/factory.py b/unifiedcache/ucm_connector/factory.py
index 814fb2ff..970594a3 100644
--- a/unifiedcache/ucm_connector/factory.py
+++ b/unifiedcache/ucm_connector/factory.py
@@ -63,3 +63,6 @@ def create_connector(cls, connector_name: str, config: dict) -> UcmKVStoreBase:
UcmConnectorFactory.register_connector(
"UcmNfsStore", "unifiedcache.ucm_connector.ucm_nfs_store", "UcmNfsStore"
)
+UcmConnectorFactory.register_connector(
+ "UcmMooncakeStore", "unifiedcache.ucm_connector.ucm_mooncake", "UcmMooncakeStore"
+)
diff --git a/unifiedcache/ucm_connector/ucm_mooncake.py b/unifiedcache/ucm_connector/ucm_mooncake.py
new file mode 100644
index 00000000..34942826
--- /dev/null
+++ b/unifiedcache/ucm_connector/ucm_mooncake.py
@@ -0,0 +1,333 @@
+import asyncio
+import json
+import os
+import threading
+from concurrent.futures import Future, TimeoutError
+from dataclasses import dataclass
+from typing import Dict, List
+
+import torch
+from safetensors.torch import load as safetensors_load
+from safetensors.torch import save as safetensors_save
+
+from unifiedcache.logger import init_logger
+from unifiedcache.ucm_connector import Task, UcmKVStoreBase
+
+TIMEOUT_S_THR: int = 60 * 60
+DEFAULT_GLOBAL_SEGMENT_SIZE: int = 3355443200 # 3.125 GiB
+DEFAULT_LOCAL_BUFFER_SIZE: int = 1073741824 # 1.0 GiB
+
+logger = init_logger(__name__)
+
+
+# TODO To keep it consistent with the vllm source code(vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py), the source code is fully reused here. The code here will be deleted after vllm is implemented.
+@dataclass
+class MooncakeStoreConfig:
+ local_hostname: str
+ metadata_server: str
+ global_segment_size: int
+ local_buffer_size: int
+ protocol: str
+ device_name: str
+ master_server_address: str
+
+ @staticmethod
+ def load_from_dict(config: Dict = {}) -> "MooncakeStoreConfig":
+ """Load the config from dict."""
+ return MooncakeStoreConfig(
+ local_hostname=config.get("local_hostname"),
+ metadata_server=config.get("metadata_server"),
+ global_segment_size=config.get(
+ "global_segment_size", DEFAULT_GLOBAL_SEGMENT_SIZE
+ ),
+ local_buffer_size=config.get(
+ "local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE
+ ),
+ protocol=config.get("protocol", "tcp"),
+ device_name=config.get("device_name", ""),
+ master_server_address=config.get("master_server_address"),
+ )
+
+
+@dataclass
+class MooncakeTask(Task):
+ """A task class for Mooncake operations with a task identifier."""
+
+ task_id: int = -1
+
+
+class UcmMooncakeStore(UcmKVStoreBase):
+ """
+ A wrapper class for MooncakeDistributedStore that implements the UcmKVStoreBase interface.
+ Provides key-value store functionality for vLLM using Mooncake as the backend.
+ """
+
+ def __init__(self, config: Dict = {}):
+ """Initialize the Mooncake store with configuration."""
+ super().__init__(config)
+ try:
+ from mooncake.store import MooncakeDistributedStore
+ except ImportError as e:
+ raise ImportError(
+ "Please install mooncake by following the instructions at "
+ "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
+ "to run vLLM with MooncakeConnector."
+ ) from e
+
+ try:
+ self.store = MooncakeDistributedStore()
+
+ mooncake_config = MooncakeStoreConfig.load_from_dict(config)
+ logger.info("Mooncake Configuration loaded from dict successfully.")
+
+ self.store.setup(
+ mooncake_config.local_hostname,
+ mooncake_config.metadata_server,
+ mooncake_config.global_segment_size,
+ mooncake_config.local_buffer_size,
+ mooncake_config.protocol,
+ mooncake_config.device_name,
+ mooncake_config.master_server_address,
+ )
+
+ except ValueError as e:
+ logger.error("Configuration loading failed: %s", e)
+ raise
+ except TypeError:
+ logger.warning("Lack of configuration, please check the dict params .")
+
+ except Exception as exc:
+ logger.error("An error occurred while loading the configuration: %s", exc)
+ raise
+
+ # Task management variables
+ self.task_id: int = 0
+ self.tasks: Dict[int, Future] = {}
+
+ # Threading and synchronization variables
+ self.loop = asyncio.new_event_loop()
+ self.lock = threading.Lock()
+ self._shutting_down = threading.Event()
+
+ # Start the event loop thread
+ self.thread = threading.Thread(target=self._run_event_loop, daemon=True)
+ self.thread.start()
+
+ def __del__(self):
+ """Release resources on garbage collection."""
+ try:
+ self.shutdown()
+ except Exception:
+ pass
+
+ def _run_event_loop(self):
+ """Run the asyncio event loop in a separate thread."""
+ asyncio.set_event_loop(self.loop)
+ self.loop.run_forever()
+
+ def create(self, block_ids: List[str]) -> int:
+ """
+ create kv cache space in storafe (not implemented for Mooncake).
+
+ Args:
+ block_ids (List[str]): vLLM block hash.
+ Returns:
+ Always returns 0 as this operation is not supported by Mooncake
+ """
+ # Mooncake only has get and put interfaces, this operation is not supported
+ return 0
+
+ def lookup(self, block_ids: List[str]) -> List[bool]:
+ """
+ Get number of blocks that can be loaded from the
+ external KV cache.
+ Mooncake integration uses hash = block_id + offset (default offset=0 if not provided).
+ Args:
+ block_ids (List[str]): vLLM block hash.
+
+ Returns:
+ hit block mask, True -> hit
+ """
+ if self._shutting_down.is_set():
+ raise RuntimeError("UcmMooncakeStore is shutting down.")
+
+ mask = [
+ True if self.store.is_exist(f"{block_key}_0") == 1 else False
+ for block_key in block_ids
+ ]
+ return mask
+
+ def prefetch(self, block_ids: List[str]) -> None:
+ """
+ prefetch kv cache to high speed cache according to block_ids (not implemented for Mooncake).
+
+ Args:
+ block_ids (List[str]): vLLM block hash.
+ """
+ # Mooncake only has get and put interfaces, this operation is not supported
+ pass
+
+ def load(
+ self, block_ids: List[str], offset: List[int], dst_tensor: List[torch.Tensor]
+ ) -> Task:
+ """
+ load kv cache to device.
+ Mooncake integration uses hash = block_id + offset (default offset=0 if not provided).
+
+ Args:
+ block_ids (List[str]): vLLM block hash.
+ offset(List[int]): tp > 1 scene
+ dst_tensor: List[torch.Tensor]: device tensor addr.
+ Returns:
+ task(Task).
+ """
+ if self._shutting_down.is_set():
+ raise RuntimeError("UcmMooncakeStore is shutting down.")
+
+ coro = self._load_impl(block_ids, offset, dst_tensor)
+ future = asyncio.run_coroutine_threadsafe(coro, self.loop)
+ with self.lock:
+ self.task_id += 1
+ self.tasks[self.task_id] = future
+ return MooncakeTask(task_id=self.task_id)
+
+ async def _load_impl(
+ self, block_ids: List[str], offset: List[int], dst_tensor: List[torch.Tensor]
+ ) -> int:
+ """Internal implementation of loading KV cache from Mooncake Store."""
+ assert len(block_ids) == len(
+ dst_tensor
+ ), "block_ids and dst_tensor have different lengths, please check!"
+ for i in range(len(block_ids)):
+ try:
+ block_hash = f"{block_ids[i]}_{offset[i]}"
+ data = self.store.get(block_hash)
+ except TypeError as err:
+ logger.error("Failed to get value from Mooncake Store: %s", err)
+ raise TypeError("Mooncake Store Get Type Error.") from err
+
+ if data:
+ loaded_tensors = safetensors_load(data)
+ tensor_cpu = loaded_tensors["tensor"]
+ assert dst_tensor[i].shape == tensor_cpu.shape
+ assert dst_tensor[i].dtype == tensor_cpu.dtype
+ dst_tensor[i].copy_(tensor_cpu)
+ else:
+ return 1
+ return 0
+
+ def dump(
+ self, block_ids: List[str], offset: List[int], src_tensor: List[torch.Tensor]
+ ) -> Task:
+ """
+ dump kv cache to device.
+ Mooncake integration uses hash = block_id + offset (default offset=0 if not provided).
+
+ Args:
+ block_ids (List[str]): vLLM block hash.
+ offset(List[int]): tp > 1 scene
+ src_tensor: List[torch.Tensor]: device tensor addr.
+ Returns:
+ task(Task).
+ """
+ if self._shutting_down.is_set():
+ raise RuntimeError("UcmMooncakeStore is shutting down.")
+
+ coro = self._dump_impl(block_ids, offset, src_tensor)
+ future = asyncio.run_coroutine_threadsafe(coro, self.loop)
+ with self.lock:
+ self.task_id += 1
+ self.tasks[self.task_id] = future
+ return MooncakeTask(task_id=self.task_id)
+
+ async def _dump_impl(
+ self, block_ids: List[str], offset: List[int], src_tensor: List[torch.Tensor]
+ ) -> int:
+ """Internal implementation of dumping KV cache to Mooncake Store."""
+ assert len(block_ids) == len(
+ src_tensor
+ ), "block_ids and src_tensor have different lengths, please check!"
+ for i in range(len(block_ids)):
+ value_bytes = safetensors_save({"tensor": src_tensor[i]})
+ try:
+ block_hash = f"{block_ids[i]}_{offset[i]}"
+ ret = self.store.put(block_hash, value_bytes)
+ if ret != 0:
+ return ret
+ except TypeError as err:
+ logger.error("Failed to put value into Mooncake Store: %s", err)
+ raise TypeError("Mooncake Store Put Type Error.") from err
+ return 0
+
+ def wait(self, task: Task) -> int:
+ """
+ wait kv cache kv transfer task finished.
+
+ Args:
+ task (Task): transfer engine task.
+ Returns:
+ 0 - success
+ others - failed.
+ """
+ # Safely retrieve the Future object
+ with self.lock:
+ future = self.tasks.pop(task.task_id, None)
+
+ if future is None:
+ logger.error(f"Invalid task ID: {task.task_id}")
+ return 1
+
+ try:
+ ret = future.result(TIMEOUT_S_THR)
+ return ret
+ except TimeoutError:
+ # Cancel the task if it times out
+ future.cancel()
+ logger.error(f"Task {task.task_id} timed out after {TIMEOUT_S_THR}s")
+ return 1
+ except asyncio.CancelledError:
+ logger.error(f"Task {task.task_id} was cancelled")
+ return 1
+ except Exception as e:
+ logger.error(f"Task {task.task_id} failed: {str(e)}")
+ return 1
+
+ def commit(self, block_ids: List[str], is_success: bool = True) -> None:
+ """
+ commit kv cache, now kv cache can be reused (not implemented for Mooncake).
+
+ Args:
+ block_ids (List[str]): vLLM block hash.
+ is_success(bool): if False, we need release block
+ """
+ # Mooncake only has get and put interfaces, this operation is not supported
+ pass
+
+ def shutdown(self):
+ """Safely shutdown all components of the store."""
+ if self._shutting_down.is_set():
+ return
+
+ self._shutting_down.set()
+
+ # Safely cancel all pending tasks (atomic operation)
+ with self.lock:
+ tasks_to_cancel = list(self.tasks.values())
+ self.tasks.clear()
+
+ for future in tasks_to_cancel:
+ if not future.done():
+ future.cancel()
+
+ # Stop the event loop
+ self.loop.call_soon_threadsafe(self.loop.stop)
+
+ # Wait for thread termination
+ if self.thread.is_alive():
+ self.thread.join(TIMEOUT_S_THR)
+
+ # Force close the loop if thread didn't exit
+ if not self.loop.is_closed():
+ self.loop.close()
+
+ self.store.close()