diff --git a/bot/dashscope/dashscope_session.py b/bot/dashscope/dashscope_session.py index 0de57b926..e432f67a4 100644 --- a/bot/dashscope/dashscope_session.py +++ b/bot/dashscope/dashscope_session.py @@ -4,7 +4,7 @@ class DashscopeSession(Session): def __init__(self, session_id, system_prompt=None, model="qwen-turbo"): - super().__init__(session_id) + super().__init__(session_id, system_prompt=system_prompt) self.reset() def discard_exceeding(self, max_tokens, cur_tokens=None): diff --git a/bot/session_manager.py b/bot/session_manager.py index a6e89f956..ce2fe2968 100644 --- a/bot/session_manager.py +++ b/bot/session_manager.py @@ -19,7 +19,8 @@ def reset(self): def set_system_prompt(self, system_prompt): self.system_prompt = system_prompt - self.reset() + if self.messages and self.messages[0]["role"] == "system": + self.messages[0]["content"] = system_prompt def add_query(self, query): user_item = {"role": "user", "content": query} @@ -46,18 +47,18 @@ def __init__(self, sessioncls, **session_args): self.sessioncls = sessioncls self.session_args = session_args - def build_session(self, session_id, system_prompt=None): + def build_session(self, session_id): """ 如果session_id不在sessions中,创建一个新的session并添加到sessions中 如果system_prompt不会空,会更新session的system_prompt并重置session """ if session_id is None: - return self.sessioncls(session_id, system_prompt, **self.session_args) - + return self.sessioncls(session_id, **self.session_args) + if session_id not in self.sessions: - self.sessions[session_id] = self.sessioncls(session_id, system_prompt, **self.session_args) - elif system_prompt is not None: # 如果有新的system_prompt,更新并重置session - self.sessions[session_id].set_system_prompt(system_prompt) + self.sessions[session_id] = self.sessioncls(session_id, **self.session_args) + elif self.session_args.get("system_prompt"): # 如果有新的system_prompt,更新并重置session + self.sessions[session_id].set_system_prompt(self.session_args.get("system_prompt")) session = self.sessions[session_id] return session diff --git a/plugins/role/role.py b/plugins/role/role.py index 8890a6299..e2e80f6c1 100644 --- a/plugins/role/role.py +++ b/plugins/role/role.py @@ -19,7 +19,8 @@ def __init__(self, bot, sessionid, desc, wrapper=None): self.sessionid = sessionid self.wrapper = wrapper or "%s" # 用于包装用户输入 self.desc = desc - self.bot.sessions.build_session(self.sessionid, system_prompt=self.desc) + self.bot.sessions.session_args.update({"system_prompt": self.desc}) + self.bot.sessions.build_session(self.sessionid) def reset(self): self.bot.sessions.clear_session(self.sessionid) @@ -28,6 +29,7 @@ def action(self, user_action): session = self.bot.sessions.build_session(self.sessionid) if session.system_prompt != self.desc: # 目前没有触发session过期事件,这里先简单判断,然后重置 session.set_system_prompt(self.desc) + session.reset() prompt = self.wrapper % user_action return prompt