From fb3645bd1c9d237914ed126e6ad140db8d18ad43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=85=E6=B0=A2?= Date: Mon, 12 Jan 2026 18:44:14 +0800 Subject: [PATCH 1/2] feat: memorycollection implement MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change-Id: I096c2a825dac99b411fa69db63cc6693cd67630e Co-developed-by: Cursor Signed-off-by: 久氢 --- Makefile | 5 +- agentrun/__init__.py | 33 + .../__client_async_template.py | 178 ++++ agentrun/memory_collection/__init__.py | 37 + .../__memory_collection_async_template.py | 457 ++++++++++ agentrun/memory_collection/api/__init__.py | 5 + agentrun/memory_collection/api/control.py | 610 +++++++++++++ agentrun/memory_collection/client.py | 323 +++++++ .../memory_collection/memory_collection.py | 844 ++++++++++++++++++ agentrun/memory_collection/model.py | 162 ++++ .../memory_collection_control_api.yaml | 53 ++ examples/memory_collection_example.py | 162 ++++ pyproject.toml | 4 + 13 files changed, 2872 insertions(+), 1 deletion(-) create mode 100644 agentrun/memory_collection/__client_async_template.py create mode 100644 agentrun/memory_collection/__init__.py create mode 100644 agentrun/memory_collection/__memory_collection_async_template.py create mode 100644 agentrun/memory_collection/api/__init__.py create mode 100644 agentrun/memory_collection/api/control.py create mode 100644 agentrun/memory_collection/client.py create mode 100644 agentrun/memory_collection/memory_collection.py create mode 100644 agentrun/memory_collection/model.py create mode 100644 codegen/configs/memory_collection_control_api.yaml create mode 100644 examples/memory_collection_example.py diff --git a/Makefile b/Makefile index fdfa26e..a00124f 100644 --- a/Makefile +++ b/Makefile @@ -45,13 +45,15 @@ JINJA2_FILES := \ agentrun/credential/api/control.py \ agentrun/model/api/control.py \ agentrun/toolset/api/control.py \ - agentrun/sandbox/api/control.py + agentrun/sandbox/api/control.py \ + agentrun/memory_collection/api/control.py JINJA2_CONFIGS := \ codegen/configs/agent_runtime_control_api.yaml \ codegen/configs/credential_control_api.yaml \ codegen/configs/model_control_api.yaml \ codegen/configs/toolset_control_api.yaml \ codegen/configs/sandbox_control_api.yaml \ + codegen/configs/memory_collection_control_api.yaml \ define make_jinja2_rule $(1): $(2) @@ -65,6 +67,7 @@ $(eval $(call make_jinja2_rule,agentrun/credential/api/control.py,codegen/config $(eval $(call make_jinja2_rule,agentrun/model/api/control.py,codegen/configs/model_control_api.yaml)) $(eval $(call make_jinja2_rule,agentrun/toolset/api/control.py,codegen/configs/toolset_control_api.yaml)) $(eval $(call make_jinja2_rule,agentrun/sandbox/api/control.py,codegen/configs/sandbox_control_api.yaml)) +$(eval $(call make_jinja2_rule,agentrun/memory_collection/api/control.py,codegen/configs/memory_collection_control_api.yaml)) TEMPLATE_FILES := $(shell find . -name "__*async_template.py" -not -path "*__pycache__*" -not -path "*egg-info*") # 根据模板文件生成对应的输出文件路径 diff --git a/agentrun/__init__.py b/agentrun/__init__.py index f906602..e35a682 100644 --- a/agentrun/__init__.py +++ b/agentrun/__init__.py @@ -55,6 +55,22 @@ CredentialUpdateInput, RelatedResource, ) +# Memory Collection +from agentrun.memory_collection import ( + EmbedderConfig, + EmbedderConfigConfig, + LLMConfig, + LLMConfigConfig, + MemoryCollection, + MemoryCollectionClient, + MemoryCollectionCreateInput, + MemoryCollectionListInput, + MemoryCollectionListOutput, + MemoryCollectionUpdateInput, + NetworkConfiguration, + VectorStoreConfig, + VectorStoreConfigConfig, +) # Model Service from agentrun.model import ( BackendType, @@ -173,6 +189,23 @@ "CredentialCreateInput", "CredentialUpdateInput", "CredentialListInput", + ######## Memory Collection ######## + # base + "MemoryCollection", + "MemoryCollectionClient", + # inner model + "EmbedderConfig", + "EmbedderConfigConfig", + "LLMConfig", + "LLMConfigConfig", + "NetworkConfiguration", + "VectorStoreConfig", + "VectorStoreConfigConfig", + # api model + "MemoryCollectionCreateInput", + "MemoryCollectionUpdateInput", + "MemoryCollectionListInput", + "MemoryCollectionListOutput", ######## Model ######## # base "ModelClient", diff --git a/agentrun/memory_collection/__client_async_template.py b/agentrun/memory_collection/__client_async_template.py new file mode 100644 index 0000000..db37113 --- /dev/null +++ b/agentrun/memory_collection/__client_async_template.py @@ -0,0 +1,178 @@ +"""MemoryCollection 客户端 / MemoryCollection Client + +此模块提供记忆集合管理的客户端API。 +This module provides the client API for memory collection management. +""" + +from typing import Optional + +from alibabacloud_agentrun20250910.models import ( + CreateMemoryCollectionInput, + ListMemoryCollectionsRequest, + UpdateMemoryCollectionInput, +) + +from agentrun.utils.config import Config +from agentrun.utils.exception import HTTPError + +from .api.control import MemoryCollectionControlAPI +from .memory_collection import MemoryCollection +from .model import ( + MemoryCollectionCreateInput, + MemoryCollectionListInput, + MemoryCollectionListOutput, + MemoryCollectionUpdateInput, +) + + +class MemoryCollectionClient: + """MemoryCollection 客户端 / MemoryCollection Client + + 提供记忆集合的创建、删除、更新和查询功能。 + Provides create, delete, update and query functions for memory collections. + """ + + def __init__(self, config: Optional[Config] = None): + """初始化客户端 / Initialize client + + Args: + config: 配置对象,可选 / Configuration object, optional + """ + self.__control_api = MemoryCollectionControlAPI(config) + + async def create_async( + self, + input: MemoryCollectionCreateInput, + config: Optional[Config] = None, + ): + """创建记忆集合(异步) / Create memory collection asynchronously + + Args: + input: 记忆集合输入参数 / Memory collection input parameters + config: 配置对象,可选 / Configuration object, optional + + Returns: + MemoryCollection: 创建的记忆集合对象 / Created memory collection object + + Raises: + ResourceAlreadyExistError: 资源已存在 / Resource already exists + HTTPError: HTTP 请求错误 / HTTP request error + """ + try: + result = await self.__control_api.create_memory_collection_async( + CreateMemoryCollectionInput().from_map(input.model_dump()), + config=config, + ) + + return MemoryCollection.from_inner_object(result) + except HTTPError as e: + raise e.to_resource_error( + "MemoryCollection", input.memory_collection_name + ) from e + + async def delete_async( + self, memory_collection_name: str, config: Optional[Config] = None + ): + """删除记忆集合(异步) + + Args: + memory_collection_name: 记忆集合名称 + config: 配置 + + Raises: + ResourceNotExistError: 记忆集合不存在 + """ + try: + result = await self.__control_api.delete_memory_collection_async( + memory_collection_name, config=config + ) + + return MemoryCollection.from_inner_object(result) + + except HTTPError as e: + raise e.to_resource_error( + "MemoryCollection", memory_collection_name + ) from e + + async def update_async( + self, + memory_collection_name: str, + input: MemoryCollectionUpdateInput, + config: Optional[Config] = None, + ): + """更新记忆集合(异步) + + Args: + memory_collection_name: 记忆集合名称 + input: 记忆集合更新输入参数 + config: 配置 + + Returns: + MemoryCollection: 更新后的记忆集合对象 + + Raises: + ResourceNotExistError: 记忆集合不存在 + """ + try: + result = await self.__control_api.update_memory_collection_async( + memory_collection_name, + UpdateMemoryCollectionInput().from_map(input.model_dump()), + config=config, + ) + + return MemoryCollection.from_inner_object(result) + except HTTPError as e: + raise e.to_resource_error( + "MemoryCollection", memory_collection_name + ) from e + + async def get_async( + self, memory_collection_name: str, config: Optional[Config] = None + ): + """获取记忆集合(异步) + + Args: + memory_collection_name: 记忆集合名称 + config: 配置 + + Returns: + MemoryCollection: 记忆集合对象 + + Raises: + ResourceNotExistError: 记忆集合不存在 + """ + try: + result = await self.__control_api.get_memory_collection_async( + memory_collection_name, config=config + ) + return MemoryCollection.from_inner_object(result) + except HTTPError as e: + raise e.to_resource_error( + "MemoryCollection", memory_collection_name + ) from e + + async def list_async( + self, + input: Optional[MemoryCollectionListInput] = None, + config: Optional[Config] = None, + ): + """列出记忆集合(异步) + + Args: + input: 分页查询参数 + config: 配置 + + Returns: + List[MemoryCollectionListOutput]: 记忆集合列表 + """ + if input is None: + input = MemoryCollectionListInput() + + results = await self.__control_api.list_memory_collections_async( + ListMemoryCollectionsRequest().from_map(input.model_dump()), + config=config, + ) + return [ + MemoryCollectionListOutput.from_inner_object(item) + for item in results.items # type: ignore + ] diff --git a/agentrun/memory_collection/__init__.py b/agentrun/memory_collection/__init__.py new file mode 100644 index 0000000..6f3bd0f --- /dev/null +++ b/agentrun/memory_collection/__init__.py @@ -0,0 +1,37 @@ +"""MemoryCollection 模块 / MemoryCollection Module + +提供记忆集合管理功能。 +Provides memory collection management functionality. +""" + +from .client import MemoryCollectionClient +from .memory_collection import MemoryCollection +from .model import ( + EmbedderConfig, + EmbedderConfigConfig, + LLMConfig, + LLMConfigConfig, + MemoryCollectionCreateInput, + MemoryCollectionListInput, + MemoryCollectionListOutput, + MemoryCollectionUpdateInput, + NetworkConfiguration, + VectorStoreConfig, + VectorStoreConfigConfig, +) + +__all__ = [ + "MemoryCollection", + "MemoryCollectionClient", + "MemoryCollectionCreateInput", + "MemoryCollectionUpdateInput", + "MemoryCollectionListInput", + "MemoryCollectionListOutput", + "EmbedderConfig", + "EmbedderConfigConfig", + "LLMConfig", + "LLMConfigConfig", + "NetworkConfiguration", + "VectorStoreConfig", + "VectorStoreConfigConfig", +] diff --git a/agentrun/memory_collection/__memory_collection_async_template.py b/agentrun/memory_collection/__memory_collection_async_template.py new file mode 100644 index 0000000..6c9dae0 --- /dev/null +++ b/agentrun/memory_collection/__memory_collection_async_template.py @@ -0,0 +1,457 @@ +"""MemoryCollection 高层 API / MemoryCollection High-Level API + +此模块定义记忆集合资源的高级API。 +This module defines the high-level API for memory collection resources. +""" + +from typing import Any, Dict, List, Optional, Tuple + +from agentrun.utils.config import Config +from agentrun.utils.model import PageableInput +from agentrun.utils.resource import ResourceBase + +from .model import ( + MemoryCollectionCreateInput, + MemoryCollectionImmutableProps, + MemoryCollectionListInput, + MemoryCollectionListOutput, + MemoryCollectionMutableProps, + MemoryCollectionSystemProps, + MemoryCollectionUpdateInput, +) + + +class MemoryCollection( + MemoryCollectionMutableProps, + MemoryCollectionImmutableProps, + MemoryCollectionSystemProps, + ResourceBase, +): + """记忆集合资源 / MemoryCollection Resource + + 提供记忆集合的完整生命周期管理,包括创建、删除、更新、查询。 + Provides complete lifecycle management for memory collections, including create, delete, update, and query. + """ + + @classmethod + def __get_client(cls): + """获取客户端实例 / Get client instance + + Returns: + MemoryCollectionClient: 客户端实例 / Client instance + """ + from .client import MemoryCollectionClient + + return MemoryCollectionClient() + + @classmethod + async def create_async( + cls, input: MemoryCollectionCreateInput, config: Optional[Config] = None + ): + """创建记忆集合(异步) + + Args: + input: 记忆集合输入参数 + config: 配置 + + Returns: + MemoryCollection: 创建的记忆集合对象 + """ + return await cls.__get_client().create_async(input, config=config) + + @classmethod + async def delete_by_name_async( + cls, memory_collection_name: str, config: Optional[Config] = None + ): + """根据名称删除记忆集合(异步) + + Args: + memory_collection_name: 记忆集合名称 + config: 配置 + """ + return await cls.__get_client().delete_async( + memory_collection_name, config=config + ) + + @classmethod + async def update_by_name_async( + cls, + memory_collection_name: str, + input: MemoryCollectionUpdateInput, + config: Optional[Config] = None, + ): + """根据名称更新记忆集合(异步) + + Args: + memory_collection_name: 记忆集合名称 + input: 记忆集合更新输入参数 + config: 配置 + + Returns: + MemoryCollection: 更新后的记忆集合对象 + """ + return await cls.__get_client().update_async( + memory_collection_name, input, config=config + ) + + @classmethod + async def get_by_name_async( + cls, memory_collection_name: str, config: Optional[Config] = None + ): + """根据名称获取记忆集合(异步) + + Args: + memory_collection_name: 记忆集合名称 + config: 配置 + + Returns: + MemoryCollection: 记忆集合对象 + """ + return await cls.__get_client().get_async( + memory_collection_name, config=config + ) + + @classmethod + async def _list_page_async( + cls, page_input: PageableInput, config: Config | None = None, **kwargs + ): + return await cls.__get_client().list_async( + input=MemoryCollectionListInput( + **kwargs, + **page_input.model_dump(), + ), + config=config, + ) + + @classmethod + async def list_all_async( + cls, + *, + memory_collection_name: Optional[str] = None, + config: Optional[Config] = None, + ) -> List[MemoryCollectionListOutput]: + """列出所有记忆集合(异步) + + Args: + memory_collection_name: 记忆集合名称(可选) + config: 配置 + + Returns: + List[MemoryCollectionListOutput]: 记忆集合列表 + """ + return await cls._list_all_async( + lambda mc: mc.memory_collection_id or "", + config=config, + memory_collection_name=memory_collection_name, + ) + + async def update_async( + self, + input: MemoryCollectionUpdateInput, + config: Optional[Config] = None, + ): + """更新记忆集合(异步) + + Args: + input: 记忆集合更新输入参数 + config: 配置 + + Returns: + MemoryCollection: 更新后的记忆集合对象 + """ + if self.memory_collection_name is None: + raise ValueError( + "memory_collection_name is required to update a" + " MemoryCollection" + ) + + result = await self.update_by_name_async( + self.memory_collection_name, input, config=config + ) + self.update_self(result) + + return self + + async def delete_async(self, config: Optional[Config] = None): + """删除记忆集合(异步) + + Args: + config: 配置 + """ + if self.memory_collection_name is None: + raise ValueError( + "memory_collection_name is required to delete a" + " MemoryCollection" + ) + + return await self.delete_by_name_async( + self.memory_collection_name, config=config + ) + + async def get_async(self, config: Optional[Config] = None): + """刷新记忆集合信息(异步) + + Args: + config: 配置 + + Returns: + MemoryCollection: 刷新后的记忆集合对象 + """ + if self.memory_collection_name is None: + raise ValueError( + "memory_collection_name is required to refresh a" + " MemoryCollection" + ) + + result = await self.get_by_name_async( + self.memory_collection_name, config=config + ) + self.update_self(result) + + return self + + async def refresh_async(self, config: Optional[Config] = None): + """刷新记忆集合信息(异步) + + Args: + config: 配置 + + Returns: + MemoryCollection: 刷新后的记忆集合对象 + """ + return await self.get_async(config=config) + + @classmethod + async def to_mem0_memory_async( + cls, + memory_collection_name: str, + config: Optional[Config] = None, + history_db_path: Optional[str] = None, + ): + """将 MemoryCollection 转换为 agentrun-mem0ai Memory 客户端(异步) + + Args: + memory_collection_name: 记忆集合名称 + config: AgentRun 配置 + history_db_path: mem0 历史数据库路径(可选) + + Returns: + Memory: agentrun-mem0ai Memory 客户端实例 + + Raises: + ImportError: 如果未安装 agentrun-mem0ai 包 + ValueError: 如果配置信息不完整 + + Example: + >>> memory = await MemoryCollection.to_mem0_memory_async( + ... "memoryCollection010901", + ... config=config + ... ) + >>> memory.add("用户喜欢吃苹果", user_id="user123") + """ + try: + from mem0 import Memory + except ImportError as e: + raise ImportError( + "agentrun-mem0ai package is required. Install it with: pip" + " install agentrun-mem0ai" + ) from e + + # 获取 MemoryCollection 配置 + memory_collection = await cls.get_by_name_async( + memory_collection_name, config=config + ) + + # 构建 mem0 配置 + mem0_config = await cls._build_mem0_config_async( + memory_collection, config, history_db_path + ) + + # 创建并返回 Memory 实例 + return Memory.from_config(mem0_config) + + @staticmethod + def _convert_vpc_endpoint_to_public(endpoint: str) -> str: + """将 VPC 内网地址转换为公网地址 + + Args: + endpoint: 原始 endpoint,可能是 VPC 内网地址 + + Returns: + str: 公网地址 + + Example: + >>> _convert_vpc_endpoint_to_public("https://jiuqing.cn-hangzhou.vpc.tablestore.aliyuncs.com") + "https://jiuqing.cn-hangzhou.ots.aliyuncs.com" + """ + if ".vpc.tablestore.aliyuncs.com" in endpoint: + # 将 .vpc.tablestore.aliyuncs.com 替换为 .ots.aliyuncs.com + return endpoint.replace( + ".vpc.tablestore.aliyuncs.com", ".ots.aliyuncs.com" + ) + return endpoint + + @classmethod + async def _build_mem0_config_async( + cls, + memory_collection: "MemoryCollection", + config: Optional[Config], + history_db_path: Optional[str] = None, + ) -> Dict[str, Any]: + """构建 mem0 配置字典(异步) + + Args: + memory_collection: MemoryCollection 对象 + config: AgentRun 配置 + history_db_path: 历史数据库路径 + + Returns: + Dict[str, Any]: mem0 配置字典 + """ + mem0_config: Dict[str, Any] = {} + + # 构建 vector_store 配置 + if memory_collection.vector_store_config: + vector_store_config = memory_collection.vector_store_config + provider = vector_store_config.provider or "" + + if vector_store_config.config: + vs_config = vector_store_config.config + vector_store: Dict[str, Any] = { + "provider": provider, + "config": {}, + } + + # 根据不同的 provider 构建配置 + if provider == "aliyun_tablestore": + # 获取凭证信息 + effective_config = config or Config() + # 将 VPC 内网地址转换为公网地址 + public_endpoint = cls._convert_vpc_endpoint_to_public( + vs_config.endpoint or "" + ) + vector_store["config"] = { + "vector_dimension": vs_config.vector_dimension, + "endpoint": public_endpoint, + "instance_name": vs_config.instance_name, + "collection_name": vs_config.collection_name, + "access_key_id": effective_config.get_access_key_id(), + "access_key_secret": ( + effective_config.get_access_key_secret() + ), + } + # 如果有 security_token,添加它 + security_token = effective_config.get_security_token() + if security_token: + vector_store["config"]["sts_token"] = security_token + else: + # 其他 provider 的通用配置 + vector_store["config"] = { + "endpoint": vs_config.endpoint, + "collection_name": vs_config.collection_name, + } + if vs_config.vector_dimension: + vector_store["config"][ + "vector_dimension" + ] = vs_config.vector_dimension + + mem0_config["vector_store"] = vector_store + + # 构建 llm 配置 + if memory_collection.llm_config: + llm_config = memory_collection.llm_config + model_service_name = llm_config.model_service_name + + if model_service_name and llm_config.config: + # 使用高层 API 获取 ModelService 配置 + base_url, api_key = ( + await cls._resolve_model_service_config_async( + model_service_name, config + ) + ) + + mem0_config["llm"] = { + "provider": "openai", # mem0 使用 openai 兼容接口 + "config": { + "model": llm_config.config.model, + "openai_base_url": base_url, + "api_key": api_key, + }, + } + + # 构建 embedder 配置 + if memory_collection.embedder_config: + embedder_config = memory_collection.embedder_config + model_service_name = embedder_config.model_service_name + + if model_service_name and embedder_config.config: + # 使用高层 API 获取 ModelService 配置 + base_url, api_key = ( + await cls._resolve_model_service_config_async( + model_service_name, config + ) + ) + + mem0_config["embedder"] = { + "provider": "openai", # mem0 使用 openai 兼容接口 + "config": { + "model": embedder_config.config.model, + "openai_base_url": base_url, + "api_key": api_key, + }, + } + + # 添加历史数据库路径 + if history_db_path: + mem0_config["history_db_path"] = history_db_path + + return mem0_config + + @staticmethod + async def _resolve_model_service_config_async( + model_service_name: str, config: Optional[Config] + ) -> Tuple[str, str]: + """解析 ModelService 配置获取 baseUrl 和 apiKey(异步) + + Args: + model_service_name: ModelService 名称 + config: AgentRun 配置 + + Returns: + Tuple[str, str]: (base_url, api_key) + + Raises: + ValueError: 如果配置信息不完整 + """ + from agentrun.credential import Credential + from agentrun.model import ModelService + + # 使用高层 API 获取 ModelService + model_service = await ModelService.get_by_name_async( + model_service_name, config=config + ) + + # 获取 provider_settings + if not model_service.provider_settings: + raise ValueError( + f"ModelService {model_service_name} providerSettings is empty" + ) + + base_url = model_service.provider_settings.base_url or "" + api_key = model_service.provider_settings.api_key or "" + + # 如果有 credentialName,使用高层 API 获取 credential secret + credential_name = model_service.credential_name + if credential_name: + credential = await Credential.get_by_name_async( + credential_name, config=config + ) + if credential.credential_secret: + api_key = credential.credential_secret + + if not base_url: + raise ValueError( + f"ModelService {model_service_name} baseUrl is empty" + ) + + return base_url, api_key diff --git a/agentrun/memory_collection/api/__init__.py b/agentrun/memory_collection/api/__init__.py new file mode 100644 index 0000000..66f2de1 --- /dev/null +++ b/agentrun/memory_collection/api/__init__.py @@ -0,0 +1,5 @@ +"""MemoryCollection API 模块 / MemoryCollection API Module""" + +from .control import MemoryCollectionControlAPI + +__all__ = ["MemoryCollectionControlAPI"] diff --git a/agentrun/memory_collection/api/control.py b/agentrun/memory_collection/api/control.py new file mode 100644 index 0000000..0bca849 --- /dev/null +++ b/agentrun/memory_collection/api/control.py @@ -0,0 +1,610 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: codegen/configs/memory_collection_control_api.yaml + + +Memory Collection 管控链路 API +""" + +from typing import Dict, Optional + +from alibabacloud_agentrun20250910.models import ( + CreateMemoryCollectionInput, + CreateMemoryCollectionRequest, + ListMemoryCollectionsOutput, + ListMemoryCollectionsRequest, + MemoryCollection, + UpdateMemoryCollectionInput, + UpdateMemoryCollectionRequest, +) +from alibabacloud_tea_openapi.exceptions._client import ClientException +from alibabacloud_tea_openapi.exceptions._server import ServerException +from alibabacloud_tea_util.models import RuntimeOptions +import pydash + +from agentrun.utils.config import Config +from agentrun.utils.control_api import ControlAPI +from agentrun.utils.exception import ClientError, ServerError +from agentrun.utils.log import logger + + +class MemoryCollectionControlAPI(ControlAPI): + """Memory Collection 管控链路 API""" + + def __init__(self, config: Optional[Config] = None): + """初始化 API 客户端 + + Args: + config: 全局配置对象 + """ + super().__init__(config) + + def create_memory_collection( + self, + input: CreateMemoryCollectionInput, + headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> MemoryCollection: + """创建 Memory Collection + + Args: + input: Memory Collection 配置 + + headers: 请求头 + config: 配置 + + Returns: + MemoryCollection: 创建的 Memory Collection 对象 + + Raises: + AgentRuntimeError: 调用失败时抛出 + ClientError: 客户端错误 + ServerError: 服务器错误 + APIError: 运行时错误 + """ + + try: + client = self._get_client(config) + response = client.create_memory_collection_with_options( + CreateMemoryCollectionRequest(body=input), + headers=headers or {}, + runtime=RuntimeOptions(), + ) + + logger.debug( + "request api create_memory_collection, request Request ID:" + f" {response.body.request_id}\n request: {[input.to_map(),]}\n" + f" response: {response.body.data}" + ) + + return response.body.data + except ClientException as e: + raise ClientError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + request=[ + input, + ], + ) from e + except ServerException as e: + raise ServerError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + ) from e + + async def create_memory_collection_async( + self, + input: CreateMemoryCollectionInput, + headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> MemoryCollection: + """创建 Memory Collection + + Args: + input: Memory Collection 配置 + + headers: 请求头 + config: 配置 + + Returns: + MemoryCollection: 创建的 Memory Collection 对象 + + Raises: + AgentRuntimeError: 调用失败时抛出 + ClientError: 客户端错误 + ServerError: 服务器错误 + APIError: 运行时错误 + """ + + try: + client = self._get_client(config) + response = await client.create_memory_collection_with_options_async( + CreateMemoryCollectionRequest(body=input), + headers=headers or {}, + runtime=RuntimeOptions(), + ) + + logger.debug( + "request api create_memory_collection, request Request ID:" + f" {response.body.request_id}\n request: {[input.to_map(),]}\n" + f" response: {response.body.data}" + ) + + return response.body.data + except ClientException as e: + raise ClientError( + e.status_code, + e.data.get("message", e.message), + request_id=e.request_id, + request=[ + input, + ], + ) from e + except ServerException as e: + raise ServerError( + e.status_code, + e.data.get("message", e.message), + request_id=e.request_id, + ) from e + + def delete_memory_collection( + self, + memory_collection_name: str, + headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> MemoryCollection: + """删除 Memory Collection + + Args: + memory_collection_name: Memory Collection 名称 + + headers: 请求头 + config: 配置 + + Returns: + MemoryCollection: 删除结果 + + Raises: + AgentRuntimeError: 调用失败时抛出 + ClientError: 客户端错误 + ServerError: 服务器错误 + APIError: 运行时错误 + """ + + try: + client = self._get_client(config) + response = client.delete_memory_collection_with_options( + memory_collection_name, + headers=headers or {}, + runtime=RuntimeOptions(), + ) + + logger.debug( + "request api delete_memory_collection, request Request ID:" + f" {response.body.request_id}\n request:" + f" {[memory_collection_name,]}\n response:" + f" {response.body.data}" + ) + + return response.body.data + except ClientException as e: + raise ClientError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + request=[ + memory_collection_name, + ], + ) from e + except ServerException as e: + raise ServerError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + ) from e + + async def delete_memory_collection_async( + self, + memory_collection_name: str, + headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> MemoryCollection: + """删除 Memory Collection + + Args: + memory_collection_name: Memory Collection 名称 + + headers: 请求头 + config: 配置 + + Returns: + MemoryCollection: 删除结果 + + Raises: + AgentRuntimeError: 调用失败时抛出 + ClientError: 客户端错误 + ServerError: 服务器错误 + APIError: 运行时错误 + """ + + try: + client = self._get_client(config) + response = await client.delete_memory_collection_with_options_async( + memory_collection_name, + headers=headers or {}, + runtime=RuntimeOptions(), + ) + + logger.debug( + "request api delete_memory_collection, request Request ID:" + f" {response.body.request_id}\n request:" + f" {[memory_collection_name,]}\n response:" + f" {response.body.data}" + ) + + return response.body.data + except ClientException as e: + raise ClientError( + e.status_code, + e.data.get("message", e.message), + request_id=e.request_id, + request=[ + memory_collection_name, + ], + ) from e + except ServerException as e: + raise ServerError( + e.status_code, + e.data.get("message", e.message), + request_id=e.request_id, + ) from e + + def update_memory_collection( + self, + memory_collection_name: str, + input: UpdateMemoryCollectionInput, + headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> MemoryCollection: + """更新 Memory Collection + + Args: + memory_collection_name: Memory Collection 名称 + input: Memory Collection 配置 + + headers: 请求头 + config: 配置 + + Returns: + MemoryCollection: 更新的 Memory Collection 对象 + + Raises: + AgentRuntimeError: 调用失败时抛出 + ClientError: 客户端错误 + ServerError: 服务器错误 + APIError: 运行时错误 + """ + + try: + client = self._get_client(config) + response = client.update_memory_collection_with_options( + memory_collection_name, + UpdateMemoryCollectionRequest(body=input), + headers=headers or {}, + runtime=RuntimeOptions(), + ) + + logger.debug( + "request api update_memory_collection, request Request ID:" + f" {response.body.request_id}\n request:" + f" {[memory_collection_name,input.to_map(),]}\n response:" + f" {response.body.data}" + ) + + return response.body.data + except ClientException as e: + raise ClientError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + request=[ + memory_collection_name, + input, + ], + ) from e + except ServerException as e: + raise ServerError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + ) from e + + async def update_memory_collection_async( + self, + memory_collection_name: str, + input: UpdateMemoryCollectionInput, + headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> MemoryCollection: + """更新 Memory Collection + + Args: + memory_collection_name: Memory Collection 名称 + input: Memory Collection 配置 + + headers: 请求头 + config: 配置 + + Returns: + MemoryCollection: 更新的 Memory Collection 对象 + + Raises: + AgentRuntimeError: 调用失败时抛出 + ClientError: 客户端错误 + ServerError: 服务器错误 + APIError: 运行时错误 + """ + + try: + client = self._get_client(config) + response = await client.update_memory_collection_with_options_async( + memory_collection_name, + UpdateMemoryCollectionRequest(body=input), + headers=headers or {}, + runtime=RuntimeOptions(), + ) + + logger.debug( + "request api update_memory_collection, request Request ID:" + f" {response.body.request_id}\n request:" + f" {[memory_collection_name,input.to_map(),]}\n response:" + f" {response.body.data}" + ) + + return response.body.data + except ClientException as e: + raise ClientError( + e.status_code, + e.data.get("message", e.message), + request_id=e.request_id, + request=[ + memory_collection_name, + input, + ], + ) from e + except ServerException as e: + raise ServerError( + e.status_code, + e.data.get("message", e.message), + request_id=e.request_id, + ) from e + + def get_memory_collection( + self, + memory_collection_name: str, + headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> MemoryCollection: + """获取 Memory Collection + + Args: + memory_collection_name: Memory Collection 名称 + + headers: 请求头 + config: 配置 + + Returns: + MemoryCollection: Memory Collection 对象 + + Raises: + AgentRuntimeError: 调用失败时抛出 + ClientError: 客户端错误 + ServerError: 服务器错误 + APIError: 运行时错误 + """ + + try: + client = self._get_client(config) + response = client.get_memory_collection_with_options( + memory_collection_name, + headers=headers or {}, + runtime=RuntimeOptions(), + ) + + logger.debug( + "request api get_memory_collection, request Request ID:" + f" {response.body.request_id}\n request:" + f" {[memory_collection_name,]}\n response:" + f" {response.body.data}" + ) + + return response.body.data + except ClientException as e: + raise ClientError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + request=[ + memory_collection_name, + ], + ) from e + except ServerException as e: + raise ServerError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + ) from e + + async def get_memory_collection_async( + self, + memory_collection_name: str, + headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> MemoryCollection: + """获取 Memory Collection + + Args: + memory_collection_name: Memory Collection 名称 + + headers: 请求头 + config: 配置 + + Returns: + MemoryCollection: Memory Collection 对象 + + Raises: + AgentRuntimeError: 调用失败时抛出 + ClientError: 客户端错误 + ServerError: 服务器错误 + APIError: 运行时错误 + """ + + try: + client = self._get_client(config) + response = await client.get_memory_collection_with_options_async( + memory_collection_name, + headers=headers or {}, + runtime=RuntimeOptions(), + ) + + logger.debug( + "request api get_memory_collection, request Request ID:" + f" {response.body.request_id}\n request:" + f" {[memory_collection_name,]}\n response:" + f" {response.body.data}" + ) + + return response.body.data + except ClientException as e: + raise ClientError( + e.status_code, + e.data.get("message", e.message), + request_id=e.request_id, + request=[ + memory_collection_name, + ], + ) from e + except ServerException as e: + raise ServerError( + e.status_code, + e.data.get("message", e.message), + request_id=e.request_id, + ) from e + + def list_memory_collections( + self, + input: ListMemoryCollectionsRequest, + headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> ListMemoryCollectionsOutput: + """枚举 Memory Collection + + Args: + input: 枚举的配置 + + headers: 请求头 + config: 配置 + + Returns: + ListMemoryCollectionsOutput: Memory Collection 列表 + + Raises: + AgentRuntimeError: 调用失败时抛出 + ClientError: 客户端错误 + ServerError: 服务器错误 + APIError: 运行时错误 + """ + + try: + client = self._get_client(config) + response = client.list_memory_collections_with_options( + input, + headers=headers or {}, + runtime=RuntimeOptions(), + ) + + logger.debug( + "request api list_memory_collections, request Request ID:" + f" {response.body.request_id}\n request: {[input.to_map(),]}\n" + f" response: {response.body.data}" + ) + + return response.body.data + except ClientException as e: + raise ClientError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + request=[ + input, + ], + ) from e + except ServerException as e: + raise ServerError( + e.status_code, + pydash.get(e, "data.message", pydash.get(e, "message", "")), + request_id=e.request_id, + ) from e + + async def list_memory_collections_async( + self, + input: ListMemoryCollectionsRequest, + headers: Optional[Dict[str, str]] = None, + config: Optional[Config] = None, + ) -> ListMemoryCollectionsOutput: + """枚举 Memory Collection + + Args: + input: 枚举的配置 + + headers: 请求头 + config: 配置 + + Returns: + ListMemoryCollectionsOutput: Memory Collection 列表 + + Raises: + AgentRuntimeError: 调用失败时抛出 + ClientError: 客户端错误 + ServerError: 服务器错误 + APIError: 运行时错误 + """ + + try: + client = self._get_client(config) + response = await client.list_memory_collections_with_options_async( + input, + headers=headers or {}, + runtime=RuntimeOptions(), + ) + + logger.debug( + "request api list_memory_collections, request Request ID:" + f" {response.body.request_id}\n request: {[input.to_map(),]}\n" + f" response: {response.body.data}" + ) + + return response.body.data + except ClientException as e: + raise ClientError( + e.status_code, + e.data.get("message", e.message), + request_id=e.request_id, + request=[ + input, + ], + ) from e + except ServerException as e: + raise ServerError( + e.status_code, + e.data.get("message", e.message), + request_id=e.request_id, + ) from e diff --git a/agentrun/memory_collection/client.py b/agentrun/memory_collection/client.py new file mode 100644 index 0000000..3de6ce8 --- /dev/null +++ b/agentrun/memory_collection/client.py @@ -0,0 +1,323 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: agentrun/memory_collection/__client_async_template.py + +MemoryCollection 客户端 / MemoryCollection Client + +此模块提供记忆集合管理的客户端API。 +This module provides the client API for memory collection management. +""" + +from typing import Optional + +from alibabacloud_agentrun20250910.models import ( + CreateMemoryCollectionInput, + ListMemoryCollectionsRequest, + UpdateMemoryCollectionInput, +) + +from agentrun.utils.config import Config +from agentrun.utils.exception import HTTPError + +from .api.control import MemoryCollectionControlAPI +from .memory_collection import MemoryCollection +from .model import ( + MemoryCollectionCreateInput, + MemoryCollectionListInput, + MemoryCollectionListOutput, + MemoryCollectionUpdateInput, +) + + +class MemoryCollectionClient: + """MemoryCollection 客户端 / MemoryCollection Client + + 提供记忆集合的创建、删除、更新和查询功能。 + Provides create, delete, update and query functions for memory collections. + """ + + def __init__(self, config: Optional[Config] = None): + """初始化客户端 / Initialize client + + Args: + config: 配置对象,可选 / Configuration object, optional + """ + self.__control_api = MemoryCollectionControlAPI(config) + + async def create_async( + self, + input: MemoryCollectionCreateInput, + config: Optional[Config] = None, + ): + """创建记忆集合(异步) / Create memory collection asynchronously + + Args: + input: 记忆集合输入参数 / Memory collection input parameters + config: 配置对象,可选 / Configuration object, optional + + Returns: + MemoryCollection: 创建的记忆集合对象 / Created memory collection object + + Raises: + ResourceAlreadyExistError: 资源已存在 / Resource already exists + HTTPError: HTTP 请求错误 / HTTP request error + """ + try: + result = await self.__control_api.create_memory_collection_async( + CreateMemoryCollectionInput().from_map(input.model_dump()), + config=config, + ) + + return MemoryCollection.from_inner_object(result) + except HTTPError as e: + raise e.to_resource_error( + "MemoryCollection", input.memory_collection_name + ) from e + + def create( + self, + input: MemoryCollectionCreateInput, + config: Optional[Config] = None, + ): + """创建记忆集合(同步) / Create memory collection asynchronously + + Args: + input: 记忆集合输入参数 / Memory collection input parameters + config: 配置对象,可选 / Configuration object, optional + + Returns: + MemoryCollection: 创建的记忆集合对象 / Created memory collection object + + Raises: + ResourceAlreadyExistError: 资源已存在 / Resource already exists + HTTPError: HTTP 请求错误 / HTTP request error + """ + try: + result = self.__control_api.create_memory_collection( + CreateMemoryCollectionInput().from_map(input.model_dump()), + config=config, + ) + + return MemoryCollection.from_inner_object(result) + except HTTPError as e: + raise e.to_resource_error( + "MemoryCollection", input.memory_collection_name + ) from e + + async def delete_async( + self, memory_collection_name: str, config: Optional[Config] = None + ): + """删除记忆集合(异步) + + Args: + memory_collection_name: 记忆集合名称 + config: 配置 + + Raises: + ResourceNotExistError: 记忆集合不存在 + """ + try: + result = await self.__control_api.delete_memory_collection_async( + memory_collection_name, config=config + ) + + return MemoryCollection.from_inner_object(result) + + except HTTPError as e: + raise e.to_resource_error( + "MemoryCollection", memory_collection_name + ) from e + + def delete( + self, memory_collection_name: str, config: Optional[Config] = None + ): + """删除记忆集合(同步) + + Args: + memory_collection_name: 记忆集合名称 + config: 配置 + + Raises: + ResourceNotExistError: 记忆集合不存在 + """ + try: + result = self.__control_api.delete_memory_collection( + memory_collection_name, config=config + ) + + return MemoryCollection.from_inner_object(result) + + except HTTPError as e: + raise e.to_resource_error( + "MemoryCollection", memory_collection_name + ) from e + + async def update_async( + self, + memory_collection_name: str, + input: MemoryCollectionUpdateInput, + config: Optional[Config] = None, + ): + """更新记忆集合(异步) + + Args: + memory_collection_name: 记忆集合名称 + input: 记忆集合更新输入参数 + config: 配置 + + Returns: + MemoryCollection: 更新后的记忆集合对象 + + Raises: + ResourceNotExistError: 记忆集合不存在 + """ + try: + result = await self.__control_api.update_memory_collection_async( + memory_collection_name, + UpdateMemoryCollectionInput().from_map(input.model_dump()), + config=config, + ) + + return MemoryCollection.from_inner_object(result) + except HTTPError as e: + raise e.to_resource_error( + "MemoryCollection", memory_collection_name + ) from e + + def update( + self, + memory_collection_name: str, + input: MemoryCollectionUpdateInput, + config: Optional[Config] = None, + ): + """更新记忆集合(同步) + + Args: + memory_collection_name: 记忆集合名称 + input: 记忆集合更新输入参数 + config: 配置 + + Returns: + MemoryCollection: 更新后的记忆集合对象 + + Raises: + ResourceNotExistError: 记忆集合不存在 + """ + try: + result = self.__control_api.update_memory_collection( + memory_collection_name, + UpdateMemoryCollectionInput().from_map(input.model_dump()), + config=config, + ) + + return MemoryCollection.from_inner_object(result) + except HTTPError as e: + raise e.to_resource_error( + "MemoryCollection", memory_collection_name + ) from e + + async def get_async( + self, memory_collection_name: str, config: Optional[Config] = None + ): + """获取记忆集合(异步) + + Args: + memory_collection_name: 记忆集合名称 + config: 配置 + + Returns: + MemoryCollection: 记忆集合对象 + + Raises: + ResourceNotExistError: 记忆集合不存在 + """ + try: + result = await self.__control_api.get_memory_collection_async( + memory_collection_name, config=config + ) + return MemoryCollection.from_inner_object(result) + except HTTPError as e: + raise e.to_resource_error( + "MemoryCollection", memory_collection_name + ) from e + + def get(self, memory_collection_name: str, config: Optional[Config] = None): + """获取记忆集合(同步) + + Args: + memory_collection_name: 记忆集合名称 + config: 配置 + + Returns: + MemoryCollection: 记忆集合对象 + + Raises: + ResourceNotExistError: 记忆集合不存在 + """ + try: + result = self.__control_api.get_memory_collection( + memory_collection_name, config=config + ) + return MemoryCollection.from_inner_object(result) + except HTTPError as e: + raise e.to_resource_error( + "MemoryCollection", memory_collection_name + ) from e + + async def list_async( + self, + input: Optional[MemoryCollectionListInput] = None, + config: Optional[Config] = None, + ): + """列出记忆集合(异步) + + Args: + input: 分页查询参数 + config: 配置 + + Returns: + List[MemoryCollectionListOutput]: 记忆集合列表 + """ + if input is None: + input = MemoryCollectionListInput() + + results = await self.__control_api.list_memory_collections_async( + ListMemoryCollectionsRequest().from_map(input.model_dump()), + config=config, + ) + return [ + MemoryCollectionListOutput.from_inner_object(item) + for item in results.items # type: ignore + ] + + def list( + self, + input: Optional[MemoryCollectionListInput] = None, + config: Optional[Config] = None, + ): + """列出记忆集合(同步) + + Args: + input: 分页查询参数 + config: 配置 + + Returns: + List[MemoryCollectionListOutput]: 记忆集合列表 + """ + if input is None: + input = MemoryCollectionListInput() + + results = self.__control_api.list_memory_collections( + ListMemoryCollectionsRequest().from_map(input.model_dump()), + config=config, + ) + return [ + MemoryCollectionListOutput.from_inner_object(item) + for item in results.items # type: ignore + ] diff --git a/agentrun/memory_collection/memory_collection.py b/agentrun/memory_collection/memory_collection.py new file mode 100644 index 0000000..bc5f5cf --- /dev/null +++ b/agentrun/memory_collection/memory_collection.py @@ -0,0 +1,844 @@ +""" +This file is auto generated by the code generation script. +Do not modify this file manually. +Use the `make codegen` command to regenerate. + +当前文件为自动生成的控制 API 客户端代码。请勿手动修改此文件。 +使用 `make codegen` 命令重新生成。 + +source: agentrun/memory_collection/__memory_collection_async_template.py + +MemoryCollection 高层 API / MemoryCollection High-Level API + +此模块定义记忆集合资源的高级API。 +This module defines the high-level API for memory collection resources. +""" + +from typing import Any, Dict, List, Optional, Tuple + +from agentrun.utils.config import Config +from agentrun.utils.model import PageableInput +from agentrun.utils.resource import ResourceBase + +from .model import ( + MemoryCollectionCreateInput, + MemoryCollectionImmutableProps, + MemoryCollectionListInput, + MemoryCollectionListOutput, + MemoryCollectionMutableProps, + MemoryCollectionSystemProps, + MemoryCollectionUpdateInput, +) + + +class MemoryCollection( + MemoryCollectionMutableProps, + MemoryCollectionImmutableProps, + MemoryCollectionSystemProps, + ResourceBase, +): + """记忆集合资源 / MemoryCollection Resource + + 提供记忆集合的完整生命周期管理,包括创建、删除、更新、查询。 + Provides complete lifecycle management for memory collections, including create, delete, update, and query. + """ + + @classmethod + def __get_client(cls): + """获取客户端实例 / Get client instance + + Returns: + MemoryCollectionClient: 客户端实例 / Client instance + """ + from .client import MemoryCollectionClient + + return MemoryCollectionClient() + + @classmethod + async def create_async( + cls, input: MemoryCollectionCreateInput, config: Optional[Config] = None + ): + """创建记忆集合(异步) + + Args: + input: 记忆集合输入参数 + config: 配置 + + Returns: + MemoryCollection: 创建的记忆集合对象 + """ + return await cls.__get_client().create_async(input, config=config) + + @classmethod + def create( + cls, input: MemoryCollectionCreateInput, config: Optional[Config] = None + ): + """创建记忆集合(同步) + + Args: + input: 记忆集合输入参数 + config: 配置 + + Returns: + MemoryCollection: 创建的记忆集合对象 + """ + return cls.__get_client().create(input, config=config) + + @classmethod + async def delete_by_name_async( + cls, memory_collection_name: str, config: Optional[Config] = None + ): + """根据名称删除记忆集合(异步) + + Args: + memory_collection_name: 记忆集合名称 + config: 配置 + """ + return await cls.__get_client().delete_async( + memory_collection_name, config=config + ) + + @classmethod + def delete_by_name( + cls, memory_collection_name: str, config: Optional[Config] = None + ): + """根据名称删除记忆集合(同步) + + Args: + memory_collection_name: 记忆集合名称 + config: 配置 + """ + return cls.__get_client().delete(memory_collection_name, config=config) + + @classmethod + async def update_by_name_async( + cls, + memory_collection_name: str, + input: MemoryCollectionUpdateInput, + config: Optional[Config] = None, + ): + """根据名称更新记忆集合(异步) + + Args: + memory_collection_name: 记忆集合名称 + input: 记忆集合更新输入参数 + config: 配置 + + Returns: + MemoryCollection: 更新后的记忆集合对象 + """ + return await cls.__get_client().update_async( + memory_collection_name, input, config=config + ) + + @classmethod + def update_by_name( + cls, + memory_collection_name: str, + input: MemoryCollectionUpdateInput, + config: Optional[Config] = None, + ): + """根据名称更新记忆集合(同步) + + Args: + memory_collection_name: 记忆集合名称 + input: 记忆集合更新输入参数 + config: 配置 + + Returns: + MemoryCollection: 更新后的记忆集合对象 + """ + return cls.__get_client().update( + memory_collection_name, input, config=config + ) + + @classmethod + async def get_by_name_async( + cls, memory_collection_name: str, config: Optional[Config] = None + ): + """根据名称获取记忆集合(异步) + + Args: + memory_collection_name: 记忆集合名称 + config: 配置 + + Returns: + MemoryCollection: 记忆集合对象 + """ + return await cls.__get_client().get_async( + memory_collection_name, config=config + ) + + @classmethod + def get_by_name( + cls, memory_collection_name: str, config: Optional[Config] = None + ): + """根据名称获取记忆集合(同步) + + Args: + memory_collection_name: 记忆集合名称 + config: 配置 + + Returns: + MemoryCollection: 记忆集合对象 + """ + return cls.__get_client().get(memory_collection_name, config=config) + + @classmethod + async def _list_page_async( + cls, page_input: PageableInput, config: Config | None = None, **kwargs + ): + return await cls.__get_client().list_async( + input=MemoryCollectionListInput( + **kwargs, + **page_input.model_dump(), + ), + config=config, + ) + + @classmethod + def _list_page( + cls, page_input: PageableInput, config: Config | None = None, **kwargs + ): + return cls.__get_client().list( + input=MemoryCollectionListInput( + **kwargs, + **page_input.model_dump(), + ), + config=config, + ) + + @classmethod + async def list_all_async( + cls, + *, + memory_collection_name: Optional[str] = None, + config: Optional[Config] = None, + ) -> List[MemoryCollectionListOutput]: + """列出所有记忆集合(异步) + + Args: + memory_collection_name: 记忆集合名称(可选) + config: 配置 + + Returns: + List[MemoryCollectionListOutput]: 记忆集合列表 + """ + return await cls._list_all_async( + lambda mc: mc.memory_collection_id or "", + config=config, + memory_collection_name=memory_collection_name, + ) + + @classmethod + def list_all( + cls, + *, + memory_collection_name: Optional[str] = None, + config: Optional[Config] = None, + ) -> List[MemoryCollectionListOutput]: + """列出所有记忆集合(同步) + + Args: + memory_collection_name: 记忆集合名称(可选) + config: 配置 + + Returns: + List[MemoryCollectionListOutput]: 记忆集合列表 + """ + return cls._list_all( + lambda mc: mc.memory_collection_id or "", + config=config, + memory_collection_name=memory_collection_name, + ) + + async def update_async( + self, + input: MemoryCollectionUpdateInput, + config: Optional[Config] = None, + ): + """更新记忆集合(异步) + + Args: + input: 记忆集合更新输入参数 + config: 配置 + + Returns: + MemoryCollection: 更新后的记忆集合对象 + """ + if self.memory_collection_name is None: + raise ValueError( + "memory_collection_name is required to update a" + " MemoryCollection" + ) + + result = await self.update_by_name_async( + self.memory_collection_name, input, config=config + ) + self.update_self(result) + + return self + + def update( + self, + input: MemoryCollectionUpdateInput, + config: Optional[Config] = None, + ): + """更新记忆集合(同步) + + Args: + input: 记忆集合更新输入参数 + config: 配置 + + Returns: + MemoryCollection: 更新后的记忆集合对象 + """ + if self.memory_collection_name is None: + raise ValueError( + "memory_collection_name is required to update a" + " MemoryCollection" + ) + + result = self.update_by_name( + self.memory_collection_name, input, config=config + ) + self.update_self(result) + + return self + + async def delete_async(self, config: Optional[Config] = None): + """删除记忆集合(异步) + + Args: + config: 配置 + """ + if self.memory_collection_name is None: + raise ValueError( + "memory_collection_name is required to delete a" + " MemoryCollection" + ) + + return await self.delete_by_name_async( + self.memory_collection_name, config=config + ) + + def delete(self, config: Optional[Config] = None): + """删除记忆集合(同步) + + Args: + config: 配置 + """ + if self.memory_collection_name is None: + raise ValueError( + "memory_collection_name is required to delete a" + " MemoryCollection" + ) + + return self.delete_by_name(self.memory_collection_name, config=config) + + async def get_async(self, config: Optional[Config] = None): + """刷新记忆集合信息(异步) + + Args: + config: 配置 + + Returns: + MemoryCollection: 刷新后的记忆集合对象 + """ + if self.memory_collection_name is None: + raise ValueError( + "memory_collection_name is required to refresh a" + " MemoryCollection" + ) + + result = await self.get_by_name_async( + self.memory_collection_name, config=config + ) + self.update_self(result) + + return self + + def get(self, config: Optional[Config] = None): + """刷新记忆集合信息(同步) + + Args: + config: 配置 + + Returns: + MemoryCollection: 刷新后的记忆集合对象 + """ + if self.memory_collection_name is None: + raise ValueError( + "memory_collection_name is required to refresh a" + " MemoryCollection" + ) + + result = self.get_by_name(self.memory_collection_name, config=config) + self.update_self(result) + + return self + + async def refresh_async(self, config: Optional[Config] = None): + """刷新记忆集合信息(异步) + + Args: + config: 配置 + + Returns: + MemoryCollection: 刷新后的记忆集合对象 + """ + return await self.get_async(config=config) + + def refresh(self, config: Optional[Config] = None): + """刷新记忆集合信息(同步) + + Args: + config: 配置 + + Returns: + MemoryCollection: 刷新后的记忆集合对象 + """ + return self.get(config=config) + + @classmethod + async def to_mem0_memory_async( + cls, + memory_collection_name: str, + config: Optional[Config] = None, + history_db_path: Optional[str] = None, + ): + """将 MemoryCollection 转换为 agentrun-mem0ai Memory 客户端(异步) + + Args: + memory_collection_name: 记忆集合名称 + config: AgentRun 配置 + history_db_path: mem0 历史数据库路径(可选) + + Returns: + Memory: agentrun-mem0ai Memory 客户端实例 + + Raises: + ImportError: 如果未安装 agentrun-mem0ai 包 + ValueError: 如果配置信息不完整 + + Example: + >>> memory = await MemoryCollection.to_mem0_memory_async( + ... "memoryCollection010901", + ... config=config + ... ) + >>> memory.add("用户喜欢吃苹果", user_id="user123") + """ + try: + from mem0 import Memory + except ImportError as e: + raise ImportError( + "agentrun-mem0ai package is required. Install it with: pip" + " install agentrun-mem0ai" + ) from e + + # 获取 MemoryCollection 配置 + memory_collection = await cls.get_by_name_async( + memory_collection_name, config=config + ) + + # 构建 mem0 配置 + mem0_config = await cls._build_mem0_config_async( + memory_collection, config, history_db_path + ) + + # 创建并返回 Memory 实例 + return Memory.from_config(mem0_config) + + @classmethod + def to_mem0_memory( + cls, + memory_collection_name: str, + config: Optional[Config] = None, + history_db_path: Optional[str] = None, + ): + """将 MemoryCollection 转换为 agentrun-mem0ai Memory 客户端(同步) + + Args: + memory_collection_name: 记忆集合名称 + config: AgentRun 配置 + history_db_path: mem0 历史数据库路径(可选) + + Returns: + Memory: agentrun-mem0ai Memory 客户端实例 + + Raises: + ImportError: 如果未安装 agentrun-mem0ai 包 + ValueError: 如果配置信息不完整 + + Example: + >>> memory = MemoryCollection.to_mem0_memory( + ... "memoryCollection010901", + ... config=config + ... ) + >>> memory.add("用户喜欢吃苹果", user_id="user123") + """ + try: + from mem0 import Memory + except ImportError as e: + raise ImportError( + "agentrun-mem0ai package is required. Install it with: pip" + " install agentrun-mem0ai" + ) from e + + # 获取 MemoryCollection 配置 + memory_collection = cls.get_by_name( + memory_collection_name, config=config + ) + + # 构建 mem0 配置 + mem0_config = cls._build_mem0_config( + memory_collection, config, history_db_path + ) + + # 创建并返回 Memory 实例 + return Memory.from_config(mem0_config) + + @staticmethod + def _convert_vpc_endpoint_to_public(endpoint: str) -> str: + """将 VPC 内网地址转换为公网地址 + + Args: + endpoint: 原始 endpoint,可能是 VPC 内网地址 + + Returns: + str: 公网地址 + + Example: + >>> _convert_vpc_endpoint_to_public("https://jiuqing.cn-hangzhou.vpc.tablestore.aliyuncs.com") + "https://jiuqing.cn-hangzhou.ots.aliyuncs.com" + """ + if ".vpc.tablestore.aliyuncs.com" in endpoint: + # 将 .vpc.tablestore.aliyuncs.com 替换为 .ots.aliyuncs.com + return endpoint.replace( + ".vpc.tablestore.aliyuncs.com", ".ots.aliyuncs.com" + ) + return endpoint + + @classmethod + async def _build_mem0_config_async( + cls, + memory_collection: "MemoryCollection", + config: Optional[Config], + history_db_path: Optional[str] = None, + ) -> Dict[str, Any]: + """构建 mem0 配置字典(异步) + + Args: + memory_collection: MemoryCollection 对象 + config: AgentRun 配置 + history_db_path: 历史数据库路径 + + Returns: + Dict[str, Any]: mem0 配置字典 + """ + mem0_config: Dict[str, Any] = {} + + # 构建 vector_store 配置 + if memory_collection.vector_store_config: + vector_store_config = memory_collection.vector_store_config + provider = vector_store_config.provider or "" + + if vector_store_config.config: + vs_config = vector_store_config.config + vector_store: Dict[str, Any] = { + "provider": provider, + "config": {}, + } + + # 根据不同的 provider 构建配置 + if provider == "aliyun_tablestore": + # 获取凭证信息 + effective_config = config or Config() + # 将 VPC 内网地址转换为公网地址 + public_endpoint = cls._convert_vpc_endpoint_to_public( + vs_config.endpoint or "" + ) + vector_store["config"] = { + "vector_dimension": vs_config.vector_dimension, + "endpoint": public_endpoint, + "instance_name": vs_config.instance_name, + "collection_name": vs_config.collection_name, + "access_key_id": effective_config.get_access_key_id(), + "access_key_secret": ( + effective_config.get_access_key_secret() + ), + } + # 如果有 security_token,添加它 + security_token = effective_config.get_security_token() + if security_token: + vector_store["config"]["sts_token"] = security_token + else: + # 其他 provider 的通用配置 + vector_store["config"] = { + "endpoint": vs_config.endpoint, + "collection_name": vs_config.collection_name, + } + if vs_config.vector_dimension: + vector_store["config"][ + "vector_dimension" + ] = vs_config.vector_dimension + + mem0_config["vector_store"] = vector_store + + # 构建 llm 配置 + if memory_collection.llm_config: + llm_config = memory_collection.llm_config + model_service_name = llm_config.model_service_name + + if model_service_name and llm_config.config: + # 使用高层 API 获取 ModelService 配置 + base_url, api_key = ( + await cls._resolve_model_service_config_async( + model_service_name, config + ) + ) + + mem0_config["llm"] = { + "provider": "openai", # mem0 使用 openai 兼容接口 + "config": { + "model": llm_config.config.model, + "openai_base_url": base_url, + "api_key": api_key, + }, + } + + # 构建 embedder 配置 + if memory_collection.embedder_config: + embedder_config = memory_collection.embedder_config + model_service_name = embedder_config.model_service_name + + if model_service_name and embedder_config.config: + # 使用高层 API 获取 ModelService 配置 + base_url, api_key = ( + await cls._resolve_model_service_config_async( + model_service_name, config + ) + ) + + mem0_config["embedder"] = { + "provider": "openai", # mem0 使用 openai 兼容接口 + "config": { + "model": embedder_config.config.model, + "openai_base_url": base_url, + "api_key": api_key, + }, + } + + # 添加历史数据库路径 + if history_db_path: + mem0_config["history_db_path"] = history_db_path + + return mem0_config + + @classmethod + def _build_mem0_config( + cls, + memory_collection: "MemoryCollection", + config: Optional[Config], + history_db_path: Optional[str] = None, + ) -> Dict[str, Any]: + """构建 mem0 配置字典(同步) + + Args: + memory_collection: MemoryCollection 对象 + config: AgentRun 配置 + history_db_path: 历史数据库路径 + + Returns: + Dict[str, Any]: mem0 配置字典 + """ + mem0_config: Dict[str, Any] = {} + + # 构建 vector_store 配置 + if memory_collection.vector_store_config: + vector_store_config = memory_collection.vector_store_config + provider = vector_store_config.provider or "" + + if vector_store_config.config: + vs_config = vector_store_config.config + vector_store: Dict[str, Any] = { + "provider": provider, + "config": {}, + } + + # 根据不同的 provider 构建配置 + if provider == "aliyun_tablestore": + # 获取凭证信息 + effective_config = config or Config() + # 将 VPC 内网地址转换为公网地址 + public_endpoint = cls._convert_vpc_endpoint_to_public( + vs_config.endpoint or "" + ) + vector_store["config"] = { + "vector_dimension": vs_config.vector_dimension, + "endpoint": public_endpoint, + "instance_name": vs_config.instance_name, + "collection_name": vs_config.collection_name, + "access_key_id": effective_config.get_access_key_id(), + "access_key_secret": ( + effective_config.get_access_key_secret() + ), + } + # 如果有 security_token,添加它 + security_token = effective_config.get_security_token() + if security_token: + vector_store["config"]["sts_token"] = security_token + else: + # 其他 provider 的通用配置 + vector_store["config"] = { + "endpoint": vs_config.endpoint, + "collection_name": vs_config.collection_name, + } + if vs_config.vector_dimension: + vector_store["config"][ + "vector_dimension" + ] = vs_config.vector_dimension + + mem0_config["vector_store"] = vector_store + + # 构建 llm 配置 + if memory_collection.llm_config: + llm_config = memory_collection.llm_config + model_service_name = llm_config.model_service_name + + if model_service_name and llm_config.config: + # 使用高层 API 获取 ModelService 配置 + base_url, api_key = cls._resolve_model_service_config( + model_service_name, config + ) + + mem0_config["llm"] = { + "provider": "openai", # mem0 使用 openai 兼容接口 + "config": { + "model": llm_config.config.model, + "openai_base_url": base_url, + "api_key": api_key, + }, + } + + # 构建 embedder 配置 + if memory_collection.embedder_config: + embedder_config = memory_collection.embedder_config + model_service_name = embedder_config.model_service_name + + if model_service_name and embedder_config.config: + # 使用高层 API 获取 ModelService 配置 + base_url, api_key = cls._resolve_model_service_config( + model_service_name, config + ) + + mem0_config["embedder"] = { + "provider": "openai", # mem0 使用 openai 兼容接口 + "config": { + "model": embedder_config.config.model, + "openai_base_url": base_url, + "api_key": api_key, + }, + } + + # 添加历史数据库路径 + if history_db_path: + mem0_config["history_db_path"] = history_db_path + + return mem0_config + + @staticmethod + async def _resolve_model_service_config_async( + model_service_name: str, config: Optional[Config] + ) -> Tuple[str, str]: + """解析 ModelService 配置获取 baseUrl 和 apiKey(异步) + + Args: + model_service_name: ModelService 名称 + config: AgentRun 配置 + + Returns: + Tuple[str, str]: (base_url, api_key) + + Raises: + ValueError: 如果配置信息不完整 + """ + from agentrun.credential import Credential + from agentrun.model import ModelService + + # 使用高层 API 获取 ModelService + model_service = await ModelService.get_by_name_async( + model_service_name, config=config + ) + + # 获取 provider_settings + if not model_service.provider_settings: + raise ValueError( + f"ModelService {model_service_name} providerSettings is empty" + ) + + base_url = model_service.provider_settings.base_url or "" + api_key = model_service.provider_settings.api_key or "" + + # 如果有 credentialName,使用高层 API 获取 credential secret + credential_name = model_service.credential_name + if credential_name: + credential = await Credential.get_by_name_async( + credential_name, config=config + ) + if credential.credential_secret: + api_key = credential.credential_secret + + if not base_url: + raise ValueError( + f"ModelService {model_service_name} baseUrl is empty" + ) + + return base_url, api_key + + @staticmethod + def _resolve_model_service_config( + model_service_name: str, config: Optional[Config] + ) -> Tuple[str, str]: + """解析 ModelService 配置获取 baseUrl 和 apiKey(同步) + + Args: + model_service_name: ModelService 名称 + config: AgentRun 配置 + + Returns: + Tuple[str, str]: (base_url, api_key) + + Raises: + ValueError: 如果配置信息不完整 + """ + from agentrun.credential import Credential + from agentrun.model import ModelService + + # 使用高层 API 获取 ModelService + model_service = ModelService.get_by_name( + model_service_name, config=config + ) + + # 获取 provider_settings + if not model_service.provider_settings: + raise ValueError( + f"ModelService {model_service_name} providerSettings is empty" + ) + + base_url = model_service.provider_settings.base_url or "" + api_key = model_service.provider_settings.api_key or "" + + # 如果有 credentialName,使用高层 API 获取 credential secret + credential_name = model_service.credential_name + if credential_name: + credential = Credential.get_by_name(credential_name, config=config) + if credential.credential_secret: + api_key = credential.credential_secret + + if not base_url: + raise ValueError( + f"ModelService {model_service_name} baseUrl is empty" + ) + + return base_url, api_key diff --git a/agentrun/memory_collection/model.py b/agentrun/memory_collection/model.py new file mode 100644 index 0000000..90bd9bc --- /dev/null +++ b/agentrun/memory_collection/model.py @@ -0,0 +1,162 @@ +"""MemoryCollection 模型定义 / MemoryCollection Model Definitions + +定义记忆集合相关的数据模型和枚举。 +Defines data models and enumerations related to memory collections. +""" + +from typing import Any, Dict, List, Optional + +from agentrun.utils.config import Config +from agentrun.utils.model import BaseModel, PageableInput + + +class EmbedderConfigConfig(BaseModel): + """嵌入模型内部配置 / Embedder Inner Configuration""" + + model: Optional[str] = None + """模型名称""" + + +class EmbedderConfig(BaseModel): + """嵌入模型配置 / Embedder Configuration""" + + config: Optional[EmbedderConfigConfig] = None + """配置""" + model_service_name: Optional[str] = None + """模型服务名称""" + + +class LLMConfigConfig(BaseModel): + """LLM 内部配置 / LLM Inner Configuration""" + + model: Optional[str] = None + """模型名称""" + + +class LLMConfig(BaseModel): + """LLM 配置 / LLM Configuration""" + + config: Optional[LLMConfigConfig] = None + """配置""" + model_service_name: Optional[str] = None + """模型服务名称""" + + +class NetworkConfiguration(BaseModel): + """网络配置 / Network Configuration""" + + vpc_id: Optional[str] = None + """VPC ID""" + vswitch_ids: Optional[List[str]] = None + """交换机 ID 列表""" + security_group_id: Optional[str] = None + """安全组 ID""" + network_mode: Optional[str] = None + """网络模式""" + + +class VectorStoreConfigConfig(BaseModel): + """向量存储内部配置 / Vector Store Inner Configuration""" + + endpoint: Optional[str] = None + """端点""" + instance_name: Optional[str] = None + """实例名称""" + collection_name: Optional[str] = None + """集合名称""" + vector_dimension: Optional[int] = None + """向量维度""" + + +class VectorStoreConfig(BaseModel): + """向量存储配置 / Vector Store Configuration""" + + provider: Optional[str] = None + """提供商""" + config: Optional[VectorStoreConfigConfig] = None + """配置""" + + +class MemoryCollectionMutableProps(BaseModel): + """MemoryCollection 可变属性""" + + description: Optional[str] = None + """描述""" + embedder_config: Optional[EmbedderConfig] = None + """嵌入模型配置""" + execution_role_arn: Optional[str] = None + """执行角色 ARN""" + llm_config: Optional[LLMConfig] = None + """LLM 配置""" + network_configuration: Optional[NetworkConfiguration] = None + """网络配置""" + vector_store_config: Optional[VectorStoreConfig] = None + """向量存储配置""" + + +class MemoryCollectionImmutableProps(BaseModel): + """MemoryCollection 不可变属性""" + + memory_collection_name: Optional[str] = None + """Memory Collection 名称""" + type: Optional[str] = None + """类型""" + + +class MemoryCollectionSystemProps(BaseModel): + """MemoryCollection 系统属性""" + + memory_collection_id: Optional[str] = None + """Memory Collection ID""" + created_at: Optional[str] = None + """创建时间""" + last_updated_at: Optional[str] = None + """最后更新时间""" + + +class MemoryCollectionCreateInput( + MemoryCollectionImmutableProps, MemoryCollectionMutableProps +): + """MemoryCollection 创建输入参数""" + + pass + + +class MemoryCollectionUpdateInput(MemoryCollectionMutableProps): + """MemoryCollection 更新输入参数""" + + pass + + +class MemoryCollectionListInput(PageableInput): + """MemoryCollection 列表查询输入参数""" + + memory_collection_name: Optional[str] = None + """Memory Collection 名称""" + + +class MemoryCollectionListOutput(BaseModel): + """MemoryCollection 列表输出""" + + memory_collection_id: Optional[str] = None + memory_collection_name: Optional[str] = None + description: Optional[str] = None + type: Optional[str] = None + created_at: Optional[str] = None + last_updated_at: Optional[str] = None + + async def to_memory_collection_async(self, config: Optional[Config] = None): + """转换为完整的 MemoryCollection 对象(异步)""" + from .client import MemoryCollectionClient + + return await MemoryCollectionClient(config).get_async( + self.memory_collection_name or "", config=config + ) + + def to_memory_collection(self, config: Optional[Config] = None): + """转换为完整的 MemoryCollection 对象""" + from .client import MemoryCollectionClient + + return MemoryCollectionClient(config).get( + self.memory_collection_name or "", config=config + ) diff --git a/codegen/configs/memory_collection_control_api.yaml b/codegen/configs/memory_collection_control_api.yaml new file mode 100644 index 0000000..54e3450 --- /dev/null +++ b/codegen/configs/memory_collection_control_api.yaml @@ -0,0 +1,53 @@ +output_path: agentrun/memory_collection/api/control.py +template: control_api.jinja2 +class_name: MemoryCollectionControlAPI +description: Memory Collection 管控链路 API +imports: [] +methods: + # memory collection + - name: create_memory_collection + description: 创建 Memory Collection + params: + - name: input + type: CreateMemoryCollectionInput + wrapper_type: CreateMemoryCollectionRequest + description: Memory Collection 配置 + return_type: MemoryCollection + return_description: 创建的 Memory Collection 对象 + - name: delete_memory_collection + description: 删除 Memory Collection + params: + - name: memory_collection_name + type: str + description: Memory Collection 名称 + return_type: MemoryCollection + return_description: 删除结果 + - name: update_memory_collection + description: 更新 Memory Collection + params: + - name: memory_collection_name + type: str + description: Memory Collection 名称 + - name: input + type: UpdateMemoryCollectionInput + wrapper_type: UpdateMemoryCollectionRequest + description: Memory Collection 配置 + return_type: MemoryCollection + return_description: 更新的 Memory Collection 对象 + - name: get_memory_collection + description: 获取 Memory Collection + params: + - name: memory_collection_name + type: str + description: Memory Collection 名称 + return_type: MemoryCollection + return_description: Memory Collection 对象 + - name: list_memory_collections + description: 枚举 Memory Collection + params: + - name: input + type: ListMemoryCollectionsRequest + description: 枚举的配置 + return_type: ListMemoryCollectionsOutput + return_description: Memory Collection 列表 + diff --git a/examples/memory_collection_example.py b/examples/memory_collection_example.py new file mode 100644 index 0000000..4ccb0b9 --- /dev/null +++ b/examples/memory_collection_example.py @@ -0,0 +1,162 @@ +"""MemoryCollection 使用示例 / MemoryCollection Usage Example + +此示例展示如何使用 MemoryCollection 模块进行记忆集合管理,包括与 mem0ai 的集成。 +This example demonstrates how to use the MemoryCollection module for memory collection management, +including integration with mem0ai. +""" + +import asyncio + +from agentrun.memory_collection import ( + EmbedderConfig, + EmbedderConfigConfig, + LLMConfig, + LLMConfigConfig, + MemoryCollection, + MemoryCollectionClient, + MemoryCollectionCreateInput, + MemoryCollectionUpdateInput, + NetworkConfiguration, + VectorStoreConfig, + VectorStoreConfigConfig, +) +from agentrun.utils.config import Config + + +async def main(): + """主函数 / Main function""" + + # 创建配置 + # Create configuration + config = Config() + + # 方式一:使用 Client + # Method 1: Using Client + print("=== 使用 MemoryCollectionClient ===") + client = MemoryCollectionClient(config) + + # 创建记忆集合 + # Create memory collection + create_input = MemoryCollectionCreateInput( + memory_collection_name="memoryCollection010901", + description="这是一个测试", + execution_role_arn="acs:ram::1760720386195983:role/aliyunfcdefaultrole", + embedder_config=EmbedderConfig( + config=EmbedderConfigConfig(model="text-embedding-v4"), + model_service_name="bailian", + ), + llm_config=LLMConfig( + config=LLMConfigConfig(model="qwen3-max"), + model_service_name="qwen3-max", + ), + vector_store_config=VectorStoreConfig( + provider="aliyun_tablestore", + config=VectorStoreConfigConfig( + endpoint=( + "https://jiuqing.cn-hangzhou.vpc.tablestore.aliyuncs.com" + ), + instance_name="jiuqing", + collection_name="memories010901", + vector_dimension=1536, + ), + ), + network_configuration=NetworkConfiguration( + vpc_id="vpc-bp1r2uvn5xactndk2jdpi", + security_group_id="sg-bp1bsf819nni9h2upltv", + vswitch_ids=["vsw-bp1omyfoztt6mt4h9r8jy"], + ), + ) + + try: + # memory_collection = await client.create_async(create_input) + print(f"创建成功: {memory_collection.memory_collection_name}") + except Exception as e: + print(f"创建失败: {e}") + + # 方式二:使用高层 API + # Method 2: Using high-level API + print("\n=== 使用 MemoryCollection 高层 API ===") + + try: + # memory_collection = await MemoryCollection.create_async(create_input, config=config) + # 获取已创建的模型服务 + memory_collection = MemoryCollection.get_by_name( + "memoryCollection010901" + ) + print(f"获取成功: {memory_collection}") + # # 更新记忆集合 + # # Update memory collection + # update_input = MemoryCollectionUpdateInput( + # description="更新后的描述" + # ) + # await memory_collection.update_async(update_input) + # print("更新成功") + + # # 获取记忆集合 + # # Get memory collection + # await memory_collection.refresh_async() + # print(f"刷新成功: {memory_collection.description}") + + # # 列出所有记忆集合 + # # List all memory collections + # collections = await MemoryCollection.list_all_async(config=config) + # print(f"找到 {len(collections)} 个记忆集合") + + # # 删除记忆集合 + # # Delete memory collection + # # await memory_collection.delete_async() + # # print("删除成功") + + except Exception as e: + print(f"操作失败: {e}") + + # 方式三:转换为 mem0ai Memory 客户端(需要安装 agentrun-mem0ai 依赖) + # Method 3: Convert to mem0ai Memory client (requires mem0 dependency) + print("\n=== 转换为 mem0ai Memory 客户端 ===") + + try: + # 使用高层 API 的 to_mem0_memory 方法 + # Use high-level API's to_mem0_memory method + memory = MemoryCollection.to_mem0_memory("memoryCollection010901") + print(f"✅ 成功创建 mem0ai Memory 客户端") + print(f" 类型: {type(memory)}") + + # 使用 mem0ai Memory 客户端进行操作 + # Use mem0ai Memory client for operations + user_id = "user123" + + # 添加记忆 + # Add memory + result = memory.add( + "我喜欢吃苹果和香蕉", + user_id=user_id, + metadata={"category": "food"}, + ) + print(f"\n✅ 添加记忆成功:") + for idx, res in enumerate(result.get("results", []), 1): + print(f" {idx}. ID: {res.get('id')}, 事件: {res.get('event')}") + + # 搜索记忆 + # Search memory + search_results = memory.search("用户喜欢吃什么水果?", user_id=user_id) + print(f"\n✅ 搜索记忆结果:") + for idx, result in enumerate(search_results.get("results", []), 1): + print( + f" {idx}. 内容: {result.get('memory')}, 相似度:" + f" {result.get('score', 0):.4f}" + ) + + except ImportError as e: + print(f"⚠️ mem0ai 未安装: {e}") + print(" 安装方法: pip install agentrun-sdk[mem0]") + except Exception as e: + print(f"❌ mem0ai 操作失败: {e}") + import traceback + + traceback.print_exc() + + print("\n✅ 示例完成") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index cb34841..ff543dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,10 @@ mcp = [ "mcp>=1.21.2; python_version >= '3.10'", ] +mem0 = [ + "agentrun-mem0ai>=0.0.3", +] + [dependency-groups] dev = [ "coverage>=7.10.7", From 61f64dc69817a13849fb9d855a5d18cfbdf99a20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=85=E6=B0=A2?= Date: Tue, 13 Jan 2026 12:02:01 +0800 Subject: [PATCH 2/2] feat: memorycollection implement MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change-Id: I187d0e5d54abc79d4ed3d256253bd442988a9ff7 Co-developed-by: Cursor Signed-off-by: 久氢 --- .../__memory_collection_async_template.py | 2 +- agentrun/memory_collection/client.py | 2 +- .../memory_collection/memory_collection.py | 4 +- agentrun/memory_collection/model.py | 2 +- pyproject.toml | 7 +- tests/unittests/memory_collection/__init__.py | 1 + .../memory_collection/test_client.py | 538 ++++++++++++++++++ .../test_memory_collection.py | 366 ++++++++++++ .../unittests/memory_collection/test_model.py | 327 +++++++++++ 9 files changed, 1239 insertions(+), 10 deletions(-) create mode 100644 tests/unittests/memory_collection/__init__.py create mode 100644 tests/unittests/memory_collection/test_client.py create mode 100644 tests/unittests/memory_collection/test_memory_collection.py create mode 100644 tests/unittests/memory_collection/test_model.py diff --git a/agentrun/memory_collection/__memory_collection_async_template.py b/agentrun/memory_collection/__memory_collection_async_template.py index 6c9dae0..b2f574d 100644 --- a/agentrun/memory_collection/__memory_collection_async_template.py +++ b/agentrun/memory_collection/__memory_collection_async_template.py @@ -250,7 +250,7 @@ async def to_mem0_memory_async( >>> memory.add("用户喜欢吃苹果", user_id="user123") """ try: - from mem0 import Memory + from agentrun_mem0 import Memory except ImportError as e: raise ImportError( "agentrun-mem0ai package is required. Install it with: pip" diff --git a/agentrun/memory_collection/client.py b/agentrun/memory_collection/client.py index 3de6ce8..a31c8d9 100644 --- a/agentrun/memory_collection/client.py +++ b/agentrun/memory_collection/client.py @@ -85,7 +85,7 @@ def create( input: MemoryCollectionCreateInput, config: Optional[Config] = None, ): - """创建记忆集合(同步) / Create memory collection asynchronously + """创建记忆集合(同步) / Create memory collection synchronously Args: input: 记忆集合输入参数 / Memory collection input parameters diff --git a/agentrun/memory_collection/memory_collection.py b/agentrun/memory_collection/memory_collection.py index bc5f5cf..94e613b 100644 --- a/agentrun/memory_collection/memory_collection.py +++ b/agentrun/memory_collection/memory_collection.py @@ -429,7 +429,7 @@ async def to_mem0_memory_async( >>> memory.add("用户喜欢吃苹果", user_id="user123") """ try: - from mem0 import Memory + from agentrun_mem0 import Memory except ImportError as e: raise ImportError( "agentrun-mem0ai package is required. Install it with: pip" @@ -478,7 +478,7 @@ def to_mem0_memory( >>> memory.add("用户喜欢吃苹果", user_id="user123") """ try: - from mem0 import Memory + from agentrun_mem0 import Memory except ImportError as e: raise ImportError( "agentrun-mem0ai package is required. Install it with: pip" diff --git a/agentrun/memory_collection/model.py b/agentrun/memory_collection/model.py index 90bd9bc..8fd95c8 100644 --- a/agentrun/memory_collection/model.py +++ b/agentrun/memory_collection/model.py @@ -4,7 +4,7 @@ Defines data models and enumerations related to memory collections. """ -from typing import Any, Dict, List, Optional +from typing import List, Optional from agentrun.utils.config import Config from agentrun.utils.model import BaseModel, PageableInput diff --git a/pyproject.toml b/pyproject.toml index ff543dd..4162dc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,9 +13,10 @@ dependencies = [ "litellm>=1.79.3", "alibabacloud-devs20230714>=2.4.1", "pydash>=8.0.5", - "alibabacloud-agentrun20250910>=5.0.1", + "alibabacloud-agentrun20250910>=5.2.0", "alibabacloud_tea_openapi>=0.4.2", "alibabacloud_bailian20231229>=2.6.2", + "agentrun-mem0ai>=0.0.6", ] [project.optional-dependencies] @@ -54,10 +55,6 @@ mcp = [ "mcp>=1.21.2; python_version >= '3.10'", ] -mem0 = [ - "agentrun-mem0ai>=0.0.3", -] - [dependency-groups] dev = [ "coverage>=7.10.7", diff --git a/tests/unittests/memory_collection/__init__.py b/tests/unittests/memory_collection/__init__.py new file mode 100644 index 0000000..ae52e93 --- /dev/null +++ b/tests/unittests/memory_collection/__init__.py @@ -0,0 +1 @@ +"""测试 agentrun.memory_collection 模块 / Test agentrun.memory_collection module""" diff --git a/tests/unittests/memory_collection/test_client.py b/tests/unittests/memory_collection/test_client.py new file mode 100644 index 0000000..52ea94c --- /dev/null +++ b/tests/unittests/memory_collection/test_client.py @@ -0,0 +1,538 @@ +"""测试 agentrun.memory_collection.client 模块 / Test agentrun.memory_collection.client module""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.memory_collection.client import MemoryCollectionClient +from agentrun.memory_collection.model import ( + EmbedderConfig, + EmbedderConfigConfig, + LLMConfig, + LLMConfigConfig, + MemoryCollectionCreateInput, + MemoryCollectionListInput, + MemoryCollectionUpdateInput, + NetworkConfiguration, + VectorStoreConfig, + VectorStoreConfigConfig, +) +from agentrun.utils.config import Config +from agentrun.utils.exception import ( + HTTPError, + ResourceAlreadyExistError, + ResourceNotExistError, +) + + +class MockMemoryCollectionData: + """模拟记忆集合数据""" + + def to_map(self): + return { + "memoryCollectionId": "mc-123", + "memoryCollectionName": "test-memory-collection", + "description": "Test memory collection", + "type": "vector", + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + "embedderConfig": { + "modelServiceName": "test-embedder", + "config": {"model": "text-embedding-3-small"}, + }, + "llmConfig": { + "modelServiceName": "test-llm", + "config": {"model": "gpt-4"}, + }, + "vectorStoreConfig": { + "provider": "dashvector", + "config": { + "endpoint": "https://test.dashvector.cn", + "instanceName": "test-instance", + "collectionName": "test-collection", + "vectorDimension": 1536, + }, + }, + } + + +class MockListResult: + """模拟列表结果""" + + def __init__(self, items): + self.items = items + + +class TestMemoryCollectionClientInit: + """测试 MemoryCollectionClient 初始化""" + + def test_init_without_config(self): + """测试不带配置的初始化""" + client = MemoryCollectionClient() + assert client is not None + + def test_init_with_config(self): + """测试带配置的初始化""" + config = Config(access_key_id="test-ak") + client = MemoryCollectionClient(config=config) + assert client is not None + + +class TestMemoryCollectionClientCreate: + """测试 MemoryCollectionClient.create 方法""" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_create_sync(self, mock_control_api_class): + """测试同步创建记忆集合""" + mock_control_api = MagicMock() + mock_control_api.create_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + input_obj = MemoryCollectionCreateInput( + memory_collection_name="test-memory-collection", + type="vector", + description="Test memory collection", + embedder_config=EmbedderConfig( + model_service_name="test-embedder", + config=EmbedderConfigConfig(model="text-embedding-3-small"), + ), + ) + + result = client.create(input_obj) + assert result.memory_collection_name == "test-memory-collection" + assert mock_control_api.create_memory_collection.called + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_create_async(self, mock_control_api_class): + """测试异步创建记忆集合""" + mock_control_api = MagicMock() + mock_control_api.create_memory_collection_async = AsyncMock( + return_value=MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + input_obj = MemoryCollectionCreateInput( + memory_collection_name="test-memory-collection", + type="vector", + description="Test memory collection", + ) + + result = await client.create_async(input_obj) + assert result.memory_collection_name == "test-memory-collection" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_create_with_full_config(self, mock_control_api_class): + """测试创建记忆集合(完整配置)""" + mock_control_api = MagicMock() + mock_control_api.create_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + input_obj = MemoryCollectionCreateInput( + memory_collection_name="test-memory-collection", + type="vector", + description="Test memory collection", + embedder_config=EmbedderConfig( + model_service_name="test-embedder", + config=EmbedderConfigConfig(model="text-embedding-3-small"), + ), + llm_config=LLMConfig( + model_service_name="test-llm", + config=LLMConfigConfig(model="gpt-4"), + ), + vector_store_config=VectorStoreConfig( + provider="dashvector", + config=VectorStoreConfigConfig( + endpoint="https://test.dashvector.cn", + instance_name="test-instance", + collection_name="test-collection", + vector_dimension=1536, + ), + ), + network_configuration=NetworkConfiguration( + vpc_id="vpc-123", + vswitch_ids=["vsw-123"], + security_group_id="sg-123", + network_mode="vpc", + ), + execution_role_arn="acs:ram::123:role/test", + ) + + result = client.create(input_obj) + assert result.memory_collection_name == "test-memory-collection" + assert mock_control_api.create_memory_collection.called + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_create_already_exists(self, mock_control_api_class): + """测试创建已存在的记忆集合""" + mock_control_api = MagicMock() + mock_control_api.create_memory_collection.side_effect = HTTPError( + status_code=409, + message="Resource already exists", + request_id="req-1", + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + input_obj = MemoryCollectionCreateInput( + memory_collection_name="existing-memory-collection", + type="vector", + ) + + with pytest.raises(ResourceAlreadyExistError): + client.create(input_obj) + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_create_async_already_exists(self, mock_control_api_class): + """测试异步创建已存在的记忆集合""" + mock_control_api = MagicMock() + mock_control_api.create_memory_collection_async = AsyncMock( + side_effect=HTTPError( + status_code=409, + message="Resource already exists", + request_id="req-1", + ) + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + input_obj = MemoryCollectionCreateInput( + memory_collection_name="existing-memory-collection", + type="vector", + ) + + with pytest.raises(ResourceAlreadyExistError): + await client.create_async(input_obj) + + +class TestMemoryCollectionClientDelete: + """测试 MemoryCollectionClient.delete 方法""" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_delete_sync(self, mock_control_api_class): + """测试同步删除记忆集合""" + mock_control_api = MagicMock() + mock_control_api.delete_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + result = client.delete("test-memory-collection") + assert result is not None + assert mock_control_api.delete_memory_collection.called + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_delete_async(self, mock_control_api_class): + """测试异步删除记忆集合""" + mock_control_api = MagicMock() + mock_control_api.delete_memory_collection_async = AsyncMock( + return_value=MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + result = await client.delete_async("test-memory-collection") + assert result is not None + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_delete_not_exist(self, mock_control_api_class): + """测试删除不存在的记忆集合""" + mock_control_api = MagicMock() + mock_control_api.delete_memory_collection.side_effect = HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + with pytest.raises(ResourceNotExistError): + client.delete("nonexistent-memory-collection") + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_delete_async_not_exist(self, mock_control_api_class): + """测试异步删除不存在的记忆集合""" + mock_control_api = MagicMock() + mock_control_api.delete_memory_collection_async = AsyncMock( + side_effect=HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + with pytest.raises(ResourceNotExistError): + await client.delete_async("nonexistent-memory-collection") + + +class TestMemoryCollectionClientUpdate: + """测试 MemoryCollectionClient.update 方法""" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_update_sync(self, mock_control_api_class): + """测试同步更新记忆集合""" + mock_control_api = MagicMock() + mock_control_api.update_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + input_obj = MemoryCollectionUpdateInput( + description="Updated description" + ) + result = client.update("test-memory-collection", input_obj) + assert result is not None + assert mock_control_api.update_memory_collection.called + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_update_sync_with_embedder_config(self, mock_control_api_class): + """测试同步更新记忆集合(带嵌入模型配置)""" + mock_control_api = MagicMock() + mock_control_api.update_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + input_obj = MemoryCollectionUpdateInput( + description="Updated", + embedder_config=EmbedderConfig( + model_service_name="new-embedder", + config=EmbedderConfigConfig(model="text-embedding-ada-002"), + ), + ) + result = client.update("test-memory-collection", input_obj) + assert result is not None + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_update_async(self, mock_control_api_class): + """测试异步更新记忆集合""" + mock_control_api = MagicMock() + mock_control_api.update_memory_collection_async = AsyncMock( + return_value=MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + input_obj = MemoryCollectionUpdateInput(description="Updated") + result = await client.update_async("test-memory-collection", input_obj) + assert result is not None + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_update_async_with_llm_config(self, mock_control_api_class): + """测试异步更新记忆集合(带 LLM 配置)""" + mock_control_api = MagicMock() + mock_control_api.update_memory_collection_async = AsyncMock( + return_value=MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + input_obj = MemoryCollectionUpdateInput( + llm_config=LLMConfig( + model_service_name="new-llm", + config=LLMConfigConfig(model="gpt-4-turbo"), + ) + ) + result = await client.update_async("test-memory-collection", input_obj) + assert result is not None + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_update_not_exist(self, mock_control_api_class): + """测试更新不存在的记忆集合""" + mock_control_api = MagicMock() + mock_control_api.update_memory_collection.side_effect = HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + input_obj = MemoryCollectionUpdateInput(description="Updated") + with pytest.raises(ResourceNotExistError): + client.update("nonexistent-memory-collection", input_obj) + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_update_with_vector_store_config(self, mock_control_api_class): + """测试更新记忆集合(带向量存储配置)""" + mock_control_api = MagicMock() + mock_control_api.update_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + input_obj = MemoryCollectionUpdateInput( + vector_store_config=VectorStoreConfig( + provider="dashvector", + config=VectorStoreConfigConfig( + endpoint="https://new.dashvector.cn", + instance_name="new-instance", + collection_name="new-collection", + vector_dimension=3072, + ), + ) + ) + result = client.update("test-memory-collection", input_obj) + assert result is not None + + +class TestMemoryCollectionClientGet: + """测试 MemoryCollectionClient.get 方法""" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_get_sync(self, mock_control_api_class): + """测试同步获取记忆集合""" + mock_control_api = MagicMock() + mock_control_api.get_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + result = client.get("test-memory-collection") + assert result.memory_collection_name == "test-memory-collection" + assert mock_control_api.get_memory_collection.called + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_get_async(self, mock_control_api_class): + """测试异步获取记忆集合""" + mock_control_api = MagicMock() + mock_control_api.get_memory_collection_async = AsyncMock( + return_value=MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + result = await client.get_async("test-memory-collection") + assert result.memory_collection_name == "test-memory-collection" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_get_not_exist(self, mock_control_api_class): + """测试获取不存在的记忆集合""" + mock_control_api = MagicMock() + mock_control_api.get_memory_collection.side_effect = HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + with pytest.raises(ResourceNotExistError): + client.get("nonexistent-memory-collection") + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_get_async_not_exist(self, mock_control_api_class): + """测试异步获取不存在的记忆集合""" + mock_control_api = MagicMock() + mock_control_api.get_memory_collection_async = AsyncMock( + side_effect=HTTPError( + status_code=404, + message="Resource does not exist", + request_id="req-1", + ) + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + with pytest.raises(ResourceNotExistError): + await client.get_async("nonexistent-memory-collection") + + +class TestMemoryCollectionClientList: + """测试 MemoryCollectionClient.list 方法""" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_list_sync(self, mock_control_api_class): + """测试同步列出记忆集合""" + mock_control_api = MagicMock() + mock_control_api.list_memory_collections.return_value = MockListResult([ + MockMemoryCollectionData(), + MockMemoryCollectionData(), + ]) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + result = client.list() + assert len(result) == 2 + assert mock_control_api.list_memory_collections.called + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_list_sync_with_input(self, mock_control_api_class): + """测试同步列出记忆集合(带输入参数)""" + mock_control_api = MagicMock() + mock_control_api.list_memory_collections.return_value = MockListResult( + [MockMemoryCollectionData()] + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + input_obj = MemoryCollectionListInput( + page_number=1, + page_size=10, + memory_collection_name="test-memory-collection", + ) + result = client.list(input=input_obj) + assert len(result) == 1 + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_list_async(self, mock_control_api_class): + """测试异步列出记忆集合""" + mock_control_api = MagicMock() + mock_control_api.list_memory_collections_async = AsyncMock( + return_value=MockListResult([MockMemoryCollectionData()]) + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + result = await client.list_async() + assert len(result) == 1 + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_list_async_with_input(self, mock_control_api_class): + """测试异步列出记忆集合(带输入参数)""" + mock_control_api = MagicMock() + mock_control_api.list_memory_collections_async = AsyncMock( + return_value=MockListResult([MockMemoryCollectionData()]) + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + input_obj = MemoryCollectionListInput(page_number=1, page_size=10) + result = await client.list_async(input=input_obj) + assert len(result) == 1 + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_list_empty(self, mock_control_api_class): + """测试列出空记忆集合列表""" + mock_control_api = MagicMock() + mock_control_api.list_memory_collections.return_value = MockListResult( + [] + ) + mock_control_api_class.return_value = mock_control_api + + client = MemoryCollectionClient() + result = client.list() + assert len(result) == 0 diff --git a/tests/unittests/memory_collection/test_memory_collection.py b/tests/unittests/memory_collection/test_memory_collection.py new file mode 100644 index 0000000..d5c465c --- /dev/null +++ b/tests/unittests/memory_collection/test_memory_collection.py @@ -0,0 +1,366 @@ +"""测试 agentrun.memory_collection.memory_collection 模块 / Test agentrun.memory_collection.memory_collection module""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.memory_collection.memory_collection import MemoryCollection +from agentrun.memory_collection.model import ( + EmbedderConfig, + EmbedderConfigConfig, + MemoryCollectionCreateInput, + MemoryCollectionUpdateInput, +) +from agentrun.utils.config import Config + + +class MockMemoryCollectionData: + """模拟记忆集合数据""" + + def to_map(self): + return { + "memoryCollectionId": "mc-123", + "memoryCollectionName": "test-memory-collection", + "description": "Test memory collection", + "type": "vector", + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + } + + +class MockListResult: + """模拟列表结果""" + + def __init__(self, items): + self.items = items + + +class TestMemoryCollectionCreate: + """测试 MemoryCollection.create 方法""" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_create_sync(self, mock_control_api_class): + """测试同步创建记忆集合""" + mock_control_api = MagicMock() + mock_control_api.create_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + input_obj = MemoryCollectionCreateInput( + memory_collection_name="test-memory-collection", + type="vector", + description="Test memory collection", + ) + + result = MemoryCollection.create(input_obj) + assert result.memory_collection_name == "test-memory-collection" + assert result.memory_collection_id == "mc-123" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_create_async(self, mock_control_api_class): + """测试异步创建记忆集合""" + mock_control_api = MagicMock() + mock_control_api.create_memory_collection_async = AsyncMock( + return_value=MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + input_obj = MemoryCollectionCreateInput( + memory_collection_name="test-memory-collection", + type="vector", + ) + + result = await MemoryCollection.create_async(input_obj) + assert result.memory_collection_name == "test-memory-collection" + + +class TestMemoryCollectionDelete: + """测试 MemoryCollection.delete 方法""" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_delete_by_name_sync(self, mock_control_api_class): + """测试根据名称同步删除记忆集合""" + mock_control_api = MagicMock() + mock_control_api.delete_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + result = MemoryCollection.delete_by_name("test-memory-collection") + assert result is not None + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_delete_by_name_async(self, mock_control_api_class): + """测试根据名称异步删除记忆集合""" + mock_control_api = MagicMock() + mock_control_api.delete_memory_collection_async = AsyncMock( + return_value=MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + result = await MemoryCollection.delete_by_name_async( + "test-memory-collection" + ) + assert result is not None + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_delete_instance_sync(self, mock_control_api_class): + """测试实例同步删除""" + mock_control_api = MagicMock() + mock_control_api.delete_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api.get_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + memory_collection = MemoryCollection.get_by_name( + "test-memory-collection" + ) + + # 删除实例 + result = memory_collection.delete() + assert result is not None + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_delete_instance_async(self, mock_control_api_class): + """测试实例异步删除""" + mock_control_api = MagicMock() + mock_control_api.delete_memory_collection_async = AsyncMock( + return_value=MockMemoryCollectionData() + ) + mock_control_api.get_memory_collection_async = AsyncMock( + return_value=MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + memory_collection = await MemoryCollection.get_by_name_async( + "test-memory-collection" + ) + + # 删除实例 + result = await memory_collection.delete_async() + assert result is not None + + +class TestMemoryCollectionUpdate: + """测试 MemoryCollection.update 方法""" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_update_by_name_sync(self, mock_control_api_class): + """测试根据名称同步更新记忆集合""" + mock_control_api = MagicMock() + mock_control_api.update_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + input_obj = MemoryCollectionUpdateInput( + description="Updated description" + ) + result = MemoryCollection.update_by_name( + "test-memory-collection", input_obj + ) + assert result is not None + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_update_by_name_async(self, mock_control_api_class): + """测试根据名称异步更新记忆集合""" + mock_control_api = MagicMock() + mock_control_api.update_memory_collection_async = AsyncMock( + return_value=MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + input_obj = MemoryCollectionUpdateInput(description="Updated") + result = await MemoryCollection.update_by_name_async( + "test-memory-collection", input_obj + ) + assert result is not None + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_update_instance_sync(self, mock_control_api_class): + """测试实例同步更新""" + mock_control_api = MagicMock() + mock_control_api.update_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api.get_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + memory_collection = MemoryCollection.get_by_name( + "test-memory-collection" + ) + + # 更新实例 + input_obj = MemoryCollectionUpdateInput(description="Updated") + result = memory_collection.update(input_obj) + assert result is not None + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_update_instance_async(self, mock_control_api_class): + """测试实例异步更新""" + mock_control_api = MagicMock() + mock_control_api.update_memory_collection_async = AsyncMock( + return_value=MockMemoryCollectionData() + ) + mock_control_api.get_memory_collection_async = AsyncMock( + return_value=MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + memory_collection = await MemoryCollection.get_by_name_async( + "test-memory-collection" + ) + + # 更新实例 + input_obj = MemoryCollectionUpdateInput(description="Updated") + result = await memory_collection.update_async(input_obj) + assert result is not None + + +class TestMemoryCollectionGet: + """测试 MemoryCollection.get_by_name 方法""" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_get_by_name_sync(self, mock_control_api_class): + """测试同步获取记忆集合""" + mock_control_api = MagicMock() + mock_control_api.get_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + result = MemoryCollection.get_by_name("test-memory-collection") + assert result.memory_collection_name == "test-memory-collection" + assert result.memory_collection_id == "mc-123" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_get_by_name_async(self, mock_control_api_class): + """测试异步获取记忆集合""" + mock_control_api = MagicMock() + mock_control_api.get_memory_collection_async = AsyncMock( + return_value=MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + result = await MemoryCollection.get_by_name_async( + "test-memory-collection" + ) + assert result.memory_collection_name == "test-memory-collection" + + +class TestMemoryCollectionList: + """测试 MemoryCollection.list_all 方法""" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_list_all_sync(self, mock_control_api_class): + """测试同步列出记忆集合""" + mock_control_api = MagicMock() + mock_control_api.list_memory_collections.return_value = MockListResult([ + MockMemoryCollectionData(), + MockMemoryCollectionData(), + ]) + mock_control_api_class.return_value = mock_control_api + + result = MemoryCollection.list_all() + # list_all 会对结果去重,所以相同 ID 的记录只会返回一个 + assert len(result) >= 1 + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_list_all_async(self, mock_control_api_class): + """测试异步列出记忆集合""" + mock_control_api = MagicMock() + mock_control_api.list_memory_collections_async = AsyncMock( + return_value=MockListResult([MockMemoryCollectionData()]) + ) + mock_control_api_class.return_value = mock_control_api + + result = await MemoryCollection.list_all_async() + assert len(result) == 1 + + +class TestMemoryCollectionRefresh: + """测试 MemoryCollection.refresh 方法""" + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + def test_refresh_sync(self, mock_control_api_class): + """测试同步刷新记忆集合""" + mock_control_api = MagicMock() + mock_control_api.get_memory_collection.return_value = ( + MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + memory_collection = MemoryCollection.get_by_name( + "test-memory-collection" + ) + + # 刷新实例 + memory_collection.refresh() + assert ( + memory_collection.memory_collection_name == "test-memory-collection" + ) + + @patch("agentrun.memory_collection.client.MemoryCollectionControlAPI") + @pytest.mark.asyncio + async def test_refresh_async(self, mock_control_api_class): + """测试异步刷新记忆集合""" + mock_control_api = MagicMock() + mock_control_api.get_memory_collection_async = AsyncMock( + return_value=MockMemoryCollectionData() + ) + mock_control_api_class.return_value = mock_control_api + + # 先获取实例 + memory_collection = await MemoryCollection.get_by_name_async( + "test-memory-collection" + ) + + # 刷新实例 + await memory_collection.refresh_async() + assert ( + memory_collection.memory_collection_name == "test-memory-collection" + ) + + +class TestMemoryCollectionFromInnerObject: + """测试 MemoryCollection.from_inner_object 方法""" + + def test_from_inner_object(self): + """测试从内部对象创建记忆集合""" + mock_data = MockMemoryCollectionData() + memory_collection = MemoryCollection.from_inner_object(mock_data) + + assert memory_collection.memory_collection_id == "mc-123" + assert ( + memory_collection.memory_collection_name == "test-memory-collection" + ) + assert memory_collection.description == "Test memory collection" + assert memory_collection.type == "vector" + + def test_from_inner_object_with_extra(self): + """测试从内部对象创建记忆集合(带额外字段)""" + mock_data = MockMemoryCollectionData() + extra = {"custom_field": "custom_value"} + memory_collection = MemoryCollection.from_inner_object(mock_data, extra) + + assert ( + memory_collection.memory_collection_name == "test-memory-collection" + ) diff --git a/tests/unittests/memory_collection/test_model.py b/tests/unittests/memory_collection/test_model.py new file mode 100644 index 0000000..8347b5d --- /dev/null +++ b/tests/unittests/memory_collection/test_model.py @@ -0,0 +1,327 @@ +"""测试 agentrun.memory_collection.model 模块 / Test agentrun.memory_collection.model module""" + +import pytest + +from agentrun.memory_collection.model import ( + EmbedderConfig, + EmbedderConfigConfig, + LLMConfig, + LLMConfigConfig, + MemoryCollectionCreateInput, + MemoryCollectionListInput, + MemoryCollectionListOutput, + MemoryCollectionUpdateInput, + NetworkConfiguration, + VectorStoreConfig, + VectorStoreConfigConfig, +) + + +class TestEmbedderConfigConfig: + """测试 EmbedderConfigConfig 模型""" + + def test_create_embedder_config_config(self): + """测试创建嵌入模型内部配置""" + config = EmbedderConfigConfig(model="text-embedding-3-small") + assert config.model == "text-embedding-3-small" + + def test_embedder_config_config_optional(self): + """测试嵌入模型内部配置可选字段""" + config = EmbedderConfigConfig() + assert config.model is None + + +class TestEmbedderConfig: + """测试 EmbedderConfig 模型""" + + def test_create_embedder_config(self): + """测试创建嵌入模型配置""" + config = EmbedderConfig( + model_service_name="test-embedder", + config=EmbedderConfigConfig(model="text-embedding-3-small"), + ) + assert config.model_service_name == "test-embedder" + assert config.config is not None + assert config.config.model == "text-embedding-3-small" + + def test_embedder_config_optional(self): + """测试嵌入模型配置可选字段""" + config = EmbedderConfig() + assert config.model_service_name is None + assert config.config is None + + +class TestLLMConfigConfig: + """测试 LLMConfigConfig 模型""" + + def test_create_llm_config_config(self): + """测试创建 LLM 内部配置""" + config = LLMConfigConfig(model="gpt-4") + assert config.model == "gpt-4" + + def test_llm_config_config_optional(self): + """测试 LLM 内部配置可选字段""" + config = LLMConfigConfig() + assert config.model is None + + +class TestLLMConfig: + """测试 LLMConfig 模型""" + + def test_create_llm_config(self): + """测试创建 LLM 配置""" + config = LLMConfig( + model_service_name="test-llm", + config=LLMConfigConfig(model="gpt-4"), + ) + assert config.model_service_name == "test-llm" + assert config.config is not None + assert config.config.model == "gpt-4" + + def test_llm_config_optional(self): + """测试 LLM 配置可选字段""" + config = LLMConfig() + assert config.model_service_name is None + assert config.config is None + + +class TestNetworkConfiguration: + """测试 NetworkConfiguration 模型""" + + def test_create_network_configuration(self): + """测试创建网络配置""" + config = NetworkConfiguration( + vpc_id="vpc-123", + vswitch_ids=["vsw-123", "vsw-456"], + security_group_id="sg-123", + network_mode="vpc", + ) + assert config.vpc_id == "vpc-123" + assert config.vswitch_ids == ["vsw-123", "vsw-456"] + assert config.security_group_id == "sg-123" + assert config.network_mode == "vpc" + + def test_network_configuration_optional(self): + """测试网络配置可选字段""" + config = NetworkConfiguration() + assert config.vpc_id is None + assert config.vswitch_ids is None + assert config.security_group_id is None + assert config.network_mode is None + + +class TestVectorStoreConfigConfig: + """测试 VectorStoreConfigConfig 模型""" + + def test_create_vector_store_config_config(self): + """测试创建向量存储内部配置""" + config = VectorStoreConfigConfig( + endpoint="https://test.dashvector.cn", + instance_name="test-instance", + collection_name="test-collection", + vector_dimension=1536, + ) + assert config.endpoint == "https://test.dashvector.cn" + assert config.instance_name == "test-instance" + assert config.collection_name == "test-collection" + assert config.vector_dimension == 1536 + + def test_vector_store_config_config_optional(self): + """测试向量存储内部配置可选字段""" + config = VectorStoreConfigConfig() + assert config.endpoint is None + assert config.instance_name is None + assert config.collection_name is None + assert config.vector_dimension is None + + +class TestVectorStoreConfig: + """测试 VectorStoreConfig 模型""" + + def test_create_vector_store_config(self): + """测试创建向量存储配置""" + config = VectorStoreConfig( + provider="dashvector", + config=VectorStoreConfigConfig( + endpoint="https://test.dashvector.cn", + instance_name="test-instance", + collection_name="test-collection", + vector_dimension=1536, + ), + ) + assert config.provider == "dashvector" + assert config.config is not None + assert config.config.endpoint == "https://test.dashvector.cn" + + def test_vector_store_config_optional(self): + """测试向量存储配置可选字段""" + config = VectorStoreConfig() + assert config.provider is None + assert config.config is None + + +class TestMemoryCollectionCreateInput: + """测试 MemoryCollectionCreateInput 模型""" + + def test_create_minimal_input(self): + """测试创建最小输入参数""" + input_obj = MemoryCollectionCreateInput( + memory_collection_name="test-memory-collection", + type="vector", + ) + assert input_obj.memory_collection_name == "test-memory-collection" + assert input_obj.type == "vector" + + def test_create_full_input(self): + """测试创建完整输入参数""" + input_obj = MemoryCollectionCreateInput( + memory_collection_name="test-memory-collection", + type="vector", + description="Test memory collection", + embedder_config=EmbedderConfig( + model_service_name="test-embedder", + config=EmbedderConfigConfig(model="text-embedding-3-small"), + ), + llm_config=LLMConfig( + model_service_name="test-llm", + config=LLMConfigConfig(model="gpt-4"), + ), + vector_store_config=VectorStoreConfig( + provider="dashvector", + config=VectorStoreConfigConfig( + endpoint="https://test.dashvector.cn", + instance_name="test-instance", + collection_name="test-collection", + vector_dimension=1536, + ), + ), + network_configuration=NetworkConfiguration( + vpc_id="vpc-123", + vswitch_ids=["vsw-123"], + security_group_id="sg-123", + network_mode="vpc", + ), + execution_role_arn="acs:ram::123:role/test", + ) + assert input_obj.memory_collection_name == "test-memory-collection" + assert input_obj.description == "Test memory collection" + assert input_obj.embedder_config is not None + assert input_obj.llm_config is not None + assert input_obj.vector_store_config is not None + assert input_obj.network_configuration is not None + assert input_obj.execution_role_arn == "acs:ram::123:role/test" + + def test_model_dump(self): + """测试模型序列化""" + input_obj = MemoryCollectionCreateInput( + memory_collection_name="test-memory-collection", + type="vector", + description="Test", + ) + data = input_obj.model_dump() + # 检查序列化后的数据包含必要字段 + assert ( + "memory_collection_name" in data or "memoryCollectionName" in data + ) + assert input_obj.memory_collection_name == "test-memory-collection" + assert input_obj.type == "vector" + assert input_obj.description == "Test" + + +class TestMemoryCollectionUpdateInput: + """测试 MemoryCollectionUpdateInput 模型""" + + def test_create_update_input(self): + """测试创建更新输入参数""" + input_obj = MemoryCollectionUpdateInput( + description="Updated description", + ) + assert input_obj.description == "Updated description" + + def test_update_input_with_embedder_config(self): + """测试更新输入参数(带嵌入模型配置)""" + input_obj = MemoryCollectionUpdateInput( + embedder_config=EmbedderConfig( + model_service_name="new-embedder", + config=EmbedderConfigConfig(model="text-embedding-ada-002"), + ) + ) + assert input_obj.embedder_config is not None + assert input_obj.embedder_config.model_service_name == "new-embedder" + + def test_update_input_with_llm_config(self): + """测试更新输入参数(带 LLM 配置)""" + input_obj = MemoryCollectionUpdateInput( + llm_config=LLMConfig( + model_service_name="new-llm", + config=LLMConfigConfig(model="gpt-4-turbo"), + ) + ) + assert input_obj.llm_config is not None + assert input_obj.llm_config.config is not None + assert input_obj.llm_config.config.model == "gpt-4-turbo" + + def test_update_input_optional(self): + """测试更新输入参数可选字段""" + input_obj = MemoryCollectionUpdateInput() + assert input_obj.description is None + assert input_obj.embedder_config is None + assert input_obj.llm_config is None + + +class TestMemoryCollectionListInput: + """测试 MemoryCollectionListInput 模型""" + + def test_create_list_input(self): + """测试创建列表输入参数""" + input_obj = MemoryCollectionListInput( + page_number=1, + page_size=10, + memory_collection_name="test-memory-collection", + ) + assert input_obj.page_number == 1 + assert input_obj.page_size == 10 + assert input_obj.memory_collection_name == "test-memory-collection" + + def test_list_input_default(self): + """测试列表输入参数默认值""" + input_obj = MemoryCollectionListInput() + assert input_obj.memory_collection_name is None + + def test_list_input_with_pagination(self): + """测试列表输入参数(带分页)""" + input_obj = MemoryCollectionListInput( + page_number=2, + page_size=20, + ) + assert input_obj.page_number == 2 + assert input_obj.page_size == 20 + + +class TestMemoryCollectionListOutput: + """测试 MemoryCollectionListOutput 模型""" + + def test_create_list_output(self): + """测试创建列表输出""" + output = MemoryCollectionListOutput( + memory_collection_id="mc-123", + memory_collection_name="test-memory-collection", + description="Test memory collection", + type="vector", + created_at="2024-01-01T00:00:00Z", + last_updated_at="2024-01-01T00:00:00Z", + ) + assert output.memory_collection_id == "mc-123" + assert output.memory_collection_name == "test-memory-collection" + assert output.description == "Test memory collection" + assert output.type == "vector" + + def test_list_output_optional(self): + """测试列表输出可选字段""" + output = MemoryCollectionListOutput() + assert output.memory_collection_id is None + assert output.memory_collection_name is None + assert output.description is None + assert output.type is None + assert output.created_at is None + assert output.last_updated_at is None