Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 40 additions & 23 deletions astrbot/core/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,33 +52,51 @@ def port_checker(port: int, host: str = "localhost"):
return False


def save_temp_img(img: Union[Image.Image, str]) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
# 获得文件创建时间,清除超过 12 小时的
def save_temp_img(img: Union[Image.Image, bytes], save_name: str | None = None) -> str:
"""
保存临时图片:
- 自动清理超过 12 小时的临时文件
- 如果提供了 save_name(含扩展名),直接用作文件名;否则按规则自动生成
- 根据图片模式自动选择保存格式(RGBA -> PNG,其余 -> JPG)
"""
temp_dir = Path(get_astrbot_data_path()) / "temp"
temp_dir.mkdir(parents=True, exist_ok=True)

# 清理超过 12 小时的旧文件
now = time.time()
try:
for f in os.listdir(temp_dir):
path = os.path.join(temp_dir, f)
if os.path.isfile(path):
ctime = os.path.getctime(path)
if time.time() - ctime > 3600 * 12:
os.remove(path)
for f in temp_dir.iterdir():
if f.is_file() and now - f.stat().st_ctime > 3600 * 12:
f.unlink(missing_ok=True)
except Exception as e:
print(f"清除临时文件失败: {e}")

# 获得时间戳
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
p = os.path.join(temp_dir, f"{timestamp}.jpg")
# 决定文件名
if save_name: # 外部指定了名字
file_name = save_name
path = temp_dir / file_name
else: # 自动生成
timestamp = f"{int(now)}_{uuid.uuid4().hex[:8]}"
if isinstance(img, Image.Image) and img.mode in ("RGBA", "LA"):
file_name = f"{timestamp}.png"
else:
file_name = f"{timestamp}.jpg"
path = temp_dir / file_name

# 保存文件
if isinstance(img, Image.Image):
img.save(p)
else:
with open(p, "wb") as f:
f.write(img)
return p
if path.suffix.lower() == ".png" or img.mode in ("RGBA", "LA"):
img.save(path, format="PNG")
else:
img.convert("RGB").save(path, format="JPEG", quality=95)
else: # bytes
path.write_bytes(img)

return str(path)


async def download_image_by_url(
url: str, post: bool = False, post_data: dict = None, path=None
url: str, post: bool = False, post_data: dict = None, path=None, save_name=None
) -> str:
"""
下载图片, 返回 path
Expand All @@ -94,15 +112,15 @@ async def download_image_by_url(
if post:
async with session.post(url, json=post_data) as resp:
if not path:
return save_temp_img(await resp.read())
return save_temp_img(await resp.read(), save_name)
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
else:
async with session.get(url) as resp:
if not path:
return save_temp_img(await resp.read())
return save_temp_img(await resp.read(), save_name)
else:
with open(path, "wb") as f:
f.write(await resp.read())
Expand All @@ -114,14 +132,13 @@ async def download_image_by_url(
async with aiohttp.ClientSession() as session:
if post:
async with session.get(url, ssl=ssl_context) as resp:
return save_temp_img(await resp.read())
return save_temp_img(await resp.read(), save_name)
else:
async with session.get(url, ssl=ssl_context) as resp:
return save_temp_img(await resp.read())
return save_temp_img(await resp.read(), save_name)
except Exception as e:
raise e


async def download_file(url: str, path: str, show_progress: bool = False):
"""
从指定 url 下载文件到指定路径 path
Expand Down
Loading
Loading