Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion astrbot/core/provider/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def get_current_key(self) -> str:

def get_keys(self) -> List[str]:
"""获得提供商 Key"""
return self.provider_config.get("key", [])
keys = self.provider_config.get("key", [""])
return keys if keys else [""]

@abc.abstractmethod
def set_key(self, key: str):
Expand Down
26 changes: 16 additions & 10 deletions astrbot/core/provider/sources/anthropic_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
)

self.chosen_api_key: str = ""
self.api_keys: List = provider_config.get("key", [])
self.api_keys: List = super().get_keys()
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else ""
self.base_url = provider_config.get("api_base", "https://api.anthropic.com")
self.timeout = provider_config.get("timeout", 120)
Expand Down Expand Up @@ -70,9 +70,13 @@ def _prepare_payload(self, messages: list[dict]):
{
"type": "tool_use",
"name": tool_call["function"]["name"],
"input": json.loads(tool_call["function"]["arguments"])
if isinstance(tool_call["function"]["arguments"], str)
else tool_call["function"]["arguments"],
"input": (
json.loads(tool_call["function"]["arguments"])
if isinstance(
tool_call["function"]["arguments"], str
)
else tool_call["function"]["arguments"]
),
"id": tool_call["id"],
}
)
Expand Down Expand Up @@ -175,9 +179,9 @@ async def _query_stream(
# 累积 JSON 输入
if "input_json" not in tool_use_buffer[event.index]:
tool_use_buffer[event.index]["input_json"] = ""
tool_use_buffer[event.index]["input_json"] += (
event.delta.partial_json
)
tool_use_buffer[event.index][
"input_json"
] += event.delta.partial_json

elif event.type == "content_block_stop":
# 内容块结束
Expand Down Expand Up @@ -355,9 +359,11 @@ async def assemble_context(self, text: str, image_urls: List[str] = None):
"source": {
"type": "base64",
"media_type": mime_type,
"data": image_data.split("base64,")[1]
if "base64," in image_data
else image_data,
"data": (
image_data.split("base64,")[1]
if "base64," in image_data
else image_data
),
},
}
)
Expand Down
38 changes: 21 additions & 17 deletions astrbot/core/provider/sources/gemini_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
import random
from typing import Optional
from typing import Optional, List
from collections.abc import AsyncGenerator

from google import genai
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
provider_settings,
default_persona,
)
self.api_keys: list = provider_config.get("key", [])
self.api_keys: List = super().get_keys()
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else ""
self.timeout: int = int(provider_config.get("timeout", 180))

Expand Down Expand Up @@ -218,19 +218,21 @@ async def _prepare_query_config(
response_modalities=modalities,
tools=tool_list,
safety_settings=self.safety_settings if self.safety_settings else None,
thinking_config=types.ThinkingConfig(
thinking_budget=min(
int(
self.provider_config.get("gm_thinking_config", {}).get(
"budget", 0
)
thinking_config=(
types.ThinkingConfig(
thinking_budget=min(
int(
self.provider_config.get("gm_thinking_config", {}).get(
"budget", 0
)
),
24576,
),
24576,
),
)
if "gemini-2.5-flash" in self.get_model()
and hasattr(types.ThinkingConfig, "thinking_budget")
else None,
)
if "gemini-2.5-flash" in self.get_model()
and hasattr(types.ThinkingConfig, "thinking_budget")
else None
),
automatic_function_calling=types.AutomaticFunctionCallingConfig(
disable=True
),
Expand Down Expand Up @@ -274,9 +276,11 @@ def append_or_extend(
if role == "user":
if isinstance(content, list):
parts = [
types.Part.from_text(text=item["text"] or " ")
if item["type"] == "text"
else process_image_url(item["image_url"])
(
types.Part.from_text(text=item["text"] or " ")
if item["type"] == "text"
else process_image_url(item["image_url"])
)
for item in content
]
else:
Expand Down
2 changes: 1 addition & 1 deletion astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
default_persona,
)
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.api_keys: List = super().get_keys()
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
Expand Down
Loading