Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 45 additions & 94 deletions backend/app/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import inspect
import logging
from datetime import datetime
from typing import Annotated, Any, AsyncIterator, Final, Literal
Expand All @@ -11,9 +10,9 @@
from agents import Agent, RunContextWrapper, Runner, function_tool
from chatkit.agents import (
AgentContext,
ClientToolCall,
ThreadItemConverter,
ClientToolCall,
stream_agent_response,
simple_to_agent_input,
)
from chatkit.server import ChatKitServer, ThreadItemDoneEvent
from chatkit.types import (
Expand All @@ -24,6 +23,8 @@
ThreadMetadata,
ThreadStreamEvent,
UserMessageItem,
AssistantMessageItem,
WidgetItem,
)
from openai.types.responses import ResponseInputContentParam
from pydantic import ConfigDict, Field
Expand Down Expand Up @@ -210,8 +211,7 @@ def __init__(self) -> None:
name="ChatKit Guide",
instructions=INSTRUCTIONS,
tools=tools, # type: ignore[arg-type]
)
self._thread_item_converter = self._init_thread_item_converter()
)

async def respond(
self,
Expand All @@ -232,7 +232,7 @@ async def respond(
if target_item is None or _is_tool_completion_item(target_item):
return

agent_input = await self._to_agent_input(thread, target_item)
agent_input = await self._to_agent_input(thread, target_item, context)
if agent_input is None:
return

Expand All @@ -244,29 +244,8 @@ async def respond(

async for event in stream_agent_response(agent_context, result):
yield event
return

async def to_message_content(self, _input: Attachment) -> ResponseInputContentParam:
raise RuntimeError("File attachments are not supported in this demo.")

def _init_thread_item_converter(self) -> Any | None:
converter_cls = ThreadItemConverter
if converter_cls is None or not callable(converter_cls):
return None

attempts: tuple[dict[str, Any], ...] = (
{"to_message_content": self.to_message_content},
{"message_content_converter": self.to_message_content},
{},
)

for kwargs in attempts:
try:
return converter_cls(**kwargs)
except TypeError:
continue
return None

return

async def _latest_thread_item(
self, thread: ThreadMetadata, context: dict[str, Any]
) -> ThreadItem | None:
Expand All @@ -281,74 +260,46 @@ async def _to_agent_input(
self,
thread: ThreadMetadata,
item: ThreadItem,
context: dict[str, Any],
) -> Any | None:
if _is_tool_completion_item(item):
return None

converter = getattr(self, "_thread_item_converter", None)
if converter is not None:
for attr in (
"to_input_item",
"convert",
"convert_item",
"convert_thread_item",
):
method = getattr(converter, attr, None)
if method is None:
continue
call_args: list[Any] = [item]
call_kwargs: dict[str, Any] = {}
try:
signature = inspect.signature(method)
except (TypeError, ValueError):
signature = None

if signature is not None:
params = [
parameter
for parameter in signature.parameters.values()
if parameter.kind
not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
)
]
if len(params) >= 2:
next_param = params[1]
if next_param.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
):
call_args.append(thread)
else:
call_kwargs[next_param.name] = thread

result = method(*call_args, **call_kwargs)
if inspect.isawaitable(result):
return await result
return result

if isinstance(item, UserMessageItem):
return _user_message_text(item)
# converter = getattr(self, "_thread_item_converter", None)
history: list[ThreadItem] = []
try:
loaded = await self.store.load_thread_items(
thread.id,
after=None,
limit=50,
order="desc",
context=context,
)
history = list(reversed(loaded.data))
except Exception: # noqa: BLE001
history = []

latest_id = getattr(item, "id", None)
if latest_id is None or not any(
getattr(existing, "id", None) == latest_id for existing in history
):
history.append(item)

relevant: list[ThreadItem] = [
entry
for entry in history
if isinstance(
entry,
(
UserMessageItem,
AssistantMessageItem,
ClientToolCallItem,
WidgetItem,
),
)
]

return None
if len(relevant) > 12:
relevant = relevant[-12:]

async def _add_hidden_item(
self,
thread: ThreadMetadata,
context: dict[str, Any],
content: str,
) -> None:
await self.store.add_thread_item(
thread.id,
HiddenContextItem(
id=_gen_id("msg"),
thread_id=thread.id,
created_at=datetime.now(),
content=content,
),
context,
)
return await simple_to_agent_input(relevant)


def create_chatkit_server() -> FactAssistantServer | None:
Expand Down