Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions examples/hello_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@

async def ask_and_print(question: str, llm: LLM, system_prompt) -> str:
logger.info(f"Q: {question}")
rsp = await llm.aask(question, system_msgs=[system_prompt], stream=True)

context = 'agent name or else'
def stream_callback(content) -> None:
print(context, content)
# use websocket send content or do something else

rsp = await llm.aask(question, system_msgs=[system_prompt], stream=True, stream_callback=stream_callback)
if hasattr(llm, "reasoning_content") and llm.reasoning_content:
logger.info(f"A reasoning: {llm.reasoning_content}")
logger.info(f"A: {rsp}")
Expand All @@ -29,7 +35,13 @@ async def lowlevel_api_example(llm: LLM):
logger.info(await llm.acompletion_text(hello_msg))

# streaming mode, much slower
await llm.acompletion_text(hello_msg, stream=True)

context = 'agent name or else'
def stream_callback(content) -> None:
print(context, content)
# use websocket send content or do something else

await llm.acompletion_text(hello_msg, stream=True, stream_callback=stream_callback)

# check completion if exist to test llm complete functions
if hasattr(llm, "completion"):
Expand Down
4 changes: 3 additions & 1 deletion metagpt/provider/anthropic_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFI
async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))

async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True))
collected_content = []
collected_reasoning_content = []
Expand All @@ -74,6 +74,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = US
elif delta_type == "text_delta":
content = event.delta.text
log_llm_stream(content)
if stream_callback:
stream_callback(content)
collected_content.append(content)
elif event_type == "message_delta":
usage.output_tokens = event.usage.output_tokens # update final output_tokens
Expand Down
4 changes: 3 additions & 1 deletion metagpt/provider/ark_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_
if self.pricing_plan in self.cost_manager.token_costs:
super()._update_costs(usage, self.pricing_plan, local_calc_usage)

async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)),
stream=True,
Expand All @@ -82,6 +82,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFI
async for chunk in response:
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
log_llm_stream(chunk_message)
if stream_callback:
stream_callback(chunk_message)
collected_messages.append(chunk_message)
if chunk.usage:
# 火山方舟的流式调用会在最后一个chunk中返回usage,最后一个chunk的choices为[]
Expand Down
9 changes: 5 additions & 4 deletions metagpt/provider/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ async def aask(
images: Optional[Union[str, list[str]]] = None,
timeout=USE_CONFIG_TIMEOUT,
stream=None,
stream_callback=None,
) -> str:
if system_msgs:
message = self._system_msgs(system_msgs)
Expand All @@ -205,7 +206,7 @@ async def aask(
logger.debug(masked_message)

compressed_message = self.compress_messages(message, compress_type=self.config.compress_type)
rsp = await self.acompletion_text(compressed_message, stream=stream, timeout=self.get_timeout(timeout))
rsp = await self.acompletion_text(compressed_message, stream=stream, timeout=self.get_timeout(timeout), stream_callback=stream_callback)
# rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
return rsp

Expand Down Expand Up @@ -243,7 +244,7 @@ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""

@abstractmethod
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
"""_achat_completion_stream implemented by inherited class"""

@retry(
Expand All @@ -254,11 +255,11 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = US
retry_error_callback=log_and_reraise,
)
async def acompletion_text(
self, messages: list[dict], stream: bool = False, timeout: int = USE_CONFIG_TIMEOUT
self, messages: list[dict], stream: bool = False, timeout: int = USE_CONFIG_TIMEOUT, stream_callback=None
) -> str:
"""Asynchronous version of completion. Return str. Support stream-print"""
if stream:
return await self._achat_completion_stream(messages, timeout=self.get_timeout(timeout))
return await self._achat_completion_stream(messages, timeout=self.get_timeout(timeout), stream_callback=stream_callback)
resp = await self._achat_completion(messages, timeout=self.get_timeout(timeout))
return self.get_choice_text(resp)

Expand Down
4 changes: 3 additions & 1 deletion metagpt/provider/bedrock_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,13 @@ async def acompletion(self, messages: list[dict]) -> dict:
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
return await self.acompletion(messages)

async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
if self.model in NOT_SUPPORT_STREAM_MODELS:
rsp = await self.acompletion(messages)
full_text = self.get_choice_text(rsp)
log_llm_stream(full_text)
if stream_callback:
stream_callback(full_text)
return full_text

request_body = self.__provider.get_request_body(messages, self._const_kwargs, stream=True)
Expand Down
4 changes: 3 additions & 1 deletion metagpt/provider/dashscope_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFI
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> GenerationOutput:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))

async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}
Expand All @@ -225,6 +225,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = US
content = chunk.output.choices[0]["message"]["content"]
usage = dict(chunk.usage) # each chunk has usage
log_llm_stream(content)
if stream_callback:
stream_callback(content)
collected_content.append(content)
log_llm_stream("\n")
self._update_costs(usage)
Expand Down
4 changes: 3 additions & 1 deletion metagpt/provider/google_gemini_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ async def _achat_completion(
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))

async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(
**self._const_kwargs(messages, stream=True)
)
Expand All @@ -148,6 +148,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = US
logger.warning(f"messages: {messages}\nerrors: {e}\n{BlockedPromptException(str(chunk))}")
raise BlockedPromptException(str(chunk))
log_llm_stream(content)
if stream_callback:
stream_callback(content)
collected_content.append(content)
log_llm_stream("\n")

Expand Down
4 changes: 2 additions & 2 deletions metagpt/provider/human_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""dummy implementation of abstract method in base"""
return []

async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
pass

async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
"""dummy implementation of abstract method in base"""
return ""

Expand Down
8 changes: 5 additions & 3 deletions metagpt/provider/ollama_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def get_choice_text(self, rsp):
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))

async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.ollama_message.api_suffix,
Expand All @@ -248,7 +248,7 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = US
stream=True,
)
if isinstance(resp, AsyncGenerator):
return await self._processing_openai_response_async_generator(resp)
return await self._processing_openai_response_async_generator(resp, stream_callback=stream_callback)
elif isinstance(resp, OpenAIResponse):
return self._processing_openai_response(resp)
else:
Expand All @@ -260,7 +260,7 @@ def _processing_openai_response(self, openai_resp: OpenAIResponse):
self._update_costs(usage)
return resp

async def _processing_openai_response_async_generator(self, ag_openai_resp: AsyncGenerator[OpenAIResponse, None]):
async def _processing_openai_response_async_generator(self, ag_openai_resp: AsyncGenerator[OpenAIResponse, None], stream_callback = None):
collected_content = []
usage = {}
async for raw_chunk in ag_openai_resp:
Expand All @@ -270,6 +270,8 @@ async def _processing_openai_response_async_generator(self, ag_openai_resp: Asyn
content = self.ollama_message.get_choice(chunk)
collected_content.append(content)
log_llm_stream(content)
if stream_callback:
stream_callback(content)
else:
# stream finished
usage = self.get_usage(chunk)
Expand Down
8 changes: 5 additions & 3 deletions metagpt/provider/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _get_proxy_params(self) -> dict:

return params

async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=self.get_timeout(timeout)), stream=True
)
Expand All @@ -109,6 +109,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFI
chunk_message = choice_delta.content or "" # extract the message
finish_reason = choice0.finish_reason if hasattr(choice0, "finish_reason") else None
log_llm_stream(chunk_message)
if stream_callback:
stream_callback(chunk_message)
collected_messages.append(chunk_message)
chunk_has_usage = hasattr(chunk, "usage") and chunk.usage
if has_finished:
Expand Down Expand Up @@ -169,10 +171,10 @@ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) ->
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
"""when streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages, timeout=timeout)
return await self._achat_completion_stream(messages, timeout=timeout, stream_callback=stream_callback)

rsp = await self._achat_completion(messages, timeout=self.get_timeout(timeout))
return self.get_choice_text(rsp)
Expand Down
4 changes: 3 additions & 1 deletion metagpt/provider/openrouter_reasoning.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFI
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))

async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
self.headers["Content-Type"] = "text/event-stream" # update header to adapt the client
payload = self._const_kwargs(messages, stream=True)
resp, _, _ = await self.client.arequest(
Expand All @@ -75,6 +75,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout: int = US
elif delta["content"]:
collected_content.append(delta["content"])
log_llm_stream(delta["content"])
if stream_callback:
stream_callback(delta["content"])

usage = chunk.get("usage")

Expand Down
4 changes: 3 additions & 1 deletion metagpt/provider/qianfan_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,16 @@ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFI
async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> JsonBody:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))

async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True), request_timeout=timeout)
collected_content = []
usage = {}
async for chunk in resp:
content = chunk.body.get("result", "")
usage = chunk.body.get("usage", {})
log_llm_stream(content)
if stream_callback:
stream_callback(content)
collected_content.append(content)
log_llm_stream("\n")

Expand Down
4 changes: 3 additions & 1 deletion metagpt/provider/spark_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEO
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
return await self._achat_completion(messages, timeout)

async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
response = await self.acreate(messages, stream=True)
collected_content = []
usage = {}
async for chunk in response:
collected_content.append(chunk.content)
log_llm_stream(chunk.content)
if stream_callback:
stream_callback(chunk.content)
if hasattr(chunk, "additional_kwargs"):
usage = chunk.additional_kwargs.get("token_usage", {})

Expand Down
4 changes: 3 additions & 1 deletion metagpt/provider/zhipuai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEO
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))

async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, stream_callback = None) -> str:
response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}
Expand All @@ -81,6 +81,8 @@ async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFI
content = self.get_choice_delta_text(chunk)
collected_content.append(content)
log_llm_stream(content)
if stream_callback:
stream_callback(content)

log_llm_stream("\n")

Expand Down
4 changes: 2 additions & 2 deletions tests/metagpt/provider/test_base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@ async def _achat_completion(self, messages: list[dict], timeout=3):
async def acompletion(self, messages: list[dict], timeout=3):
return get_part_chat_completion(name)

async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3, stream_callback = None) -> str:
pass

async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3, stream_callback = None) -> str:
return default_resp_cont


Expand Down
Loading