diff --git a/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py b/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py index 2a09e08201..0cbd073a21 100644 --- a/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py +++ b/metagpt/ext/aflow/scripts/optimizer_utils/data_utils.py @@ -125,7 +125,7 @@ def create_result_data(self, round: int, score: float, avg_cost: float, total_co return {"round": round, "score": score, "avg_cost": avg_cost, "total_cost": total_cost, "time": now} def save_results(self, json_file_path: str, data: list): - write_json_file(json_file_path, data, encoding="utf-8", indent=4) + write_json_file(json_file_path, data, encoding="utf-8", indent=4, cls=NumpyJSONEncoder) def _load_scores(self, path=None, mode="Graph"): if mode == "Graph": @@ -147,3 +147,30 @@ def _load_scores(self, path=None, mode="Graph"): self.top_scores.sort(key=lambda x: x["score"], reverse=True) return self.top_scores + +class NumpyJSONEncoder(json.JSONEncoder): + """customized JSON encoder for numpy data type + + features: + 1. support numpy array serialization (automatically convert to list) + 2. support numpy scalar type (int32, float64, etc.) serialization + 3. keep the original data type precision + 4. compatible with regular JSON data types + """ + def default(self, obj): + """override the default serialize method""" + + # handle numpy array type + if isinstance(obj, np.ndarray): + return { + '__ndarray__': obj.tolist(), # 转换为Python列表 + 'dtype': str(obj.dtype), # 保留数据类型信息 + 'shape': obj.shape # 保留数组形状 + } + + # handle numpy scalar type + elif isinstance(obj, np.generic): + return obj.item() # convert to python native type + + # handle other type using default method + return super().default(obj) \ No newline at end of file diff --git a/metagpt/ext/werewolf/schema.py b/metagpt/ext/werewolf/schema.py index 1502a2391b..3304fe9558 100644 --- a/metagpt/ext/werewolf/schema.py +++ b/metagpt/ext/werewolf/schema.py @@ -1,10 +1,10 @@ from typing import Any - +import json from pydantic import BaseModel, Field, field_validator from metagpt.schema import Message from metagpt.utils.common import any_to_str_set - +from metagpt.configs.llm_config import LLMType class RoleExperience(BaseModel): id: str = "" @@ -31,3 +31,13 @@ class WwMessage(Message): @classmethod def check_restricted_to(cls, restricted_to: Any): return any_to_str_set(restricted_to if restricted_to else set()) + + +class WwJsonEncoder(json.JSONEncoder): + def _default(self, obj): + if isinstance(obj, type): # handle class + return { + "__type__": obj.__name__, + "__module__": obj.__module__ + } + return super().default(obj) diff --git a/metagpt/ext/werewolf/werewolf_game.py b/metagpt/ext/werewolf/werewolf_game.py index 4deb831a00..81691a5832 100644 --- a/metagpt/ext/werewolf/werewolf_game.py +++ b/metagpt/ext/werewolf/werewolf_game.py @@ -3,7 +3,7 @@ from metagpt.actions.add_requirement import UserRequirement from metagpt.context import Context from metagpt.environment.werewolf.werewolf_env import WerewolfEnv -from metagpt.ext.werewolf.schema import WwMessage +from metagpt.ext.werewolf.schema import WwMessage, WwJsonEncoder from metagpt.team import Team @@ -14,6 +14,7 @@ class WerewolfGame(Team): def __init__(self, context: Context = None, **data: Any): super(Team, self).__init__(**data) + self.json_encoder = WwJsonEncoder ctx = context or Context() if not self.env: self.env = WerewolfEnv(context=ctx) diff --git a/metagpt/team.py b/metagpt/team.py index 2288f9748d..f831c16edd 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -9,6 +9,7 @@ """ import warnings +import json from pathlib import Path from typing import Any, Optional @@ -40,6 +41,7 @@ class Team(BaseModel): env: Optional[Environment] = None investment: float = Field(default=10.0) idea: str = Field(default="") + json_encoder: json.JSONEncoder = Field(default=None, exclude=True) def __init__(self, context: Context = None, **data: Any): super(Team, self).__init__(**data) @@ -59,7 +61,7 @@ def serialize(self, stg_path: Path = None): serialized_data = self.model_dump() serialized_data["context"] = self.env.context.serialize() - write_json_file(team_info_path, serialized_data) + write_json_file(team_info_path, serialized_data, cls=self.json_encoder) @classmethod def deserialize(cls, stg_path: Path, context: Context = None) -> "Team": diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 47f2768cd2..50ba7ba48e 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -571,14 +571,25 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]: raise ValueError(f"read json file: {json_file} failed") return data +def wrapper_none_error(func): + """Wrapper for ValueError""" + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except ValueError as e: + # the to_jsonable_python of pydantic will raise PydanticSerializationError + # return None to call the custom JSONEncoder + return None + return wrapper -def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4): +def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4, cls: json.JSONEncoder=None): folder_path = Path(json_file).parent if not folder_path.exists(): folder_path.mkdir(parents=True, exist_ok=True) with open(json_file, "w", encoding=encoding) as fout: - json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python) + json.dump(data, fout, ensure_ascii=False, cls=cls, indent=indent, + default=wrapper_none_error(to_jsonable_python) if cls else to_jsonable_python) def read_jsonl_file(jsonl_file: str, encoding="utf-8") -> list[dict]: diff --git a/setup.py b/setup.py index 2ffc09ee81..658c82219f 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ def run(self): "llama-index-postprocessor-cohere-rerank==0.1.4", "llama-index-postprocessor-colbert-rerank==0.1.1", "llama-index-postprocessor-flag-embedding-reranker==0.1.2", - # "llama-index-vector-stores-milvus==0.1.23", + "llama-index-vector-stores-milvus==0.1.23", "docx2txt==0.8", ], } diff --git a/tests/metagpt/ext/werewolf/test_schema.py b/tests/metagpt/ext/werewolf/test_schema.py new file mode 100644 index 0000000000..06533beff3 --- /dev/null +++ b/tests/metagpt/ext/werewolf/test_schema.py @@ -0,0 +1,15 @@ +from metagpt.ext.werewolf.schema import WwJsonEncoder +from metagpt.ext.werewolf.actions.common_actions import Speak +from metagpt.environment.werewolf.const import RoleType, RoleState, RoleActionRes +import json +from metagpt.utils.common import to_jsonable_python +def test_ww_json_encoder(): + encoder = WwJsonEncoder + data = { + "test": RoleType.VILLAGER, + "test2": RoleState.ALIVE, + "test3": RoleActionRes.PASS, + "test4": [Speak], + } + encoded = json.dumps(data, cls=encoder, default=to_jsonable_python) + # print(encoded) \ No newline at end of file