diff --git a/examples/hello_world.py b/examples/hello_world.py index 641ac0dcf7..bcda95e11b 100644 --- a/examples/hello_world.py +++ b/examples/hello_world.py @@ -13,7 +13,13 @@ async def ask_and_print(question: str, llm: LLM, system_prompt) -> str: logger.info(f"Q: {question}") - rsp = await llm.aask(question, system_msgs=[system_prompt], stream=True) + + context = 'agent name or else' + def stream_callback(content) -> None: + print(context, content) + # use websocket send content or do something else + + rsp = await llm.aask(question, system_msgs=[system_prompt], stream=True, stream_callback=stream_callback) if hasattr(llm, "reasoning_content") and llm.reasoning_content: logger.info(f"A reasoning: {llm.reasoning_content}") logger.info(f"A: {rsp}") @@ -29,7 +35,13 @@ async def lowlevel_api_example(llm: LLM): logger.info(await llm.acompletion_text(hello_msg)) # streaming mode, much slower - await llm.acompletion_text(hello_msg, stream=True) + + context = 'agent name or else' + def stream_callback(content) -> None: + print(context, content) + # use websocket send content or do something else + + await llm.acompletion_text(hello_msg, stream=True, stream_callback=stream_callback) # check completion if exist to test llm complete functions if hasattr(llm, "completion"): diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index 7fae4939a2..b527a59624 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -57,7 +57,7 @@ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFI async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) - async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str: stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True)) collected_content = [] collected_reasoning_content = [] @@ -74,6 +74,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = US elif delta_type == "text_delta": content = event.delta.text log_llm_stream(content) + if stream_callback: + stream_callback(content) collected_content.append(content) elif event_type == "message_delta": usage.output_tokens = event.usage.output_tokens # update final output_tokens diff --git a/metagpt/provider/ark_api.py b/metagpt/provider/ark_api.py index 0c5704b910..45c5bd9fb9 100644 --- a/metagpt/provider/ark_api.py +++ b/metagpt/provider/ark_api.py @@ -71,7 +71,7 @@ def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_ if self.pricing_plan in self.cost_manager.token_costs: super()._update_costs(usage, self.pricing_plan, local_calc_usage) - async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, stream_callback = None) -> str: response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages, timeout=self.get_timeout(timeout)), stream=True, @@ -82,6 +82,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFI async for chunk in response: chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message log_llm_stream(chunk_message) + if stream_callback: + stream_callback(chunk_message) collected_messages.append(chunk_message) if chunk.usage: # 火山方舟的流式调用会在最后一个chunk中返回usage,最后一个chunk的choices为[] diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 781963434f..7bbfa91bb1 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -184,6 +184,7 @@ async def aask( images: Optional[Union[str, list[str]]] = None, timeout=USE_CONFIG_TIMEOUT, stream=None, + stream_callback=None, ) -> str: if system_msgs: message = self._system_msgs(system_msgs) @@ -205,7 +206,7 @@ async def aask( logger.debug(masked_message) compressed_message = self.compress_messages(message, compress_type=self.config.compress_type) - rsp = await self.acompletion_text(compressed_message, stream=stream, timeout=self.get_timeout(timeout)) + rsp = await self.acompletion_text(compressed_message, stream=stream, timeout=self.get_timeout(timeout), stream_callback=stream_callback) # rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout)) return rsp @@ -243,7 +244,7 @@ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): """ @abstractmethod - async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str: """_achat_completion_stream implemented by inherited class""" @retry( @@ -254,11 +255,11 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = US retry_error_callback=log_and_reraise, ) async def acompletion_text( - self, messages: list[dict], stream: bool = False, timeout: int = USE_CONFIG_TIMEOUT + self, messages: list[dict], stream: bool = False, timeout: int = USE_CONFIG_TIMEOUT, stream_callback=None ) -> str: """Asynchronous version of completion. Return str. Support stream-print""" if stream: - return await self._achat_completion_stream(messages, timeout=self.get_timeout(timeout)) + return await self._achat_completion_stream(messages, timeout=self.get_timeout(timeout), stream_callback=stream_callback) resp = await self._achat_completion(messages, timeout=self.get_timeout(timeout)) return self.get_choice_text(resp) diff --git a/metagpt/provider/bedrock_api.py b/metagpt/provider/bedrock_api.py index c4d9e21834..cd122850c3 100644 --- a/metagpt/provider/bedrock_api.py +++ b/metagpt/provider/bedrock_api.py @@ -119,11 +119,13 @@ async def acompletion(self, messages: list[dict]) -> dict: async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: return await self.acompletion(messages) - async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, stream_callback = None) -> str: if self.model in NOT_SUPPORT_STREAM_MODELS: rsp = await self.acompletion(messages) full_text = self.get_choice_text(rsp) log_llm_stream(full_text) + if stream_callback: + stream_callback(full_text) return full_text request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True) diff --git a/metagpt/provider/dashscope_api.py b/metagpt/provider/dashscope_api.py index 837377edcd..42797475a1 100644 --- a/metagpt/provider/dashscope_api.py +++ b/metagpt/provider/dashscope_api.py @@ -216,7 +216,7 @@ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFI async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> GenerationOutput: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) - async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str: resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True)) collected_content = [] usage = {} @@ -225,6 +225,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = US content = chunk.output.choices[0]["message"]["content"] usage = dict(chunk.usage) # each chunk has usage log_llm_stream(content) + if stream_callback: + stream_callback(content) collected_content.append(content) log_llm_stream("\n") self._update_costs(usage) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index de534216d9..bbaa84fb48 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -136,7 +136,7 @@ async def _achat_completion( async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) - async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str: resp: AsyncGenerateContentResponse = await self.llm.generate_content_async( **self._const_kwargs(messages, stream=True) ) @@ -148,6 +148,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = US logger.warning(f"messages: {messages}\nerrors: {e}\n{BlockedPromptException(str(chunk))}") raise BlockedPromptException(str(chunk)) log_llm_stream(content) + if stream_callback: + stream_callback(content) collected_content.append(content) log_llm_stream("\n") diff --git a/metagpt/provider/human_provider.py b/metagpt/provider/human_provider.py index 9c032d4b33..e372c3d383 100644 --- a/metagpt/provider/human_provider.py +++ b/metagpt/provider/human_provider.py @@ -45,10 +45,10 @@ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): """dummy implementation of abstract method in base""" return [] - async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str: pass - async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT, stream_callback = None) -> str: """dummy implementation of abstract method in base""" return "" diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 1663bf2b76..ad89682049 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -239,7 +239,7 @@ def get_choice_text(self, rsp): async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) - async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str: resp, _, _ = await self.client.arequest( method=self.http_method, url=self.ollama_message.api_suffix, @@ -248,7 +248,7 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = US stream=True, ) if isinstance(resp, AsyncGenerator): - return await self._processing_openai_response_async_generator(resp) + return await self._processing_openai_response_async_generator(resp, stream_callback=stream_callback) elif isinstance(resp, OpenAIResponse): return self._processing_openai_response(resp) else: @@ -260,7 +260,7 @@ def _processing_openai_response(self, openai_resp: OpenAIResponse): self._update_costs(usage) return resp - async def _processing_openai_response_async_generator(self, ag_openai_resp: AsyncGenerator[OpenAIResponse, None]): + async def _processing_openai_response_async_generator(self, ag_openai_resp: AsyncGenerator[OpenAIResponse, None], stream_callback = None): collected_content = [] usage = {} async for raw_chunk in ag_openai_resp: @@ -270,6 +270,8 @@ async def _processing_openai_response_async_generator(self, ag_openai_resp: Asyn content = self.ollama_message.get_choice(chunk) collected_content.append(content) log_llm_stream(content) + if stream_callback: + stream_callback(content) else: # stream finished usage = self.get_usage(chunk) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 877bd71383..6ae70383ac 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -89,7 +89,7 @@ def _get_proxy_params(self) -> dict: return params - async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, stream_callback = None) -> str: response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages, timeout=self.get_timeout(timeout)), stream=True ) @@ -109,6 +109,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFI chunk_message = choice_delta.content or "" # extract the message finish_reason = choice0.finish_reason if hasattr(choice0, "finish_reason") else None log_llm_stream(chunk_message) + if stream_callback: + stream_callback(chunk_message) collected_messages.append(chunk_message) chunk_has_usage = hasattr(chunk, "usage") and chunk.usage if has_finished: @@ -169,10 +171,10 @@ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> retry=retry_if_exception_type(APIConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT, stream_callback = None) -> str: """when streaming, print each token in place.""" if stream: - return await self._achat_completion_stream(messages, timeout=timeout) + return await self._achat_completion_stream(messages, timeout=timeout, stream_callback=stream_callback) rsp = await self._achat_completion(messages, timeout=self.get_timeout(timeout)) return self.get_choice_text(rsp) diff --git a/metagpt/provider/openrouter_reasoning.py b/metagpt/provider/openrouter_reasoning.py index 9a324821a9..8145a328fd 100644 --- a/metagpt/provider/openrouter_reasoning.py +++ b/metagpt/provider/openrouter_reasoning.py @@ -56,7 +56,7 @@ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFI async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) - async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str: self.headers["Content-Type"] = "text/event-stream" # update header to adapt the client payload = self._const_kwargs(messages, stream=True) resp, _, _ = await self.client.arequest( @@ -75,6 +75,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = US elif delta["content"]: collected_content.append(delta["content"]) log_llm_stream(delta["content"]) + if stream_callback: + stream_callback(delta["content"]) usage = chunk.get("usage") diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py index d0a95e734f..8e04a49090 100644 --- a/metagpt/provider/qianfan_api.py +++ b/metagpt/provider/qianfan_api.py @@ -120,7 +120,7 @@ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFI async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) - async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str: resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True), request_timeout=timeout) collected_content = [] usage = {} @@ -128,6 +128,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = US content = chunk.body.get("result", "") usage = chunk.body.get("usage", {}) log_llm_stream(content) + if stream_callback: + stream_callback(content) collected_content.append(content) log_llm_stream("\n") diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 8a38d99c55..bbef46938b 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -69,13 +69,15 @@ async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEO async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT): return await self._achat_completion(messages, timeout) - async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str: response = await self.acreate(messages, stream=True) collected_content = [] usage = {} async for chunk in response: collected_content.append(chunk.content) log_llm_stream(chunk.content) + if stream_callback: + stream_callback(chunk.content) if hasattr(chunk, "additional_kwargs"): usage = chunk.additional_kwargs.get("token_usage", {}) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index acac44aaf0..164b025866 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -69,7 +69,7 @@ async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEO async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: return await self._achat_completion(messages, timeout=self.get_timeout(timeout)) - async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, stream_callback = None) -> str: response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True)) collected_content = [] usage = {} @@ -81,6 +81,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFI content = self.get_choice_delta_text(chunk) collected_content.append(content) log_llm_stream(content) + if stream_callback: + stream_callback(content) log_llm_stream("\n") diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index d1a4bd66dc..3efb00eae2 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -36,10 +36,10 @@ async def _achat_completion(self, messages: list[dict], timeout=3): async def acompletion(self, messages: list[dict], timeout=3): return get_part_chat_completion(name) - async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3, stream_callback = None) -> str: pass - async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=3, stream_callback = None) -> str: return default_resp_cont