From de482e264cdb1ccdf1108a0c891f898fb334164d Mon Sep 17 00:00:00 2001 From: Terrdi <675464934@qq.com> Date: Tue, 11 Feb 2025 11:47:38 +0800 Subject: [PATCH 1/2] In the example of werewolf, when the model call fails, it will try to serialize the Weregame (inherited from Team) class into Json and save it, but the BasePlayer::special_actions list may contain the Action class, causing the serialization to fail. Add custom serialization fields to the Team class. --- metagpt/ext/werewolf/schema.py | 35 +++++++++++++++++++++-- metagpt/ext/werewolf/werewolf_game.py | 3 +- metagpt/team.py | 4 ++- metagpt/utils/common.py | 4 +-- setup.py | 2 +- tests/metagpt/ext/werewolf/test_schema.py | 15 ++++++++++ 6 files changed, 56 insertions(+), 7 deletions(-) create mode 100644 tests/metagpt/ext/werewolf/test_schema.py diff --git a/metagpt/ext/werewolf/schema.py b/metagpt/ext/werewolf/schema.py index 1502a2391b..57eab6ad87 100644 --- a/metagpt/ext/werewolf/schema.py +++ b/metagpt/ext/werewolf/schema.py @@ -1,10 +1,10 @@ from typing import Any - +import json from pydantic import BaseModel, Field, field_validator from metagpt.schema import Message from metagpt.utils.common import any_to_str_set - +from metagpt.configs.llm_config import LLMType class RoleExperience(BaseModel): id: str = "" @@ -31,3 +31,34 @@ class WwMessage(Message): @classmethod def check_restricted_to(cls, restricted_to: Any): return any_to_str_set(restricted_to if restricted_to else set()) + +def wrapper_none_error(func): + """Wrapper for ValueError""" + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except ValueError as e: + # the to_jsonable_python of pydantic will raise PydanticSerializationError + # return None to call the custom JSONEncoder + return None + return wrapper + + +class WwJsonEncoder(json.JSONEncoder): + def __init__(self, *, skipkeys=False, ensure_ascii=True, + check_circular=True, allow_nan=True, sort_keys=False, + indent=None, separators=None, default=None): + super().__init__(skipkeys=skipkeys, ensure_ascii=ensure_ascii, + check_circular=check_circular, allow_nan=allow_nan, sort_keys=sort_keys, + indent=indent, separators=separators, default=default) + if default is not None: + self.default = wrapper_none_error(default) + + + def _default(self, obj): + if isinstance(obj, type): # handle class + return { + "__type__": obj.__name__, + "__module__": obj.__module__ + } + return super().default(obj) diff --git a/metagpt/ext/werewolf/werewolf_game.py b/metagpt/ext/werewolf/werewolf_game.py index 4deb831a00..81691a5832 100644 --- a/metagpt/ext/werewolf/werewolf_game.py +++ b/metagpt/ext/werewolf/werewolf_game.py @@ -3,7 +3,7 @@ from metagpt.actions.add_requirement import UserRequirement from metagpt.context import Context from metagpt.environment.werewolf.werewolf_env import WerewolfEnv -from metagpt.ext.werewolf.schema import WwMessage +from metagpt.ext.werewolf.schema import WwMessage, WwJsonEncoder from metagpt.team import Team @@ -14,6 +14,7 @@ class WerewolfGame(Team): def __init__(self, context: Context = None, **data: Any): super(Team, self).__init__(**data) + self.json_encoder = WwJsonEncoder ctx = context or Context() if not self.env: self.env = WerewolfEnv(context=ctx) diff --git a/metagpt/team.py b/metagpt/team.py index 2288f9748d..f831c16edd 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -9,6 +9,7 @@ """ import warnings +import json from pathlib import Path from typing import Any, Optional @@ -40,6 +41,7 @@ class Team(BaseModel): env: Optional[Environment] = None investment: float = Field(default=10.0) idea: str = Field(default="") + json_encoder: json.JSONEncoder = Field(default=None, exclude=True) def __init__(self, context: Context = None, **data: Any): super(Team, self).__init__(**data) @@ -59,7 +61,7 @@ def serialize(self, stg_path: Path = None): serialized_data = self.model_dump() serialized_data["context"] = self.env.context.serialize() - write_json_file(team_info_path, serialized_data) + write_json_file(team_info_path, serialized_data, cls=self.json_encoder) @classmethod def deserialize(cls, stg_path: Path, context: Context = None) -> "Team": diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 47f2768cd2..a01d2a240b 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -572,13 +572,13 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]: return data -def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4): +def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4, cls: json.JSONEncoder=None): folder_path = Path(json_file).parent if not folder_path.exists(): folder_path.mkdir(parents=True, exist_ok=True) with open(json_file, "w", encoding=encoding) as fout: - json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python) + json.dump(data, fout, ensure_ascii=False, cls=cls, indent=indent, default=to_jsonable_python) def read_jsonl_file(jsonl_file: str, encoding="utf-8") -> list[dict]: diff --git a/setup.py b/setup.py index 2ffc09ee81..658c82219f 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ def run(self): "llama-index-postprocessor-cohere-rerank==0.1.4", "llama-index-postprocessor-colbert-rerank==0.1.1", "llama-index-postprocessor-flag-embedding-reranker==0.1.2", - # "llama-index-vector-stores-milvus==0.1.23", + "llama-index-vector-stores-milvus==0.1.23", "docx2txt==0.8", ], } diff --git a/tests/metagpt/ext/werewolf/test_schema.py b/tests/metagpt/ext/werewolf/test_schema.py new file mode 100644 index 0000000000..06533beff3 --- /dev/null +++ b/tests/metagpt/ext/werewolf/test_schema.py @@ -0,0 +1,15 @@ +from metagpt.ext.werewolf.schema import WwJsonEncoder +from metagpt.ext.werewolf.actions.common_actions import Speak +from metagpt.environment.werewolf.const import RoleType, RoleState, RoleActionRes +import json +from metagpt.utils.common import to_jsonable_python +def test_ww_json_encoder(): + encoder = WwJsonEncoder + data = { + "test": RoleType.VILLAGER, + "test2": RoleState.ALIVE, + "test3": RoleActionRes.PASS, + "test4": [Speak], + } + encoded = json.dumps(data, cls=encoder, default=to_jsonable_python) + # print(encoded) \ No newline at end of file From d0dc01c06694a37101d2c1bdacc4ca22d41c9427 Mon Sep 17 00:00:00 2001 From: Terrdi <675464934@qq.com> Date: Wed, 19 Feb 2025 15:05:56 +0800 Subject: [PATCH 2/2] The aflow example will use json to save the results. The json contains a field total_cost which is `numpy.int64` and cannot be serialized by json. --- .../scripts/optimizer_utils/data_utils.py | 29 ++++++++++++++++++- metagpt/ext/werewolf/schema.py | 21 -------------- metagpt/utils/common.py | 13 ++++++++- 3 files changed, 40 insertions(+), 23 deletions(-) diff --git a/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py b/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py index 2a09e08201..0cbd073a21 100644 --- a/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py +++ b/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py @@ -125,7 +125,7 @@ def create_result_data(self, round: int, score: float, avg_cost: float, total_co return {"round": round, "score": score, "avg_cost": avg_cost, "total_cost": total_cost, "time": now} def save_results(self, json_file_path: str, data: list): - write_json_file(json_file_path, data, encoding="utf-8", indent=4) + write_json_file(json_file_path, data, encoding="utf-8", indent=4, cls=NumpyJSONEncoder) def _load_scores(self, path=None, mode="Graph"): if mode == "Graph": @@ -147,3 +147,30 @@ def _load_scores(self, path=None, mode="Graph"): self.top_scores.sort(key=lambda x: x["score"], reverse=True) return self.top_scores + +class NumpyJSONEncoder(json.JSONEncoder): + """customized JSON encoder for numpy data type + + features: + 1. support numpy array serialization (automatically convert to list) + 2. support numpy scalar type (int32, float64, etc.) serialization + 3. keep the original data type precision + 4. compatible with regular JSON data types + """ + def default(self, obj): + """override the default serialize method""" + + # handle numpy array type + if isinstance(obj, np.ndarray): + return { + '__ndarray__': obj.tolist(), # 转换为Python列表 + 'dtype': str(obj.dtype), # 保留数据类型信息 + 'shape': obj.shape # 保留数组形状 + } + + # handle numpy scalar type + elif isinstance(obj, np.generic): + return obj.item() # convert to python native type + + # handle other type using default method + return super().default(obj) \ No newline at end of file diff --git a/metagpt/ext/werewolf/schema.py b/metagpt/ext/werewolf/schema.py index 57eab6ad87..3304fe9558 100644 --- a/metagpt/ext/werewolf/schema.py +++ b/metagpt/ext/werewolf/schema.py @@ -32,29 +32,8 @@ class WwMessage(Message): def check_restricted_to(cls, restricted_to: Any): return any_to_str_set(restricted_to if restricted_to else set()) -def wrapper_none_error(func): - """Wrapper for ValueError""" - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except ValueError as e: - # the to_jsonable_python of pydantic will raise PydanticSerializationError - # return None to call the custom JSONEncoder - return None - return wrapper - class WwJsonEncoder(json.JSONEncoder): - def __init__(self, *, skipkeys=False, ensure_ascii=True, - check_circular=True, allow_nan=True, sort_keys=False, - indent=None, separators=None, default=None): - super().__init__(skipkeys=skipkeys, ensure_ascii=ensure_ascii, - check_circular=check_circular, allow_nan=allow_nan, sort_keys=sort_keys, - indent=indent, separators=separators, default=default) - if default is not None: - self.default = wrapper_none_error(default) - - def _default(self, obj): if isinstance(obj, type): # handle class return { diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index a01d2a240b..50ba7ba48e 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -571,6 +571,16 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]: raise ValueError(f"read json file: {json_file} failed") return data +def wrapper_none_error(func): + """Wrapper for ValueError""" + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except ValueError as e: + # the to_jsonable_python of pydantic will raise PydanticSerializationError + # return None to call the custom JSONEncoder + return None + return wrapper def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4, cls: json.JSONEncoder=None): folder_path = Path(json_file).parent @@ -578,7 +588,8 @@ def write_json_file(json_file: str, data: list, encoding: str = None, indent: in folder_path.mkdir(parents=True, exist_ok=True) with open(json_file, "w", encoding=encoding) as fout: - json.dump(data, fout, ensure_ascii=False, cls=cls, indent=indent, default=to_jsonable_python) + json.dump(data, fout, ensure_ascii=False, cls=cls, indent=indent, + default=wrapper_none_error(to_jsonable_python) if cls else to_jsonable_python) def read_jsonl_file(jsonl_file: str, encoding="utf-8") -> list[dict]: