From 3165e72f4e2adab2b0ce05c4e78dd97e4e178ff5 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Wed, 20 Aug 2025 15:25:07 -0700 Subject: [PATCH 01/16] add 'litellm_utils' module --- .../jupyter_ai/litellm_utils/__init__.py | 2 + .../litellm_utils/test_toolcall_list.py | 52 ++++++++ .../jupyter_ai/litellm_utils/toolcall_list.py | 121 ++++++++++++++++++ .../litellm_utils/toolcall_types.py | 57 +++++++++ 4 files changed, 232 insertions(+) create mode 100644 packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py create mode 100644 packages/jupyter-ai/jupyter_ai/litellm_utils/test_toolcall_list.py create mode 100644 packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py create mode 100644 packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py new file mode 100644 index 000000000..cd95e2b2d --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py @@ -0,0 +1,2 @@ +from .toolcall_list import ToolCallList +from .toolcall_types import * diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/test_toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/test_toolcall_list.py new file mode 100644 index 000000000..9069eb481 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/test_toolcall_list.py @@ -0,0 +1,52 @@ +from litellm.utils import ChatCompletionDeltaToolCall, Function +from .toolcall_list import ToolCallList + +class TestToolCallList(): + + def test_single_tool_stream(self): + """ + Asserts this class works against a sample response from Claude running a + single tool. + """ + # Setup test + ID = "toolu_01TzXi4nFJErYThcdhnixn7e" + toolcall_list = ToolCallList() + toolcall_list += [ChatCompletionDeltaToolCall(id=ID, function=Function(arguments='', name='ls'), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='{"path', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='": "."}', name=None), type='function', index=0)] + + # Verify the resolved list of calls + resolved_toolcalls = toolcall_list.resolve() + assert len(resolved_toolcalls) == 1 + assert resolved_toolcalls[0] + + def test_two_tool_stream(self): + """ + Asserts this class works against a sample response from Claude running a + two tools in parallel. + """ + # Setup test + ID_0 = 'toolu_0141FrNfT2LJg6odqbrdmLM6' + ID_1 = 'toolu_01DKqnaXVcyp1v1ABxhHC5Sg' + toolcall_list = ToolCallList() + toolcall_list += [ChatCompletionDeltaToolCall(id=ID_0, function=Function(arguments='', name='ls'), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='{"path": ', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='"."}', name=None), type='function', index=0)] + toolcall_list += [ChatCompletionDeltaToolCall(id=ID_1, function=Function(arguments='', name='bash'), type='function', index=1)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='', name=None), type='function', index=1)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='{"com', name=None), type='function', index=1)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='mand": "ech', name=None), type='function', index=1)] + toolcall_list += [ChatCompletionDeltaToolCall(id=None, function=Function(arguments='o \'hello\'"}', name=None), type='function', index=1)] + + # Verify the resolved list of calls + resolved_toolcalls = toolcall_list.resolve() + assert len(resolved_toolcalls) == 2 + assert resolved_toolcalls[0].id == ID_0 + assert resolved_toolcalls[0].function.name == "ls" + assert resolved_toolcalls[0].function.arguments == { "path": "." } + assert resolved_toolcalls[1].id == ID_1 + assert resolved_toolcalls[1].function.name == "bash" + assert resolved_toolcalls[1].function.arguments == { "command": "echo \'hello\'" } + diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py new file mode 100644 index 000000000..1e3effd3a --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py @@ -0,0 +1,121 @@ +from litellm.utils import ChatCompletionDeltaToolCall, Function +import json + +from .toolcall_types import ResolvedToolCall, ResolvedFunction + +class ToolCallList(): + """ + A helper object that defines a custom `__iadd__()` method which accepts a + `tool_call_deltas: list[ChatCompletionDeltaToolCall]` argument. This class + is used to aggregate the tool call deltas yielded from a LiteLLM response + stream and produce a list of tool calls. + + After all tool call deltas are added, the `process()` method may be called + to return a list of resolved tool calls. + + Example usage: + + ```py + tool_call_list = ToolCallList() + reply_stream = await litellm.acompletion(..., stream=True) + + async for chunk in reply_stream: + tool_call_delta = chunk.choices[0].delta.tool_calls + tool_call_list += tool_call_delta + + tool_call_list.resolve() + ``` + """ + + _aggregate: list[ChatCompletionDeltaToolCall] + + def __init__(self): + self.size = None + + # Initialize `_aggregate` + self._aggregate = [] + + + def __iadd__(self, other: list[ChatCompletionDeltaToolCall] | None) -> 'ToolCallList': + """ + Adds a list of tool call deltas to this instance. + + NOTE: This assumes the 'index' attribute on each entry in this list to + be accurate. If this assumption doesn't hold, we will need to rework the + logic here. + """ + if other is None: + return self + + # Iterate through each delta + for delta in other: + # Ensure `self._aggregate` is at least of size `delta.index + 1` + for i in range(len(self._aggregate), delta.index + 1): + self._aggregate.append(ChatCompletionDeltaToolCall( + function=Function(arguments=""), + index=i, + )) + + # Find the corresponding target in the `self._aggregate` and add the + # delta on top of it. In most cases, the value of aggregate + # attribute is set as soon as any delta sets it to a non-`None` + # value. However, `delta.function.arguments` is a string that should + # be appended to the aggregate value of that attribute. + target = self._aggregate[delta.index] + if delta.type: + target.type = delta.type + if delta.id: + target.id = delta.id + if delta.function.name: + target.function.name = delta.function.name + if delta.function.arguments: + target.function.arguments += delta.function.arguments + + return self + + + def __add__(self, other: list[ChatCompletionDeltaToolCall] | None) -> 'ToolCallList': + """ + Alias for `__iadd__()`. + """ + return self.__iadd__(other) + + + def resolve(self) -> list[ResolvedToolCall]: + """ + Resolve the aggregated tool call delta lists into a list of tool calls. + """ + resolved_toolcalls: list[ResolvedToolCall] = [] + for i, raw_toolcall in enumerate(self._aggregate): + # Verify entries are at the correct index in the aggregated list + assert raw_toolcall.index == i + + # Verify each tool call specifies the name of the tool to run. + # + # TODO: Check if this may cause a runtime error. The docstring on + # `litellm.utils.Function` implies that `name` may be `None`. + assert raw_toolcall.function.name + + # Verify each tool call defines the type of tool it is calling. + assert raw_toolcall.type is not None + + # Parse the function argument string into a dictionary + resolved_fn_args = json.loads(raw_toolcall.function.arguments) + + # Add to the returned list + resolved_fn = ResolvedFunction( + name=raw_toolcall.function.name, + arguments=resolved_fn_args + ) + resolved_toolcall = ResolvedToolCall( + id=raw_toolcall.id, + type=raw_toolcall.type, + index=i, + function=resolved_fn + ) + resolved_toolcalls.append(resolved_toolcall) + + return resolved_toolcalls + + + \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py new file mode 100644 index 000000000..9426439f0 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py @@ -0,0 +1,57 @@ +from __future__ import annotations +from pydantic import BaseModel +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Any + +class ResolvedFunction(BaseModel): + """ + A type-safe, parsed representation of `litellm.utils.Function`. + """ + + name: str + """ + Name of the tool function to be called. + + TODO: Check if this attribute is defined for non-function tools, e.g. tools + provided by a MCP server. The docstring on `litellm.utils.Function` implies + that `name` may be `None`. + """ + + arguments: dict + """ + Arguments to the tool function, as a dictionary. + """ + +class ResolvedToolCall(BaseModel): + """ + A type-safe, parsed representation of + `litellm.utils.ChatCompletionDeltaToolCall`. + """ + + id: str | None + """ + The ID of the tool call. This should always be provided by LiteLLM, this + type is left optional as we do not use this attribute. + """ + + type: str + """ + The 'type' of tool call. Usually 'function'. + + TODO: Make this a union of string literals to ensure we are handling every + potential type of tool call. + """ + + function: ResolvedFunction + """ + The resolved function. See `ResolvedFunction` for more info. + """ + + index: int + """ + The index of this tool call. + + This is usually 0 unless the LLM supports parallel tool calling. + """ From d504b348c9ab641a0280b9dca1c70793c28e0f3d Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Sat, 23 Aug 2025 11:59:27 -0700 Subject: [PATCH 02/16] WIP: first working copy of Jupyternaut as an agent --- .../jupyter_ai/personas/base_persona.py | 211 +++++++++++++++--- .../personas/jupyternaut/jupyternaut.py | 53 +++-- .../jupyter-ai/jupyter_ai/tools/models.py | 15 +- 3 files changed, 222 insertions(+), 57 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index 310901b1c..b74c778a7 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -1,29 +1,39 @@ +from __future__ import annotations import asyncio import os from abc import ABC, ABCMeta, abstractmethod from dataclasses import asdict from logging import Logger from time import time -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Tuple from jupyter_ai.config_manager import ConfigManager from jupyterlab_chat.models import Message, NewMessage, User from jupyterlab_chat.ychat import YChat +from litellm import ModelResponseStream, supports_function_calling +from litellm.utils import function_to_dict from pydantic import BaseModel from traitlets import MetaHasTraits from traitlets.config import LoggingConfigurable from .persona_awareness import PersonaAwareness +from ..litellm_utils import ToolCallList, ResolvedToolCall + +# Import toolkits +from jupyter_ai_tools.toolkits.file_system import toolkit as fs_toolkit +from jupyter_ai_tools.toolkits.code_execution import toolkit as codeexec_toolkit +from jupyter_ai_tools.toolkits.git import toolkit as git_toolkit -# prevents a circular import -# types imported under this block have to be surrounded in single quotes on use if TYPE_CHECKING: from collections.abc import AsyncIterator - - from litellm import ModelResponseStream - from .persona_manager import PersonaManager + from ..tools import Toolkit +DEFAULT_TOOLKITS: dict[str, Toolkit] = { + "fs": fs_toolkit, + "codeexec": codeexec_toolkit, + "git": git_toolkit, +} class PersonaDefaults(BaseModel): """ @@ -237,7 +247,7 @@ def as_user_dict(self) -> dict[str, Any]: async def stream_message( self, reply_stream: "AsyncIterator[ModelResponseStream | str]" - ) -> None: + ) -> Tuple[ResolvedToolCall, ToolCallList]: """ Takes an async iterator, dubbed the 'reply stream', and streams it to a new message by this persona in the YChat. The async iterator may yield @@ -247,21 +257,36 @@ async def stream_message( stream, then continuously updates it until the stream is closed. - Automatically manages its awareness state to show writing status. + + Returns a list of `ResolvedToolCall` objects. If this list is not empty, + the persona should run these tools. """ stream_id: Optional[str] = None stream_interrupted = False try: self.awareness.set_local_state_field("isWriting", True) - async for chunk in reply_stream: - # Coerce LiteLLM stream chunk to a string delta - if not isinstance(chunk, str): - chunk = chunk.choices[0].delta.content + toolcall_list = ToolCallList() + resolved_toolcalls: list[ResolvedToolCall] = [] - # LiteLLM streams always terminate with an empty chunk, so we - # ignore and continue when this occurs. - if not chunk: + async for chunk in reply_stream: + # Compute `content_delta` and `tool_calls_delta` based on the + # type of object yielded by `reply_stream`. + if isinstance(chunk, ModelResponseStream): + delta = chunk.choices[0].delta + content_delta = delta.content + toolcalls_delta = delta.tool_calls + elif isinstance(chunk, str): + content_delta = chunk + toolcalls_delta = None + else: + raise Exception(f"Unrecognized type in stream_message(): {type(chunk)}") + + # LiteLLM streams always terminate with an empty chunk, so + # continue in this case. + if not (content_delta or toolcalls_delta): continue + # Terminate the stream if the user requested it. if ( stream_id and stream_id in self.message_interrupted.keys() @@ -280,34 +305,46 @@ async def stream_message( stream_interrupted = True break - if not stream_id: - stream_id = self.ychat.add_message( - NewMessage(body="", sender=self.id) + # Append `content_delta` to the existing message. + if content_delta: + # Start the stream with an empty message on the initial reply. + # Bind the new message ID to `stream_id`. + if not stream_id: + stream_id = self.ychat.add_message( + NewMessage(body="", sender=self.id) + ) + self.message_interrupted[stream_id] = asyncio.Event() + self.awareness.set_local_state_field("isWriting", stream_id) + assert stream_id + + self.ychat.update_message( + Message( + id=stream_id, + body=content_delta, + time=time(), + sender=self.id, + raw_time=False, + ), + append=True, ) - self.message_interrupted[stream_id] = asyncio.Event() - self.awareness.set_local_state_field("isWriting", stream_id) - - assert stream_id - self.ychat.update_message( - Message( - id=stream_id, - body=chunk, - time=time(), - sender=self.id, - raw_time=False, - ), - append=True, - ) + if toolcalls_delta: + toolcall_list += toolcalls_delta + + # After the reply stream is complete, resolve the list of tool calls. + resolved_toolcalls = toolcall_list.resolve() except Exception as e: self.log.error( f"Persona '{self.name}' encountered an exception printed below when attempting to stream output." ) self.log.exception(e) finally: + # Reset local state self.awareness.set_local_state_field("isWriting", False) - if stream_id: - # if stream was interrupted, add a tombstone - if stream_interrupted: + self.message_interrupted.pop(stream_id, None) + + # If stream was interrupted, add a tombstone and return `[]`, + # indicating that no tools should be run afterwards. + if stream_id and stream_interrupted: stream_tombstone = "\n\n(AI response stopped by user)" self.ychat.update_message( Message( @@ -319,8 +356,15 @@ async def stream_message( ), append=True, ) - if stream_id in self.message_interrupted.keys(): - del self.message_interrupted[stream_id] + return None + + # Otherwise return the resolved list. + if len(resolved_toolcalls): + count = len(resolved_toolcalls) + names = sorted([tc.function.name for tc in resolved_toolcalls]) + self.log.info(f"AI response triggered {count} tool calls: {names}") + return resolved_toolcalls, toolcall_list + def send_message(self, body: str) -> None: """ @@ -361,7 +405,7 @@ def get_mcp_config(self) -> dict[str, Any]: Returns the MCP config for the current chat. """ return self.parent.get_mcp_config() - + def process_attachments(self, message: Message) -> Optional[str]: """ Process file attachments in the message and return their content as a string. @@ -431,6 +475,99 @@ def resolve_attachment_to_path(self, attachment_id: str) -> Optional[str]: self.log.error(f"Failed to resolve attachment {attachment_id}: {e}") return None + def get_tools(self, model_id: str) -> list[dict]: + """ + Returns the `tools` parameter which should be passed to + `litellm.acompletion()` for a given LiteLLM model ID. + + If the model does not support tool-calling, this method returns an empty + list. Otherwise, it returns the list of tools available in the current + environment. These may include: + + - The default set of tool functions in Jupyter AI, defined in the + `jupyter_ai_tools` package. + + - (TODO) Tools provided by MCP server configuration, if any. + + - (TODO) Web search. + + - (TODO) File search using vector store IDs. + + TODO: cache this + + TODO: Implement some permissions system so users can control what tools + are allowable. + + NOTE: The returned list is expected by LiteLLM to conform to the `tools` + parameter defintiion defined by the OpenAI API: + https://platform.openai.com/docs/guides/tools#available-tools + + NOTE: This API is a WIP and is very likely to change. + """ + # Return early if the model does not support tool calling + if not supports_function_calling(model=model_id): + return [] + + tool_descriptions = [] + + # Get all tools from `jupyter_ai_tools` and store their object descriptions + for toolkit_name, toolkit in DEFAULT_TOOLKITS.items(): + # TODO: make these tool permissions configurable. + for tool in toolkit.get_tools(): + # Here, we are using a util function from LiteLLM to coerce + # each `Tool` struct into a tool description dictionary expected + # by LiteLLM. + desc = { + "type": "function", + "function": function_to_dict(tool.callable), + } + + # Prepend the toolkit name to each function name, hopefully + # ensuring every tool function has a unique name. + # e.g. 'git_add' => 'git__git_add' + # + # TODO: Actually ensure this instead of hoping. + desc['function']['name'] = f"{toolkit_name}__{desc['function']['name']}" + tool_descriptions.append(desc) + + # Finally, return the tool descriptions + return tool_descriptions + + + async def run_tools(self, tools: list[ResolvedToolCall]) -> list[dict]: + """ + Runs the tools specified in the list of tool calls returned by + `self.stream_message()`. Returns a list of dictionaries + `toolcall_outputs: list[dict]`, which should be appended directly to the + message history on the next invocation of the LLM. + """ + if not len(tools): + return [] + + tool_outputs: list[dict] = [] + for tool_call in tools: + # Get tool definition from the correct toolkit + toolkit_name, tool_name = tool_call.function.name.split("__") + assert toolkit_name in DEFAULT_TOOLKITS + tool_defn = DEFAULT_TOOLKITS[toolkit_name].get_tool_unsafe(tool_name) + + # Run tool and store its output + output = await tool_defn.callable(**tool_call.function.arguments) + + # Store the tool output in a dictionary accepted by LiteLLM + output_dict = { + "tool_call_id": tool_call.id, + "role": "tool", + "name": tool_call.function.name, + "content": output, + } + tool_outputs.append(output_dict) + + self.log.info(f"Ran {len(tools)} tool functions.") + return tool_outputs + + + def shutdown(self) -> None: """ Shuts the persona down. This method should: diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index 05ec403e4..66f1805e2 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -9,6 +9,7 @@ JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE, JupyternautSystemPromptArgs, ) +from ...litellm_utils import ResolvedToolCall class JupyternautPersona(BasePersona): @@ -37,22 +38,35 @@ async def process_message(self, message: Message) -> None: return model_id = self.config_manager.chat_model - model_args = self.config_manager.chat_model_args - context_as_messages = self.get_context_as_messages(model_id, message) - response_aiter = await acompletion( - **model_args, - model=model_id, - messages=[ - *context_as_messages, - { - "role": "user", - "content": message.body, - }, - ], - stream=True, - ) - await self.stream_message(response_aiter) + # `True` on the first LLM invocation, `False` on all invocations after. + initial_invocation = True + # List of tool calls requested by the LLM in the previous invocaiton. + tool_calls: list[ResolvedToolCall] = [] + tool_call_list = None + # List of tool call outputs computed in the previous invocation. + tool_call_outputs: list[dict] = [] + + # Loop until the AI is complete running all its tools. + while initial_invocation or len(tool_call_outputs): + messages = self.get_context_as_messages(model_id, message) + + # TODO: Find a better way to track tool calls + if not initial_invocation and tool_calls: + self.log.error(messages[-1]) + messages[-1]['tool_calls'] = tool_call_list._aggregate + messages.extend(tool_call_outputs) + + self.log.error(messages) + response_aiter = await acompletion( + model=model_id, + messages=messages, + tools=self.get_tools(model_id), + stream=True, + ) + tool_calls, tool_call_list = await self.stream_message(response_aiter) + initial_invocation = False + tool_call_outputs = await self.run_tools(tool_calls) def get_context_as_messages( self, model_id: str, message: Message @@ -79,16 +93,17 @@ def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]] """ Returns the current history as a list of messages accepted by `litellm.acompletion()`. + + NOTE: You should usually call the public `get_context_as_messages()` + method instead. """ # TODO: consider bounding history based on message size (e.g. total # char/token count) instead of message count. all_messages = self.ychat.get_messages() # gather last k * 2 messages and return - # we exclude the last message since that is the human message just - # submitted by a user. - start_idx = 0 if k is None else -2 * k - 1 - recent_messages: list[Message] = all_messages[start_idx:-1] + start_idx = 0 if k is None else -2 * k + recent_messages: list[Message] = all_messages[start_idx:] history: list[dict[str, Any]] = [] for msg in recent_messages: diff --git a/packages/jupyter-ai/jupyter_ai/tools/models.py b/packages/jupyter-ai/jupyter_ai/tools/models.py index 5b95b6174..e547f0c15 100644 --- a/packages/jupyter-ai/jupyter_ai/tools/models.py +++ b/packages/jupyter-ai/jupyter_ai/tools/models.py @@ -135,7 +135,7 @@ class Toolkit(BaseModel): name: str description: Optional[str] = None - tools: set = Field(default_factory=set) + tools: set[Tool] = Field(default_factory=set) model_config = ConfigDict(arbitrary_types_allowed=True) def add_tool(self, tool: Tool): @@ -146,6 +146,19 @@ def add_tool(self, tool: Tool): """ self.tools.add(tool) + def get_tool_unsafe(self, tool_name: str) -> Tool: + """ + (WIP) Gets a tool by its name. This is just a temporary method which is + used to make Jupyternaut agentic before we implement the + read/write/execute/delete permissions. + """ + for tool in self.tools: + if tool_name == tool.name: + return tool + + raise Exception(f"Tool not found: {tool_name}") + + def get_tools( self, read: Optional[bool] = None, From 5aa46bf9a4a93d89403ab9f5c9fe3dfc76c13832 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Sat, 23 Aug 2025 19:05:28 -0700 Subject: [PATCH 03/16] clean up tool calling flow & show in chat --- .../jupyter_ai/litellm_utils/__init__.py | 4 +- .../litellm_utils/streaming_utils.py | 13 ++++ .../jupyter_ai/litellm_utils/toolcall_list.py | 78 +++++++++++++++---- .../litellm_utils/toolcall_types.py | 57 -------------- .../jupyter_ai/personas/base_persona.py | 46 ++++++----- .../personas/jupyternaut/jupyternaut.py | 65 +++++++++++----- 6 files changed, 151 insertions(+), 112 deletions(-) create mode 100644 packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py delete mode 100644 packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py index cd95e2b2d..787493764 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py @@ -1,2 +1,2 @@ -from .toolcall_list import ToolCallList -from .toolcall_types import * +from .toolcall_list import * +from .streaming_utils import * diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py new file mode 100644 index 000000000..febe3f7f2 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel +from .toolcall_list import ToolCallList + +class StreamResult(BaseModel): + id: str + """ + ID of the new message. + """ + + tool_calls: ToolCallList + """ + Tool calls requested by the LLM in its streamed response. + """ diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py index 1e3effd3a..654939ebb 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py @@ -1,9 +1,61 @@ from litellm.utils import ChatCompletionDeltaToolCall, Function import json +from pydantic import BaseModel +from typing import Any -from .toolcall_types import ResolvedToolCall, ResolvedFunction +class ResolvedFunction(BaseModel): + """ + A type-safe, parsed representation of `litellm.utils.Function`. + """ + + name: str + """ + Name of the tool function to be called. + + TODO: Check if this attribute is defined for non-function tools, e.g. tools + provided by a MCP server. The docstring on `litellm.utils.Function` implies + that `name` may be `None`. + """ + + arguments: dict[str, Any] + """ + Arguments to the tool function, as a dictionary. + """ + + +class ResolvedToolCall(BaseModel): + """ + A type-safe, parsed representation of + `litellm.utils.ChatCompletionDeltaToolCall`. + """ + + id: str | None + """ + The ID of the tool call. This should always be provided by LiteLLM, this + type is left optional as we do not use this attribute. + """ + + type: str + """ + The 'type' of tool call. Usually 'function'. -class ToolCallList(): + TODO: Make this a union of string literals to ensure we are handling every + potential type of tool call. + """ + + function: ResolvedFunction + """ + The resolved function. See `ResolvedFunction` for more info. + """ + + index: int + """ + The index of this tool call. + + This is usually 0 unless the LLM supports parallel tool calling. + """ + +class ToolCallList(BaseModel): """ A helper object that defines a custom `__iadd__()` method which accepts a `tool_call_deltas: list[ChatCompletionDeltaToolCall]` argument. This class @@ -27,14 +79,7 @@ class ToolCallList(): ``` """ - _aggregate: list[ChatCompletionDeltaToolCall] - - def __init__(self): - self.size = None - - # Initialize `_aggregate` - self._aggregate = [] - + _aggregate: list[ChatCompletionDeltaToolCall] = [] def __iadd__(self, other: list[ChatCompletionDeltaToolCall] | None) -> 'ToolCallList': """ @@ -116,6 +161,13 @@ def resolve(self) -> list[ResolvedToolCall]: resolved_toolcalls.append(resolved_toolcall) return resolved_toolcalls - - - \ No newline at end of file + + def to_json(self) -> list[dict[str, Any]]: + """ + Returns the list of tool calls as a Python dictionary that can be + JSON-serialized. + """ + return [ + model.model_dump() for model in self._aggregate + ] + \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py deleted file mode 100644 index 9426439f0..000000000 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_types.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations -from pydantic import BaseModel -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from typing import Any - -class ResolvedFunction(BaseModel): - """ - A type-safe, parsed representation of `litellm.utils.Function`. - """ - - name: str - """ - Name of the tool function to be called. - - TODO: Check if this attribute is defined for non-function tools, e.g. tools - provided by a MCP server. The docstring on `litellm.utils.Function` implies - that `name` may be `None`. - """ - - arguments: dict - """ - Arguments to the tool function, as a dictionary. - """ - -class ResolvedToolCall(BaseModel): - """ - A type-safe, parsed representation of - `litellm.utils.ChatCompletionDeltaToolCall`. - """ - - id: str | None - """ - The ID of the tool call. This should always be provided by LiteLLM, this - type is left optional as we do not use this attribute. - """ - - type: str - """ - The 'type' of tool call. Usually 'function'. - - TODO: Make this a union of string literals to ensure we are handling every - potential type of tool call. - """ - - function: ResolvedFunction - """ - The resolved function. See `ResolvedFunction` for more info. - """ - - index: int - """ - The index of this tool call. - - This is usually 0 unless the LLM supports parallel tool calling. - """ diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index b74c778a7..21718dab7 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -5,7 +5,7 @@ from dataclasses import asdict from logging import Logger from time import time -from typing import TYPE_CHECKING, Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Optional from jupyter_ai.config_manager import ConfigManager from jupyterlab_chat.models import Message, NewMessage, User @@ -17,7 +17,7 @@ from traitlets.config import LoggingConfigurable from .persona_awareness import PersonaAwareness -from ..litellm_utils import ToolCallList, ResolvedToolCall +from ..litellm_utils import ToolCallList, StreamResult, ResolvedToolCall # Import toolkits from jupyter_ai_tools.toolkits.file_system import toolkit as fs_toolkit @@ -247,7 +247,7 @@ def as_user_dict(self) -> dict[str, Any]: async def stream_message( self, reply_stream: "AsyncIterator[ModelResponseStream | str]" - ) -> Tuple[ResolvedToolCall, ToolCallList]: + ) -> StreamResult: """ Takes an async iterator, dubbed the 'reply stream', and streams it to a new message by this persona in the YChat. The async iterator may yield @@ -263,12 +263,21 @@ async def stream_message( """ stream_id: Optional[str] = None stream_interrupted = False + tool_calls = ToolCallList() try: self.awareness.set_local_state_field("isWriting", True) - toolcall_list = ToolCallList() - resolved_toolcalls: list[ResolvedToolCall] = [] async for chunk in reply_stream: + # Start the stream with an empty message on the initial reply. + # Bind the new message ID to `stream_id`. + if not stream_id: + stream_id = self.ychat.add_message( + NewMessage(body="", sender=self.id) + ) + self.message_interrupted[stream_id] = asyncio.Event() + self.awareness.set_local_state_field("isWriting", stream_id) + assert stream_id + # Compute `content_delta` and `tool_calls_delta` based on the # type of object yielded by `reply_stream`. if isinstance(chunk, ModelResponseStream): @@ -307,16 +316,6 @@ async def stream_message( # Append `content_delta` to the existing message. if content_delta: - # Start the stream with an empty message on the initial reply. - # Bind the new message ID to `stream_id`. - if not stream_id: - stream_id = self.ychat.add_message( - NewMessage(body="", sender=self.id) - ) - self.message_interrupted[stream_id] = asyncio.Event() - self.awareness.set_local_state_field("isWriting", stream_id) - assert stream_id - self.ychat.update_message( Message( id=stream_id, @@ -328,10 +327,8 @@ async def stream_message( append=True, ) if toolcalls_delta: - toolcall_list += toolcalls_delta + tool_calls += toolcalls_delta - # After the reply stream is complete, resolve the list of tool calls. - resolved_toolcalls = toolcall_list.resolve() except Exception as e: self.log.error( f"Persona '{self.name}' encountered an exception printed below when attempting to stream output." @@ -358,12 +355,17 @@ async def stream_message( ) return None - # Otherwise return the resolved list. + # TODO: determine where this should live + resolved_toolcalls = tool_calls.resolve() if len(resolved_toolcalls): count = len(resolved_toolcalls) names = sorted([tc.function.name for tc in resolved_toolcalls]) self.log.info(f"AI response triggered {count} tool calls: {names}") - return resolved_toolcalls, toolcall_list + + return StreamResult( + id=stream_id, + tool_calls=tool_calls + ) def send_message(self, body: str) -> None: @@ -552,7 +554,9 @@ async def run_tools(self, tools: list[ResolvedToolCall]) -> list[dict]: tool_defn = DEFAULT_TOOLKITS[toolkit_name].get_tool_unsafe(tool_name) # Run tool and store its output - output = await tool_defn.callable(**tool_call.function.arguments) + output = tool_defn.callable(**tool_call.function.arguments) + if asyncio.iscoroutine(output): + output = await output # Store the tool output in a dictionary accepted by LiteLLM output_dict = { diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index 66f1805e2..3c350c1f4 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -1,4 +1,6 @@ from typing import Any, Optional +import time +import json from jupyterlab_chat.models import Message from litellm import acompletion @@ -9,7 +11,6 @@ JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE, JupyternautSystemPromptArgs, ) -from ...litellm_utils import ResolvedToolCall class JupyternautPersona(BasePersona): @@ -39,34 +40,60 @@ async def process_message(self, message: Message) -> None: model_id = self.config_manager.chat_model - # `True` on the first LLM invocation, `False` on all invocations after. - initial_invocation = True - # List of tool calls requested by the LLM in the previous invocaiton. - tool_calls: list[ResolvedToolCall] = [] - tool_call_list = None + # `True` before the first LLM response is sent, `False` afterwards. + initial_response = True # List of tool call outputs computed in the previous invocation. tool_call_outputs: list[dict] = [] - # Loop until the AI is complete running all its tools. - while initial_invocation or len(tool_call_outputs): - messages = self.get_context_as_messages(model_id, message) - - # TODO: Find a better way to track tool calls - if not initial_invocation and tool_calls: - self.log.error(messages[-1]) - messages[-1]['tool_calls'] = tool_call_list._aggregate - messages.extend(tool_call_outputs) + # Initialize list of messages, including history and context + messages: list[dict] = self.get_context_as_messages(model_id, message) - self.log.error(messages) + # Loop until the AI is complete running all its tools. + while initial_response or len(tool_call_outputs): + # Stream message to the chat response_aiter = await acompletion( model=model_id, messages=messages, tools=self.get_tools(model_id), stream=True, ) - tool_calls, tool_call_list = await self.stream_message(response_aiter) - initial_invocation = False - tool_call_outputs = await self.run_tools(tool_calls) + result = await self.stream_message(response_aiter) + initial_response = False + + # Append new reply to `messages` + reply = self.ychat.get_message(result.id) + tool_calls_json = result.tool_calls.to_json() + messages.append({ + "role": "assistant", + "content": reply.body, + "tool_calls": tool_calls_json + }) + + # Show tool call requests to YChat (not synced with `messages`) + if len(tool_calls_json): + self.ychat.update_message(Message( + id=result.id, + body=f"\n\n```\n{json.dumps(tool_calls_json, indent=2)}\n```\n", + sender=self.id, + time=time.time(), + raw_time=False + ), append=True) + + # Run tools and append outputs to `messages` + tool_call_outputs = await self.run_tools(result.tool_calls.resolve()) + messages.extend(tool_call_outputs) + + # Add tool call outputs to YChat (not synced with `messages`) + if tool_call_outputs: + self.ychat.update_message(Message( + id=result.id, + body=f"\n\n```\n{json.dumps(tool_call_outputs, indent=2)}\n```\n", + sender=self.id, + time=time.time(), + raw_time=False + ), append=True) + + def get_context_as_messages( self, model_id: str, message: Message From 7ba285dc5aaf7ef8780ffa02b6bf21f6a2f7dcf4 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 25 Aug 2025 06:16:51 +0200 Subject: [PATCH 04/16] add temporary default toolkit --- .../jupyter_ai/personas/base_persona.py | 46 +- .../jupyter-ai/jupyter_ai/tools/__init__.py | 3 +- .../jupyter_ai/tools/default_toolkit.py | 255 ++++++++ .../jupyter_ai/tools/test_default_toolkit.py | 595 ++++++++++++++++++ 4 files changed, 867 insertions(+), 32 deletions(-) create mode 100644 packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py create mode 100644 packages/jupyter-ai/jupyter_ai/tools/test_default_toolkit.py diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index 21718dab7..4e550cfbd 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -20,21 +20,13 @@ from ..litellm_utils import ToolCallList, StreamResult, ResolvedToolCall # Import toolkits -from jupyter_ai_tools.toolkits.file_system import toolkit as fs_toolkit -from jupyter_ai_tools.toolkits.code_execution import toolkit as codeexec_toolkit -from jupyter_ai_tools.toolkits.git import toolkit as git_toolkit +from ..tools.default_toolkit import DEFAULT_TOOLKIT if TYPE_CHECKING: from collections.abc import AsyncIterator from .persona_manager import PersonaManager from ..tools import Toolkit -DEFAULT_TOOLKITS: dict[str, Toolkit] = { - "fs": fs_toolkit, - "codeexec": codeexec_toolkit, - "git": git_toolkit, -} - class PersonaDefaults(BaseModel): """ Data structure that represents the default settings of a persona. Each persona @@ -512,27 +504,19 @@ def get_tools(self, model_id: str) -> list[dict]: tool_descriptions = [] - # Get all tools from `jupyter_ai_tools` and store their object descriptions - for toolkit_name, toolkit in DEFAULT_TOOLKITS.items(): - # TODO: make these tool permissions configurable. - for tool in toolkit.get_tools(): - # Here, we are using a util function from LiteLLM to coerce - # each `Tool` struct into a tool description dictionary expected - # by LiteLLM. - desc = { - "type": "function", - "function": function_to_dict(tool.callable), - } - - # Prepend the toolkit name to each function name, hopefully - # ensuring every tool function has a unique name. - # e.g. 'git_add' => 'git__git_add' - # - # TODO: Actually ensure this instead of hoping. - desc['function']['name'] = f"{toolkit_name}__{desc['function']['name']}" - tool_descriptions.append(desc) + # Get all tools from the default toolkit and store their object descriptions + for tool in DEFAULT_TOOLKIT.get_tools(): + # Here, we are using a util function from LiteLLM to coerce + # each `Tool` struct into a tool description dictionary expected + # by LiteLLM. + desc = { + "type": "function", + "function": function_to_dict(tool.callable), + } + tool_descriptions.append(desc) # Finally, return the tool descriptions + self.log.info(tool_descriptions) return tool_descriptions @@ -549,9 +533,9 @@ async def run_tools(self, tools: list[ResolvedToolCall]) -> list[dict]: tool_outputs: list[dict] = [] for tool_call in tools: # Get tool definition from the correct toolkit - toolkit_name, tool_name = tool_call.function.name.split("__") - assert toolkit_name in DEFAULT_TOOLKITS - tool_defn = DEFAULT_TOOLKITS[toolkit_name].get_tool_unsafe(tool_name) + # TODO: validation? + tool_name = tool_call.function.name + tool_defn = DEFAULT_TOOLKIT.get_tool_unsafe(tool_name) # Run tool and store its output output = tool_defn.callable(**tool_call.function.arguments) diff --git a/packages/jupyter-ai/jupyter_ai/tools/__init__.py b/packages/jupyter-ai/jupyter_ai/tools/__init__.py index 0252ac1a9..1f8e3afa3 100644 --- a/packages/jupyter-ai/jupyter_ai/tools/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/tools/__init__.py @@ -1,5 +1,6 @@ """Tools package for Jupyter AI.""" from .models import Tool, Toolkit +from .default_toolkit import DEFAULT_TOOLKIT -__all__ = ["Tool", "Toolkit"] +__all__ = ["Tool", "Toolkit", "DEFAULT_TOOLKIT"] diff --git a/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py b/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py new file mode 100644 index 000000000..79c8a7675 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py @@ -0,0 +1,255 @@ +from .models import Tool, Toolkit +from jupyter_ai_tools.toolkits.code_execution import bash + +import pathlib + + +def read(file_path: str, offset: int, limit: int) -> str: + """ + Read a subset of lines from a text file. + + Parameters + ---------- + file_path : str + Absolute path to the file that should be read. + offset : int + The line number at which to start reading (1-based indexing). + limit : int + Number of lines to read starting from *offset*. + If *offset + limit* exceeds the number of lines in the file, + all available lines after *offset* are returned. + + Returns + ------- + List[str] + List of lines (including line-ending characters) that were read. + + Examples + -------- + >>> # Suppose ``/tmp/example.txt`` contains 10 lines + >>> read('/tmp/example.txt', offset=3, limit=4) + ['third line\n', 'fourth line\n', 'fifth line\n', 'sixth line\n'] + """ + path = pathlib.Path(file_path) + if not path.is_file(): + raise FileNotFoundError(f"File not found: {file_path}") + + # Normalize arguments + offset = max(1, int(offset)) + limit = max(0, int(limit)) + lines: list[str] = [] + + with path.open(encoding='utf-8', errors='replace') as f: + # Skip to offset + line_no = 0 + # Loop invariant: line_no := last read line + # After the loop exits, line_no == offset - 1, meaning the + # next line starts at `offset` + while line_no < offset - 1: + line = f.readline() + # Return early if offset exceeds number of lines in file + if line == "": + return "" + line_no += 1 + + # Append lines until limit is reached + while len(lines) < limit: + line = f.readline() + if line == "": + break + lines.append(line) + + return "".join(lines) + + +def edit( + file_path: str, + old_string: str, + new_string: str, + replace_all: bool = False, +) -> None: + """ + Replace occurrences of a substring in a file. + + Parameters + ---------- + file_path : str + Absolute path to the file that should be edited. + old_string : str + Text that should be replaced. + new_string : str + Text that will replace *old_string*. + replace_all : bool, optional + If ``True`` all occurrences of *old_string* are replaced. + If ``False`` (default), only the first occurrence in the file is replaced. + + Returns + ------- + None + + Raises + ------ + FileNotFoundError + If *file_path* does not exist. + ValueError + If *old_string* is empty (replacing an empty string is ambiguous). + + Notes + ----- + The file is overwritten atomically: it is first read into memory, + the substitution is performed, and the file is written back. + This keeps the operation safe for short to medium-sized files. + + Examples + -------- + >>> # Replace only the first occurrence + >>> edit('/tmp/test.txt', 'foo', 'bar', replace_all=False) + >>> # Replace all occurrences + >>> edit('/tmp/test.txt', 'foo', 'bar', replace_all=True) + """ + path = pathlib.Path(file_path) + if not path.is_file(): + raise FileNotFoundError(f"File not found: {file_path}") + + if old_string == "": + raise ValueError("old_string must not be empty") + + # Read the entire file + content = path.read_text(encoding="utf-8", errors="replace") + + # Perform replacement + if replace_all: + new_content = content.replace(old_string, new_string) + else: + new_content = content.replace(old_string, new_string, 1) + + # Write back + path.write_text(new_content, encoding="utf-8") + + +def write(file_path: str, content: str) -> None: + """ + Write content to a file, creating it if it doesn't exist. + + Parameters + ---------- + file_path : str + Absolute path to the file that should be written. + content : str + Content to write to the file. + + Returns + ------- + None + + Raises + ------ + OSError + If the file cannot be written (e.g., permission denied, invalid path). + + Notes + ----- + This function will overwrite the file if it already exists. + The parent directory must exist; this function does not create directories. + + Examples + -------- + >>> write('/tmp/example.txt', 'Hello, world!') + >>> write('/tmp/data.json', '{"key": "value"}') + """ + path = pathlib.Path(file_path) + + # Write the content to the file + path.write_text(content, encoding="utf-8") + + +async def search_grep(pattern: str, include: str = "*") -> str: + """ + Search for text patterns in files using ripgrep. + + This function uses ripgrep (rg) to perform fast regex-based text searching + across files, with optional file filtering based on glob patterns. + + Parameters + ---------- + pattern : str + A regular expression pattern to search for. Ripgrep uses Rust regex + syntax which supports: + - Basic regex features: ., *, +, ?, ^, $, [], (), | + - Character classes: \w, \d, \s, \W, \D, \S + - Unicode categories: \p{L}, \p{N}, \p{P}, etc. + - Word boundaries: \b, \B + - Anchors: ^, $, \A, \z + - Quantifiers: {n}, {n,}, {n,m} + - Groups: (pattern), (?:pattern), (?Ppattern) + - Lookahead/lookbehind: (?=pattern), (?!pattern), (?<=pattern), (?>> search_grep(r"def\s+\w+", "*.py") + 'file.py:10:def my_function():' + + >>> search_grep(r"TODO|FIXME", "**/*.{py,js}") + 'app.py:25:# TODO: implement this + script.js:15:// FIXME: handle edge case' + + >>> search_grep(r"class\s+(\w+)", "src/**/*.py") + 'src/models.py:1:class User:' + """ + # Use bash tool to execute ripgrep + cmd_parts = ["rg", "--color=never", "--line-number", "--with-filename"] + + # Add glob pattern if specified + if include != "*": + cmd_parts.extend(["-g", include]) + + # Add the pattern (always quote it to handle special characters) + cmd_parts.append(pattern) + + # Join command with proper shell escaping + command = " ".join(f'"{part}"' if " " in part or any(c in part for c in "!*?[]{}()") else part for part in cmd_parts) + + try: + result = await bash(command) + return result + except Exception as e: + raise RuntimeError(f"Ripgrep search failed: {str(e)}") from e + + +DEFAULT_TOOLKIT = Toolkit(name="jupyter-ai-default-toolkit") +DEFAULT_TOOLKIT.add_tool(Tool(callable=bash)) +DEFAULT_TOOLKIT.add_tool(Tool(callable=read)) +DEFAULT_TOOLKIT.add_tool(Tool(callable=edit)) +DEFAULT_TOOLKIT.add_tool(Tool(callable=write)) +DEFAULT_TOOLKIT.add_tool(Tool(callable=search_grep)) \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/tools/test_default_toolkit.py b/packages/jupyter-ai/jupyter_ai/tools/test_default_toolkit.py new file mode 100644 index 000000000..9db82dc41 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/tools/test_default_toolkit.py @@ -0,0 +1,595 @@ +"""Tests for default_toolkit.py functions and toolkit configuration.""" + +import pathlib +import tempfile +import pytest +from unittest.mock import patch, mock_open + +from .default_toolkit import read, edit, write, search_grep, DEFAULT_TOOLKIT +from .models import Tool, Toolkit + + +class TestReadFunction: + """Test the read function.""" + + def test_read_valid_file(self): + """Test reading lines from a valid file.""" + # Create a temporary file with known content + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line 1\nline 2\nline 3\nline 4\nline 5\n") + temp_path = f.name + + try: + # Test reading from offset 2, limit 3 + result = read(temp_path, offset=2, limit=3) + assert result == "line 2\nline 3\nline 4\n" + + # Test reading from offset 1, limit 2 + result = read(temp_path, offset=1, limit=2) + assert result == "line 1\nline 2\n" + + # Test reading all lines from beginning + result = read(temp_path, offset=1, limit=10) + assert result == "line 1\nline 2\nline 3\nline 4\nline 5\n" + + # Test reading from middle to end + result = read(temp_path, offset=4, limit=10) + assert result == "line 4\nline 5\n" + + finally: + # Clean up + pathlib.Path(temp_path).unlink() + + def test_read_empty_file(self): + """Test reading from an empty file.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + temp_path = f.name + + try: + result = read(temp_path, offset=1, limit=5) + assert result == "" + finally: + pathlib.Path(temp_path).unlink() + + def test_read_single_line_file(self): + """Test reading from a file with one line.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("single line\n") + temp_path = f.name + + try: + result = read(temp_path, offset=1, limit=1) + assert result == "single line\n" + + result = read(temp_path, offset=1, limit=5) + assert result == "single line\n" + finally: + pathlib.Path(temp_path).unlink() + + def test_read_file_not_found(self): + """Test reading from a non-existent file.""" + with pytest.raises(FileNotFoundError, match="File not found: /nonexistent/path"): + read("/nonexistent/path", offset=1, limit=5) + + def test_read_offset_beyond_file_length(self): + """Test reading with offset beyond file length.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line 1\nline 2\n") + temp_path = f.name + + try: + # Offset beyond file length should return empty string + result = read(temp_path, offset=10, limit=5) + assert result == "" + finally: + pathlib.Path(temp_path).unlink() + + def test_read_negative_and_zero_values(self): + """Test read function with negative and zero offset/limit values.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line 1\nline 2\nline 3\n") + temp_path = f.name + + try: + # Negative offset should be normalized to 1 + result = read(temp_path, offset=-5, limit=2) + assert result == "line 1\nline 2\n" + + # Zero offset should be normalized to 1 + result = read(temp_path, offset=0, limit=2) + assert result == "line 1\nline 2\n" + + # Zero limit should return empty string + result = read(temp_path, offset=1, limit=0) + assert result == "" + + # Negative limit should return empty string + result = read(temp_path, offset=1, limit=-5) + assert result == "" + + finally: + pathlib.Path(temp_path).unlink() + + def test_read_unicode_content(self): + """Test reading file with unicode content.""" + with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False, suffix='.txt') as f: + f.write("línea 1 🚀\nlínea 2 ❤️\nlínea 3 🎉\n") + temp_path = f.name + + try: + result = read(temp_path, offset=1, limit=2) + assert result == "línea 1 🚀\nlínea 2 ❤️\n" + finally: + pathlib.Path(temp_path).unlink() + + def test_read_with_encoding_errors(self): + """Test reading file with encoding issues using replace errors handling.""" + # This test ensures the 'replace' error handling works properly + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("valid line\n") + temp_path = f.name + + try: + # The function should handle encoding errors gracefully + result = read(temp_path, offset=1, limit=1) + assert result == "valid line\n" + finally: + pathlib.Path(temp_path).unlink() + + +class TestEditFunction: + """Test the edit function.""" + + def test_edit_replace_first_occurrence(self): + """Test replacing the first occurrence of a string.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("foo bar foo baz foo") + temp_path = f.name + + try: + edit(temp_path, "foo", "qux", replace_all=False) + + # Read the file to verify the change + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "qux bar foo baz foo" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_replace_all_occurrences(self): + """Test replacing all occurrences of a string.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("foo bar foo baz foo") + temp_path = f.name + + try: + edit(temp_path, "foo", "qux", replace_all=True) + + # Read the file to verify the change + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "qux bar qux baz qux" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_multiline_content(self): + """Test editing multiline content.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line 1\nold content\nline 3\nold content\nline 5") + temp_path = f.name + + try: + edit(temp_path, "old content", "new content", replace_all=True) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "line 1\nnew content\nline 3\nnew content\nline 5" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_string_not_found(self): + """Test editing when the target string is not found.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("hello world") + temp_path = f.name + + try: + # This should not raise an error, just leave the file unchanged + edit(temp_path, "nonexistent", "replacement", replace_all=False) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "hello world" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_file_not_found(self): + """Test editing a non-existent file.""" + with pytest.raises(FileNotFoundError, match="File not found: /nonexistent/path"): + edit("/nonexistent/path", "old", "new") + + def test_edit_empty_old_string(self): + """Test editing with an empty old_string.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("hello world") + temp_path = f.name + + try: + with pytest.raises(ValueError, match="old_string must not be empty"): + edit(temp_path, "", "replacement") + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_unicode_content(self): + """Test editing file with unicode content.""" + with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False, suffix='.txt') as f: + f.write("hola 🌟 mundo 🌟 adiós") + temp_path = f.name + + try: + edit(temp_path, "🌟", "⭐", replace_all=True) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "hola ⭐ mundo ⭐ adiós" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_newline_characters(self): + """Test editing with newline characters.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("line1\nold\nline3") + temp_path = f.name + + try: + edit(temp_path, "\nold\n", "\nnew\n", replace_all=False) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "line1\nnew\nline3" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_replace_with_empty_string(self): + """Test replacing content with empty string (deletion).""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("keep this DELETE_ME keep this too") + temp_path = f.name + + try: + edit(temp_path, "DELETE_ME ", "", replace_all=False) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == "keep this keep this too" + finally: + pathlib.Path(temp_path).unlink() + + def test_edit_atomicity(self): + """Test that edit operation is atomic (file is either fully updated or unchanged).""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + original_content = "original content" + f.write(original_content) + temp_path = f.name + + try: + # Mock pathlib.Path.write_text to raise an exception + with patch.object(pathlib.Path, 'write_text', side_effect=IOError("Disk full")): + with pytest.raises(IOError): + edit(temp_path, "original", "modified") + + # File should remain unchanged due to the error + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == original_content + + finally: + pathlib.Path(temp_path).unlink() + + +class TestWriteFunction: + """Test the write function.""" + + def test_write_new_file(self): + """Test writing content to a new file.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "new_file.txt" + test_content = "Hello, world!\nThis is a test." + + write(str(temp_path), test_content) + + # Verify the file was created and contains the correct content + assert temp_path.exists() + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_overwrite_existing_file(self): + """Test overwriting an existing file.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("original content") + temp_path = f.name + + try: + new_content = "new content that replaces the old" + write(temp_path, new_content) + + # Verify the file was overwritten + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == new_content + finally: + pathlib.Path(temp_path).unlink() + + def test_write_empty_content(self): + """Test writing empty content to a file.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "empty_file.txt" + + write(str(temp_path), "") + + # Verify the file exists and is empty + assert temp_path.exists() + content = temp_path.read_text(encoding='utf-8') + assert content == "" + + def test_write_multiline_content(self): + """Test writing multiline content.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "multiline.txt" + test_content = "Line 1\nLine 2\nLine 3\n" + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_unicode_content(self): + """Test writing unicode content.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "unicode.txt" + test_content = "Hello 世界! 🌍 Café naïve résumé" + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_large_content(self): + """Test writing large content.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "large.txt" + # Create content with 10000 lines + test_content = "\n".join([f"Line {i}" for i in range(10000)]) + + write(str(temp_path), test_content) + + content = pathlib.Path(temp_path).read_text(encoding='utf-8') + assert content == test_content + + @pytest.mark.skip("Fix this test for CRLF newlines (Windows problem)") + def test_write_special_characters(self): + """Test writing content with special characters.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "special.txt" + test_content = 'Content with "quotes", \ttabs, and \nnewlines\r\n' + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_invalid_directory(self): + """Test writing to a non-existent directory.""" + invalid_path = "/nonexistent/directory/file.txt" + + with pytest.raises(OSError): + write(invalid_path, "test content") + + def test_write_permission_denied(self): + """Test writing to a file without write permissions.""" + with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt') as f: + f.write("original") + temp_path = f.name + + try: + # Make file read-only + pathlib.Path(temp_path).chmod(0o444) + + with pytest.raises(OSError): + write(temp_path, "new content") + + finally: + # Restore write permissions and clean up + pathlib.Path(temp_path).chmod(0o644) + pathlib.Path(temp_path).unlink() + + def test_write_binary_like_content(self): + """Test writing content that looks like binary data.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "binary_like.txt" + # Content with null bytes and other control characters + test_content = "Normal text\x00null byte\x01control char\xff" + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_json_content(self): + """Test writing JSON-like content.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "data.json" + test_content = '{"name": "test", "value": 42, "nested": {"key": "value"}}' + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + def test_write_code_content(self): + """Test writing code content with proper indentation.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "code.py" + test_content = '''def hello(): + """Say hello.""" + print("Hello, world!") + + if True: + return "success" +''' + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + @pytest.mark.skip("Fix this test for CRLF newlines (Windows problem)") + def test_write_preserves_line_endings(self): + """Test that write preserves different line endings.""" + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = pathlib.Path(temp_dir) / "line_endings.txt" + test_content = "Unix\nWindows\r\nMac\rMixed\r\n" + + write(str(temp_path), test_content) + + content = temp_path.read_text(encoding='utf-8') + assert content == test_content + + +class TestSearchGrepFunction: + """Test the search_grep function.""" + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_bash_integration(self, mock_bash): + """Test that search_grep correctly calls bash with proper arguments.""" + mock_bash.return_value = "test.py:1:def test():" + + result = await search_grep("def", "*.py") + + # Verify bash was called + mock_bash.assert_called_once() + call_args = mock_bash.call_args[0][0] + + # Check that the command contains expected parts + assert "rg" in call_args + assert "--color=never" in call_args + assert "--line-number" in call_args + assert "--with-filename" in call_args + assert "-g" in call_args + assert "*.py" in call_args + assert "def" in call_args + + assert result == "test.py:1:def test():" + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_default_include(self, mock_bash): + """Test search_grep with default include pattern.""" + mock_bash.return_value = "" + + await search_grep("pattern") + + call_args = mock_bash.call_args[0][0] + # Should not contain -g flag when using default "*" pattern + assert "-g" not in call_args or "\"*\"" not in call_args + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_bash_exception(self, mock_bash): + """Test search_grep handling of bash execution errors.""" + mock_bash.side_effect = Exception("Command failed") + + with pytest.raises(RuntimeError, match="Ripgrep search failed: Command failed"): + await search_grep("pattern", "*.txt") + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_basic_pattern(self, mock_bash): + """Test basic pattern searching.""" + mock_bash.return_value = "test1.py:1:def hello_world():\ntest2.py:2: def method(self):" + + result = await search_grep(r"def\s+\w+", "*.py") + + # Should find function definitions in both files + assert "test1.py" in result + assert "test2.py" in result + assert "def hello_world" in result + assert "def method" in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_no_matches(self, mock_bash): + """Test search with no matches.""" + mock_bash.return_value = "" + + result = await search_grep("nonexistent_pattern", "*.txt") + assert result == "" + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_with_include_pattern(self, mock_bash): + """Test search with file include pattern.""" + mock_bash.return_value = "script.py:1:import os" + + result = await search_grep("import", "*.py") + assert "script.py" in result + assert "readme.txt" not in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_special_characters(self, mock_bash): + """Test searching for patterns with special regex characters.""" + # Mock different return values for different calls + mock_bash.side_effect = [ + "special.txt:2:email: user@domain.com", + "special.txt:1:price: $10.99" + ] + + # Search for email pattern + result = await search_grep(r"\w+@\w+\.\w+", "*.txt") + assert "user@domain.com" in result + + # Search for price pattern + result = await search_grep(r"\$\d+\.\d+", "*.txt") + assert "$10.99" in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_unicode_content(self, mock_bash): + """Test searching in files with unicode content.""" + mock_bash.return_value = "unicode.txt:1:Hello 世界" + + result = await search_grep("世界", "*.txt") + assert "世界" in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_line_anchors(self, mock_bash): + """Test line anchor patterns (^ and $).""" + mock_bash.side_effect = [ + "anchors.txt:1:start of line", + "anchors.txt:3:line with end" + ] + + # Search for lines starting with specific text + result = await search_grep("^start", "*.txt") + assert "start of line" in result + + # Search for lines ending with specific text + result = await search_grep("end$", "*.txt") + assert "line with end" in result + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_case_insensitive_pattern(self, mock_bash): + """Test case insensitive regex patterns.""" + mock_bash.return_value = "mixed_case.txt:1:TODO: fix this\nmixed_case.txt:2:todo: also this\nmixed_case.txt:3:ToDo: and this" + + # Case insensitive search + result = await search_grep("(?i)todo", "*.txt") + lines = result.strip().split('\n') if result.strip() else [] + assert len(lines) == 3 # Should match all three variants + + @patch('jupyter_ai.tools.default_toolkit.bash') + @pytest.mark.asyncio + async def test_search_grep_complex_glob_patterns(self, mock_bash): + """Test various complex glob patterns.""" + mock_bash.return_value = "src/main.py:1:import sys\nsrc/utils.py:1:import os" + + # Test recursive search in src directory + result = await search_grep("import", "src/**/*.py") + assert "src/main.py" in result + assert "src/utils.py" in result + assert "test_main.py" not in result \ No newline at end of file From 4059489d7af208e74cbd023b5db8ccd2b31c81a6 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 25 Aug 2025 06:22:04 +0200 Subject: [PATCH 05/16] update tool calling APIs --- .../jupyter_ai/litellm_utils/__init__.py | 1 + .../jupyter_ai/litellm_utils/run_tools.py | 50 ++++++++++++++++++ .../litellm_utils/streaming_utils.py | 2 +- .../jupyter_ai/litellm_utils/toolcall_list.py | 4 ++ .../jupyter_ai/personas/base_persona.py | 51 +++++-------------- .../personas/jupyternaut/jupyternaut.py | 4 +- 6 files changed, 70 insertions(+), 42 deletions(-) create mode 100644 packages/jupyter-ai/jupyter_ai/litellm_utils/run_tools.py diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py index 787493764..ff1c7d8c3 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py @@ -1,2 +1,3 @@ from .toolcall_list import * from .streaming_utils import * +from .run_tools import * diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/run_tools.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/run_tools.py new file mode 100644 index 000000000..2b148b505 --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/run_tools.py @@ -0,0 +1,50 @@ +import asyncio +from pydantic import BaseModel +from .toolcall_list import ToolCallList +from ..tools import Toolkit + + +class ToolCallOutput(BaseModel): + tool_call_id: str + role: str = "tool" + name: str + content: str + + +async def run_tools(tool_call_list: ToolCallList, toolkit: Toolkit) -> list[dict]: + """ + Runs the tools specified in the list of tool calls returned by + `self.stream_message()`. + + Returns `list[ToolCallOutput]`. The outputs should be appended directly to + the message history on the next request made to the LLM. + """ + tool_calls = tool_call_list.resolve() + if not len(tool_calls): + return [] + + tool_outputs: list[dict] = [] + for tool_call in tool_calls: + # Get tool definition from the correct toolkit + # TODO: validation? + tool_name = tool_call.function.name + tool_defn = toolkit.get_tool_unsafe(tool_name) + + # Run tool and store its output + try: + output = tool_defn.callable(**tool_call.function.arguments) + if asyncio.iscoroutine(output): + output = await output + except Exception as e: + output = str(e) + + # Store the tool output in a dictionary accepted by LiteLLM + output_dict = { + "tool_call_id": tool_call.id, + "role": "tool", + "name": tool_call.function.name, + "content": output, + } + tool_outputs.append(output_dict) + + return tool_outputs diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py index febe3f7f2..7251c88ed 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py @@ -7,7 +7,7 @@ class StreamResult(BaseModel): ID of the new message. """ - tool_calls: ToolCallList + tool_call_list: ToolCallList """ Tool calls requested by the LLM in its streamed response. """ diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py index 654939ebb..e7094e4f9 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py @@ -170,4 +170,8 @@ def to_json(self) -> list[dict[str, Any]]: return [ model.model_dump() for model in self._aggregate ] + + + def __len__(self) -> int: + return len(self._aggregate) \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index 4e550cfbd..a33ad69f3 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -17,7 +17,7 @@ from traitlets.config import LoggingConfigurable from .persona_awareness import PersonaAwareness -from ..litellm_utils import ToolCallList, StreamResult, ResolvedToolCall +from ..litellm_utils import ToolCallList, StreamResult, run_tools, ToolCallOutput # Import toolkits from ..tools.default_toolkit import DEFAULT_TOOLKIT @@ -255,7 +255,7 @@ async def stream_message( """ stream_id: Optional[str] = None stream_interrupted = False - tool_calls = ToolCallList() + tool_call_list = ToolCallList() try: self.awareness.set_local_state_field("isWriting", True) @@ -319,7 +319,7 @@ async def stream_message( append=True, ) if toolcalls_delta: - tool_calls += toolcalls_delta + tool_call_list += toolcalls_delta except Exception as e: self.log.error( @@ -348,15 +348,13 @@ async def stream_message( return None # TODO: determine where this should live - resolved_toolcalls = tool_calls.resolve() - if len(resolved_toolcalls): - count = len(resolved_toolcalls) - names = sorted([tc.function.name for tc in resolved_toolcalls]) - self.log.info(f"AI response triggered {count} tool calls: {names}") + count = len(tool_call_list) + if count > 0: + self.log.info(f"AI response triggered {count} tool calls.") return StreamResult( id=stream_id, - tool_calls=tool_calls + tool_call_list=tool_call_list ) @@ -520,38 +518,13 @@ def get_tools(self, model_id: str) -> list[dict]: return tool_descriptions - async def run_tools(self, tools: list[ResolvedToolCall]) -> list[dict]: + async def run_tools(self, tool_call_list: ToolCallList) -> list[ToolCallOutput]: """ - Runs the tools specified in the list of tool calls returned by - `self.stream_message()`. Returns a list of dictionaries - `toolcall_outputs: list[dict]`, which should be appended directly to the - message history on the next invocation of the LLM. + Runs the tools specified in a given tool call list using the default + toolkit. """ - if not len(tools): - return [] - - tool_outputs: list[dict] = [] - for tool_call in tools: - # Get tool definition from the correct toolkit - # TODO: validation? - tool_name = tool_call.function.name - tool_defn = DEFAULT_TOOLKIT.get_tool_unsafe(tool_name) - - # Run tool and store its output - output = tool_defn.callable(**tool_call.function.arguments) - if asyncio.iscoroutine(output): - output = await output - - # Store the tool output in a dictionary accepted by LiteLLM - output_dict = { - "tool_call_id": tool_call.id, - "role": "tool", - "name": tool_call.function.name, - "content": output, - } - tool_outputs.append(output_dict) - - self.log.info(f"Ran {len(tools)} tool functions.") + tool_outputs = await run_tools(tool_call_list, toolkit=DEFAULT_TOOLKIT) + self.log.info(f"Ran {len(tool_outputs)} tool functions.") return tool_outputs diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index 3c350c1f4..914372ddf 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -62,7 +62,7 @@ async def process_message(self, message: Message) -> None: # Append new reply to `messages` reply = self.ychat.get_message(result.id) - tool_calls_json = result.tool_calls.to_json() + tool_calls_json = result.tool_call_list.to_json() messages.append({ "role": "assistant", "content": reply.body, @@ -80,7 +80,7 @@ async def process_message(self, message: Message) -> None: ), append=True) # Run tools and append outputs to `messages` - tool_call_outputs = await self.run_tools(result.tool_calls.resolve()) + tool_call_outputs = await self.run_tools(result.tool_call_list) messages.extend(tool_call_outputs) # Add tool call outputs to YChat (not synced with `messages`) From 056676eeb9f13b406e62ceabb3b54074bbb8c7c1 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 25 Aug 2025 06:50:30 +0200 Subject: [PATCH 06/16] improve bash tool reliability, drop jupyter_ai_tools for now --- .../jupyter_ai/personas/base_persona.py | 2 +- .../jupyter_ai/tools/default_toolkit.py | 54 +++++++++++++++++-- 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index a33ad69f3..37271e1ba 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -477,7 +477,7 @@ def get_tools(self, model_id: str) -> list[dict]: environment. These may include: - The default set of tool functions in Jupyter AI, defined in the - `jupyter_ai_tools` package. + the default toolkit from `jupyter_ai.tools`. - (TODO) Tools provided by MCP server configuration, if any. diff --git a/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py b/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py index 79c8a7675..d850775af 100644 --- a/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py +++ b/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py @@ -1,7 +1,9 @@ -from .models import Tool, Toolkit -from jupyter_ai_tools.toolkits.code_execution import bash - +import asyncio import pathlib +import shlex +from typing import Optional + +from .models import Tool, Toolkit def read(file_path: str, offset: int, limit: int) -> str: @@ -247,6 +249,52 @@ async def search_grep(pattern: str, include: str = "*") -> str: raise RuntimeError(f"Ripgrep search failed: {str(e)}") from e +async def bash(command: str, timeout: Optional[int] = None) -> str: + """Executes a bash command and returns the result + + Args: + command: The bash command to execute + timeout: Optional timeout in seconds + + Returns: + The command output (stdout and stderr combined) + """ + # coerce `timeout` to the correct type. sometimes LLMs pass this as a string + if isinstance(timeout, str): + timeout = int(timeout) + + proc = await asyncio.create_subprocess_exec( + *shlex.split(command), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout) + stdout = stdout.decode("utf-8") + stderr = stderr.decode("utf-8") + + if proc.returncode != 0: + info = f"Command returned non-zero exit code {proc.returncode}. This usually indicates an error." + info += "\n\n" + fr"Original command: {command}" + if not (stdout or stderr): + info += "\n\nNo further information was given in stdout or stderr." + return info + if stdout: + info += f"stdout:\n\n```\n{stdout}\n```\n\n" + if stderr: + info += f"stderr:\n\n```\n{stderr}\n```\n\n" + return info + + if stdout: + return stdout + return "Command executed successfully with exit code 0. No stdout/stderr was returned." + + except asyncio.TimeoutError: + proc.kill() + return f"Command timed out after {timeout} seconds" + + DEFAULT_TOOLKIT = Toolkit(name="jupyter-ai-default-toolkit") DEFAULT_TOOLKIT.add_tool(Tool(callable=bash)) DEFAULT_TOOLKIT.add_tool(Tool(callable=read)) From 1311989613d03f982a0f1f0fe167339d0a5c76f3 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 25 Aug 2025 14:58:00 +0200 Subject: [PATCH 07/16] add jai-tool-call web component to show tool calls & outputs --- .../jupyter_ai/personas/base_persona.py | 2 +- .../personas/jupyternaut/jupyternaut.py | 83 +++++++++--- packages/jupyter-ai/package.json | 1 + packages/jupyter-ai/src/index.ts | 2 + .../jupyter-ai/src/web-components/index.ts | 2 + .../src/web-components/jai-tool-call.tsx | 121 ++++++++++++++++++ .../jupyter-ai/src/web-components/plugin.ts | 65 ++++++++++ yarn.lock | 20 +++ 8 files changed, 278 insertions(+), 18 deletions(-) create mode 100644 packages/jupyter-ai/src/web-components/index.ts create mode 100644 packages/jupyter-ai/src/web-components/jai-tool-call.tsx create mode 100644 packages/jupyter-ai/src/web-components/plugin.ts diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index 37271e1ba..a6abef496 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -518,7 +518,7 @@ def get_tools(self, model_id: str) -> list[dict]: return tool_descriptions - async def run_tools(self, tool_call_list: ToolCallList) -> list[ToolCallOutput]: + async def run_tools(self, tool_call_list: ToolCallList) -> list[dict]: """ Runs the tools specified in a given tool call list using the default toolkit. diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index 914372ddf..1e9669c6d 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -5,6 +5,7 @@ from jupyterlab_chat.models import Message from litellm import acompletion +from ...litellm_utils import StreamResult, ToolCallOutput from ..base_persona import BasePersona, PersonaDefaults from ..persona_manager import SYSTEM_USERNAME from .prompt_template import ( @@ -69,30 +70,78 @@ async def process_message(self, message: Message) -> None: "tool_calls": tool_calls_json }) - # Show tool call requests to YChat (not synced with `messages`) - if len(tool_calls_json): - self.ychat.update_message(Message( - id=result.id, - body=f"\n\n```\n{json.dumps(tool_calls_json, indent=2)}\n```\n", - sender=self.id, - time=time.time(), - raw_time=False - ), append=True) + # Render tool calls in new message + if len(result.tool_call_list): + self.render_tool_calls(result) # Run tools and append outputs to `messages` tool_call_outputs = await self.run_tools(result.tool_call_list) messages.extend(tool_call_outputs) - # Add tool call outputs to YChat (not synced with `messages`) + # Render tool call outputs in new message if tool_call_outputs: - self.ychat.update_message(Message( - id=result.id, - body=f"\n\n```\n{json.dumps(tool_call_outputs, indent=2)}\n```\n", - sender=self.id, - time=time.time(), - raw_time=False - ), append=True) + self.render_tool_call_outputs( + message_id=result.id, + tool_call_outputs=tool_call_outputs + ) + def render_tool_calls(self, stream_result: StreamResult): + """ + Renders tool calls by appending the tool calls to a message. + """ + message_id = stream_result.id + tool_call_list = stream_result.tool_call_list + + for tool_call in tool_call_list.resolve(): + id = tool_call.id + index = tool_call.index + type_val = tool_call.type + function = tool_call.function.model_dump_json() + # We have to HTML-escape double quotes in the JSON string. + function = function.replace('"', """) + + self.ychat.update_message(Message( + id=message_id, + body=f'\n\n\n', + sender=self.id, + time=time.time(), + raw_time=False + ), append=True) + + + def render_tool_call_outputs(self, message_id: str, tool_call_outputs: list[dict]): + # TODO + # self.ychat.update_message(Message( + # id=message_id, + # body=f"\n\n```\n{json.dumps(tool_call_outputs, indent=2)}\n```\n", + # sender=self.id, + # time=time.time(), + # raw_time=False + # ), append=True) + + # Updates the content of the last message directly + message = self.ychat.get_message(message_id) + body = message.body + for output in tool_call_outputs: + if not output['content']: + output['content'] = "" + output = ToolCallOutput(**output) + tool_id = output.tool_call_id + tool_output = output.model_dump_json() + tool_output = tool_output.replace('"', '"') + body = body.replace( + f'; + }; + index: number; + output?: { + tool_call_id: string; + role: string; + name: string; + content: string | null; + }; +}; + +export function JaiToolCall(props: JaiToolCallProps): JSX.Element | null { + const [expanded, setExpanded] = useState(false); + console.log({ + output: props.output + }); + const toolComplete = !!(props.output && Object.keys(props.output).length > 0); + const hasOutput = !!(toolComplete && props.output?.content?.length); + + const handleExpandClick = () => { + setExpanded(!expanded); + }; + + const statusIcon: JSX.Element = toolComplete ? ( + + ) : ( + + ); + + const statusText: JSX.Element = ( + + {toolComplete ? 'Ran' : 'Running'}{' '} + + {props.function.name} + {' '} + tool + {toolComplete ? '.' : '...'} + + ); + + const toolArgsJson = useMemo( + () => JSON.stringify(props.function.arguments, null, 2), + [props.function.arguments] + ); + + const toolArgsSection: JSX.Element | null = + toolArgsJson === '{}' ? null : ( + + + Tool arguments + +
+          {toolArgsJson}
+        
+
+ ); + + const toolOutputSection: JSX.Element | null = hasOutput ? ( + + + Tool output + +
{props.output?.content}
+
+ ) : null; + + if (!props.id || !props.type || !props.function) { + return null; + } + + return ( + + + {statusIcon} + {statusText} + + + + + + + + + {toolArgsSection} + {toolOutputSection} + + + + ); +} diff --git a/packages/jupyter-ai/src/web-components/plugin.ts b/packages/jupyter-ai/src/web-components/plugin.ts new file mode 100644 index 000000000..ac0b0674a --- /dev/null +++ b/packages/jupyter-ai/src/web-components/plugin.ts @@ -0,0 +1,65 @@ +import { + JupyterFrontEnd, + JupyterFrontEndPlugin +} from '@jupyterlab/application'; +import r2wc from '@r2wc/react-to-web-component'; + +import { JaiToolCall } from './jai-tool-call'; +import { ISanitizer, Sanitizer } from '@jupyterlab/apputils'; +import { IRenderMime } from '@jupyterlab/rendermime'; + +/** + * Plugin that registers custom web components for usage in AI responses. + */ +export const webComponentsPlugin: JupyterFrontEndPlugin = + { + id: '@jupyter-ai/core:web-components', + autoStart: true, + provides: ISanitizer, + activate: (app: JupyterFrontEnd) => { + // Define the JaiToolCall web component + // ['id', 'type', 'function', 'index', 'output'] + const JaiToolCallWebComponent = r2wc(JaiToolCall, { + props: { + id: 'string', + type: 'string', + function: 'json', + index: 'number', + output: 'json' + } + }); + + // Register the web component + customElements.define('jai-tool-call', JaiToolCallWebComponent); + console.log("Registered custom 'jai-tool-call' web component."); + + // Finally, override the default Rendermime sanitizer to allow custom web + // components in the output. + class CustomSanitizer + extends Sanitizer + implements IRenderMime.ISanitizer + { + sanitize( + dirty: string, + customOptions: IRenderMime.ISanitizerOptions + ): string { + const options: IRenderMime.ISanitizerOptions = { + // default sanitizer options + ...(this as any)._options, + // custom sanitizer options (variable per call) + ...customOptions + }; + + return super.sanitize(dirty, { + ...options, + allowedTags: [...(options?.allowedTags ?? []), 'jai-tool-call'], + allowedAttributes: { + ...options?.allowedAttributes, + 'jai-tool-call': ['id', 'type', 'function', 'index', 'output'] + } + }); + } + } + return new CustomSanitizer(); + } + }; diff --git a/yarn.lock b/yarn.lock index 173eef6b3..d2ff65f5a 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2256,6 +2256,7 @@ __metadata: "@lumino/widgets": ^2.3.2 "@mui/icons-material": ^5.11.0 "@mui/material": ^5.11.0 + "@r2wc/react-to-web-component": ^2.0.4 "@stylistic/eslint-plugin": ^3.0.1 "@types/jest": ^29 "@types/react-dom": ^18.2.0 @@ -4826,6 +4827,25 @@ __metadata: languageName: node linkType: hard +"@r2wc/core@npm:^1.0.0": + version: 1.2.0 + resolution: "@r2wc/core@npm:1.2.0" + checksum: e0dc23e8fd1f0d96193b67f5eb04b74b25b9f4609778e6ea2427c565eb590f458553cad307a2fdb3fc4614f6a576d7701b9bacf11775958bc560cc3b3b5aaae7 + languageName: node + linkType: hard + +"@r2wc/react-to-web-component@npm:^2.0.4": + version: 2.0.4 + resolution: "@r2wc/react-to-web-component@npm:2.0.4" + dependencies: + "@r2wc/core": ^1.0.0 + peerDependencies: + react: ^18.0.0 || ^19.0.0 + react-dom: ^18.0.0 || ^19.0.0 + checksum: 7b140ffd612173a30d74717d18efcf554774ef0ed0fe72f207ec21df707685ef5f4c34521e6840041665550c6461171dc32f12835f35beb1788ccac0c66c0e5c + languageName: node + linkType: hard + "@rjsf/core@npm:^5.13.4": version: 5.17.0 resolution: "@rjsf/core@npm:5.17.0" From 5223fa644b4c6d7eab78f341b9443cada9e36273 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Sat, 30 Aug 2025 15:23:27 +0200 Subject: [PATCH 08/16] move litellm_utils => litellm_lib --- .../{litellm_utils => litellm_lib}/__init__.py | 4 ++-- .../{litellm_utils => litellm_lib}/run_tools.py | 8 -------- .../test_toolcall_list.py | 0 .../toolcall_list.py | 15 +++++++++------ .../streaming_utils.py => litellm_lib/types.py} | 8 ++++++++ .../jupyter_ai/personas/base_persona.py | 5 +---- .../personas/jupyternaut/jupyternaut.py | 13 +------------ 7 files changed, 21 insertions(+), 32 deletions(-) rename packages/jupyter-ai/jupyter_ai/{litellm_utils => litellm_lib}/__init__.py (63%) rename packages/jupyter-ai/jupyter_ai/{litellm_utils => litellm_lib}/run_tools.py (90%) rename packages/jupyter-ai/jupyter_ai/{litellm_utils => litellm_lib}/test_toolcall_list.py (100%) rename packages/jupyter-ai/jupyter_ai/{litellm_utils => litellm_lib}/toolcall_list.py (92%) rename packages/jupyter-ai/jupyter_ai/{litellm_utils/streaming_utils.py => litellm_lib/types.py} (64%) diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/__init__.py similarity index 63% rename from packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py rename to packages/jupyter-ai/jupyter_ai/litellm_lib/__init__.py index ff1c7d8c3..edc2e51cc 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/__init__.py @@ -1,3 +1,3 @@ -from .toolcall_list import * -from .streaming_utils import * from .run_tools import * +from .toolcall_list import * +from .types import * diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/run_tools.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py similarity index 90% rename from packages/jupyter-ai/jupyter_ai/litellm_utils/run_tools.py rename to packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py index 2b148b505..0bb815d20 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/run_tools.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py @@ -1,16 +1,8 @@ import asyncio -from pydantic import BaseModel from .toolcall_list import ToolCallList from ..tools import Toolkit -class ToolCallOutput(BaseModel): - tool_call_id: str - role: str = "tool" - name: str - content: str - - async def run_tools(tool_call_list: ToolCallList, toolkit: Toolkit) -> list[dict]: """ Runs the tools specified in the list of tool calls returned by diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/test_toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/test_toolcall_list.py similarity index 100% rename from packages/jupyter-ai/jupyter_ai/litellm_utils/test_toolcall_list.py rename to packages/jupyter-ai/jupyter_ai/litellm_lib/test_toolcall_list.py diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py similarity index 92% rename from packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py rename to packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py index e7094e4f9..311e20a6a 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/toolcall_list.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py @@ -29,10 +29,9 @@ class ResolvedToolCall(BaseModel): `litellm.utils.ChatCompletionDeltaToolCall`. """ - id: str | None + id: str """ - The ID of the tool call. This should always be provided by LiteLLM, this - type is left optional as we do not use this attribute. + The ID of the tool call. """ type: str @@ -62,7 +61,7 @@ class ToolCallList(BaseModel): is used to aggregate the tool call deltas yielded from a LiteLLM response stream and produce a list of tool calls. - After all tool call deltas are added, the `process()` method may be called + After all tool call deltas are added, the `resolve()` method may be called to return a list of resolved tool calls. Example usage: @@ -75,7 +74,7 @@ class ToolCallList(BaseModel): tool_call_delta = chunk.choices[0].delta.tool_calls tool_call_list += tool_call_delta - tool_call_list.resolve() + tool_calls = tool_call_list.resolve() ``` """ @@ -128,7 +127,11 @@ def __add__(self, other: list[ChatCompletionDeltaToolCall] | None) -> 'ToolCallL def resolve(self) -> list[ResolvedToolCall]: """ - Resolve the aggregated tool call delta lists into a list of tool calls. + Returns the aggregated tool calls as `list[ResolvedToolCall]`. + + Raises an exception if any function arguments could not be parsed from + JSON into a dictionary. This method should only be called after the + stream completed without errors. """ resolved_toolcalls: list[ResolvedToolCall] = [] for i, raw_toolcall in enumerate(self._aggregate): diff --git a/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py similarity index 64% rename from packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py rename to packages/jupyter-ai/jupyter_ai/litellm_lib/types.py index 7251c88ed..b901711c3 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_utils/streaming_utils.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py @@ -1,3 +1,4 @@ +from __future__ import annotations from pydantic import BaseModel from .toolcall_list import ToolCallList @@ -11,3 +12,10 @@ class StreamResult(BaseModel): """ Tool calls requested by the LLM in its streamed response. """ + +class ToolCallOutput(BaseModel): + tool_call_id: str + role: str = "tool" + name: str + content: str + diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index a6abef496..99763f409 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -17,15 +17,12 @@ from traitlets.config import LoggingConfigurable from .persona_awareness import PersonaAwareness -from ..litellm_utils import ToolCallList, StreamResult, run_tools, ToolCallOutput - -# Import toolkits +from ..litellm_lib import ToolCallList, StreamResult, run_tools from ..tools.default_toolkit import DEFAULT_TOOLKIT if TYPE_CHECKING: from collections.abc import AsyncIterator from .persona_manager import PersonaManager - from ..tools import Toolkit class PersonaDefaults(BaseModel): """ diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index 1e9669c6d..7a6261a67 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -1,11 +1,10 @@ from typing import Any, Optional import time -import json from jupyterlab_chat.models import Message from litellm import acompletion -from ...litellm_utils import StreamResult, ToolCallOutput +from ...litellm_lib import StreamResult, ToolCallOutput from ..base_persona import BasePersona, PersonaDefaults from ..persona_manager import SYSTEM_USERNAME from .prompt_template import ( @@ -110,15 +109,6 @@ def render_tool_calls(self, stream_result: StreamResult): def render_tool_call_outputs(self, message_id: str, tool_call_outputs: list[dict]): - # TODO - # self.ychat.update_message(Message( - # id=message_id, - # body=f"\n\n```\n{json.dumps(tool_call_outputs, indent=2)}\n```\n", - # sender=self.id, - # time=time.time(), - # raw_time=False - # ), append=True) - # Updates the content of the last message directly message = self.ychat.get_message(message_id) body = message.body @@ -134,7 +124,6 @@ def render_tool_call_outputs(self, message_id: str, tool_call_outputs: list[dict f' Date: Thu, 18 Sep 2025 10:22:44 -0700 Subject: [PATCH 09/16] migrate all agent logic to module using pocketflow --- .../jupyter_ai/default_flow/__init__.py | 1 + .../jupyter_ai/default_flow/default_flow.py | 335 ++++++++++++++++++ .../jupyter_ai/litellm_lib/run_tools.py | 22 +- .../jupyter_ai/litellm_lib/toolcall_list.py | 81 ++++- .../jupyter_ai/litellm_lib/types.py | 42 ++- .../jupyter_ai/personas/__init__.py | 4 +- .../jupyter_ai/personas/base_persona.py | 189 +--------- .../personas/jupyternaut/jupyternaut.py | 170 ++------- .../jupyter-ai/jupyter_ai/tools/models.py | 22 ++ 9 files changed, 509 insertions(+), 357 deletions(-) create mode 100644 packages/jupyter-ai/jupyter_ai/default_flow/__init__.py create mode 100644 packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py diff --git a/packages/jupyter-ai/jupyter_ai/default_flow/__init__.py b/packages/jupyter-ai/jupyter_ai/default_flow/__init__.py new file mode 100644 index 000000000..849952ada --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/default_flow/__init__.py @@ -0,0 +1 @@ +from .default_flow import * \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py b/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py new file mode 100644 index 000000000..9fbdda70d --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py @@ -0,0 +1,335 @@ +from pocketflow import AsyncNode, AsyncFlow +from jupyterlab_chat.models import Message, NewMessage +from jupyterlab_chat.ychat import YChat +from typing import Any, Optional, Tuple, TypedDict +from jinja2 import Template +from litellm import acompletion, ModelResponseStream +import time +import logging + +from ..litellm_lib import ToolCallList, run_tools, LitellmToolCallOutput +from ..tools import Toolkit +from ..personas import SYSTEM_USERNAME, PersonaAwareness + +DEFAULT_RESPONSE_TEMPLATE = """ +{{ content }} +{{ tool_call_ui_elements }} +""".strip() + +class DefaultFlowParams(TypedDict): + """ + Parameters expected by the default flow provided by Jupyter AI. + """ + + model_id: str + + ychat: YChat + + awareness: PersonaAwareness + + persona_id: str + + logger: logging.Logger + + model_args: dict[str, Any] | None + """ + Custom keyword arguments forwarded to `litellm.acompletion()`. Defaults to + `{}` if unset. + """ + + system_prompt: Optional[str] + """ + System prompt that will be used as the first message in the list of messages + sent to the language model. Unused if unset. + """ + + response_template: Template | None + """ + Jinja2 template used to template the response. If one is not given, + `DEFAULT_RESPONSE_TEMPLATE` is used. + + It should take `content: str` and `tool_call_ui_elements: str` as format arguments. + """ + + toolkit: Toolkit | None + """ + Toolkit of tools. Unused if unset. + """ + + history_size: int | None + """ + Number of messages preceding the message triggering this flow to include + in the prompt as context. Defaults to 2 if unset. + """ + +class JaiAsyncNode(AsyncNode): + """ + An AsyncNode with custom properties & helper methods used exclusively in the + Jupyter AI extension. + """ + + @property + def model_id(self) -> str: + return self.params["model_id"] + + @property + def ychat(self) -> YChat: + return self.params["ychat"] + + @property + def awareness(self) -> PersonaAwareness: + return self.params["awareness"] + + @property + def persona_id(self) -> str: + return self.params["persona_id"] + + @property + def model_args(self) -> dict[str, Any]: + return self.params.get("model_args", {}) + + @property + def system_prompt(self) -> Optional[str]: + return self.params.get("system_prompt") + + @property + def response_template(self) -> Template: + template = self.params.get("response_template") + # If response template was unspecified, use the default response + # template. + if not template: + template = Template(DEFAULT_RESPONSE_TEMPLATE) + + return template + + @property + def toolkit(self) -> Optional[Toolkit]: + return self.params.get("toolkit") + + @property + def history_size(self) -> int: + return self.params.get("history_size", 2) + + @property + def log(self) -> logging.Logger: + return self.params.get("logger") + + +class RootNode(JaiAsyncNode): + """ + The root node of the default flow provided by Jupyter AI. + """ + + async def prep_async(self, shared): + self.log.info("Running RootNode.prep_async()") + # Initialize `shared.litellm_messages` using the YChat message history + # if it is unset. + if not ('litellm_messages' in shared and isinstance(shared['litellm_messages'], list) and len(shared['litellm_messages']) > 0): + shared['litellm_messages'] = self._init_litellm_messages() + + # Return `shared.litellm_messages`. This is passed as the `prep_res` + # argument to `exec_async()`. + return shared['litellm_messages'] + + + def _init_litellm_messages(self) -> list[dict]: + # Store the invoking message & the previous `params.history_size` messages + # as `ychat_messages`. + # TODO: ensure the invoking message is in this list + all_messages = self.ychat.get_messages() + ychat_messages: list[Message] = all_messages[-self.history_size - 1:] + + # Coerce each `Message` in `ychat_messages` to a dictionary following + # the OpenAI spec, and store it as `litellm_messages`. + litellm_messages: list[dict[str, Any]] = [] + for msg in ychat_messages: + role = ( + "assistant" + if msg.sender.startswith("jupyter-ai-personas::") + else "system" if msg.sender == SYSTEM_USERNAME else "user" + ) + litellm_messages.append({"role": role, "content": msg.body}) + + # Insert system message as a dictionary if present. + if self.system_prompt: + system_litellm_message = { + "role": "system", + "content": self.system_prompt + } + litellm_messages = [system_litellm_message, *litellm_messages] + + # Return `litellm_messages` + return litellm_messages + + + async def exec_async(self, prep_res: list[dict]): + self.log.info("Running RootNode.exec_async()") + # Gather arguments and start a reply stream via LiteLLM + reply_stream = await acompletion( + **self.model_args, + model=self.model_id, + messages=prep_res, + tools=self.toolkit.to_json(), + stream=True, + ) + + # Iterate over reply stream + content = "" + tool_calls = ToolCallList() + stream_id: str | None = None + async for chunk in reply_stream: + assert isinstance(chunk, ModelResponseStream) + delta = chunk.choices[0].delta + content_delta = delta.content + toolcalls_delta = delta.tool_calls + + # Continue early if an empty chunk was emitted. + # This sometimes happens with LiteLLM. + if not (content_delta or toolcalls_delta): + continue + + # Aggregate the content and tool calls from the deltas + if content_delta: + content += content_delta + if toolcalls_delta: + tool_calls += toolcalls_delta + + # Create a new message if one does not yet exist + if not stream_id: + stream_id = self.ychat.add_message(NewMessage( + sender=self.persona_id, + body="" + )) + assert stream_id + + # Update the reply + message_body = self.response_template.render({ + "content": content, + "tool_call_ui_elements": tool_calls.render() + }) + self.log.error(message_body) + self.ychat.update_message( + Message( + id=stream_id, + body=message_body, + time=time.time(), + sender=self.persona_id, + raw_time=False, + ) + ) + + # Return message_id, content, and tool calls + return stream_id, content, tool_calls + + async def post_async(self, shared, prep_res, exec_res: Tuple[str, str, ToolCallList]): + self.log.info("Running RootNode.post_async()") + # Assert that `shared['litellm_messages']` is of the correct type, and + # that any tool calls returned are complete. + message_id, content, tool_calls = exec_res + assert 'litellm_messages' in shared and isinstance(shared['litellm_messages'], list) + assert tool_calls.complete + + # Add AI response to `shared['litellm_messages']`, including tool calls + new_litellm_message = { + "role": "assistant", + "content": content + } + if len(tool_calls): + new_litellm_message['tool_calls'] = tool_calls.as_litellm_tool_calls() + shared['litellm_messages'].append(new_litellm_message) + + # Add message ID to `shared['prev_message_id']` + shared['prev_message_id'] = message_id + + # Add message content to `shared['prev_message_content]` + shared['prev_message_content'] = content + + # Add tool calls to `shared['next_tool_calls']` + shared['next_tool_calls'] = tool_calls + + # Trigger `ToolExecutorNode` if tools were called. + if len(tool_calls): + return "execute-tools" + return 'finish' + +class ToolExecutorNode(JaiAsyncNode): + """ + Node responsible for executing tool calls in the default flow. + """ + + + async def prep_async(self, shared): + self.log.info("Running ToolExecutorNode.prep_async()") + # Extract `shared['next_tool_calls']` and the ID of the last message + assert 'next_tool_calls' in shared and isinstance(shared['next_tool_calls'], ToolCallList) + assert 'prev_message_id' in shared and isinstance(shared['prev_message_id'], str) + + # Return list of tool calls as a list of dictionaries + return shared['prev_message_id'], shared['next_tool_calls'] + + async def exec_async(self, prep_res: Tuple[str, ToolCallList]) -> list[LitellmToolCallOutput]: + self.log.info("Running ToolExecutorNode.exec_async()") + message_id, tool_calls = prep_res + + # TODO: Run 1 tool at a time? + outputs = await run_tools(tool_calls, self.toolkit) + + for output in outputs: + self.log.error(output) + return outputs + + async def post_async(self, shared, prep_res: Tuple[str, ToolCallList], exec_res: list[LitellmToolCallOutput]): + self.log.info("Running ToolExecutorNode.post_async()") + + # Update last message to include outputs + prev_message_id = shared['prev_message_id'] + prev_message_content = shared['prev_message_content'] + tool_calls: ToolCallList = shared['next_tool_calls'] + message_body = self.response_template.render({ + "content": prev_message_content, + "tool_call_ui_elements": tool_calls.render( + outputs=exec_res + ) + }) + self.ychat.update_message( + Message( + id=prev_message_id, + body=message_body, + time=time.time(), + sender=self.persona_id, + raw_time=False, + ) + ) + self.log.error(message_body) + + # Add tool outputs to `shared['litellm_messages']` + shared['litellm_messages'].extend(exec_res) + for msg in shared['litellm_messages']: + self.log.error(msg) + + # Delete shared state that is now stale + del shared['prev_message_id'] + del shared['prev_message_content'] + del shared['next_tool_calls'] + # This node will automatically return to `RootNode` after execution. + +async def run_default_flow(params: DefaultFlowParams): + # Initialize nodes + root_node = RootNode() + tool_executor_node = ToolExecutorNode() + + # Define state transitions + ## Flow to ToolExecutorNode if tool calls were dispatched + root_node - "execute-tools" >> tool_executor_node + ## Always flow back to RootNode after running tools + tool_executor_node >> root_node + ## End the flow if no tool calls were dispatched + root_node - "finish" >> AsyncNode() + + # Initialize flow and set its parameters + flow = AsyncFlow(start=root_node) + flow.set_params(params) + + # Finally, run the async node + await flow.run_async({}) + diff --git a/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py index 0bb815d20..6ccb22d71 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/run_tools.py @@ -1,21 +1,29 @@ +from __future__ import annotations +from typing import TYPE_CHECKING import asyncio -from .toolcall_list import ToolCallList -from ..tools import Toolkit +if TYPE_CHECKING: + from ..tools import Toolkit + from .toolcall_list import ToolCallList + from .types import LitellmToolCallOutput -async def run_tools(tool_call_list: ToolCallList, toolkit: Toolkit) -> list[dict]: + +async def run_tools(tool_call_list: ToolCallList, toolkit: Toolkit) -> list[LitellmToolCallOutput]: """ Runs the tools specified in the list of tool calls returned by `self.stream_message()`. - Returns `list[ToolCallOutput]`. The outputs should be appended directly to - the message history on the next request made to the LLM. + Returns `list[LitellmToolCallOutput]`, a list of output dictionaries of the + type expected by LiteLLM. + + Each output in the list should be appended directly to the message history + on the next request made to the LLM. """ tool_calls = tool_call_list.resolve() if not len(tool_calls): return [] - tool_outputs: list[dict] = [] + tool_outputs: list[LitellmToolCallOutput] = [] for tool_call in tool_calls: # Get tool definition from the correct toolkit # TODO: validation? @@ -31,7 +39,7 @@ async def run_tools(tool_call_list: ToolCallList, toolkit: Toolkit) -> list[dict output = str(e) # Store the tool output in a dictionary accepted by LiteLLM - output_dict = { + output_dict: LitellmToolCallOutput = { "tool_call_id": tool_call.id, "role": "tool", "name": tool_call.function.name, diff --git a/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py index 311e20a6a..563f06746 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py @@ -2,6 +2,8 @@ import json from pydantic import BaseModel from typing import Any +from .types import LitellmToolCall, LitellmToolCallOutput, JaiToolCallProps +from jinja2 import Template class ResolvedFunction(BaseModel): """ @@ -54,6 +56,13 @@ class ResolvedToolCall(BaseModel): This is usually 0 unless the LLM supports parallel tool calling. """ +JAI_TOOL_CALL_TEMPLATE = Template(""" +{% for props in props_list %} + + +{% endfor %} +""".strip()) + class ToolCallList(BaseModel): """ A helper object that defines a custom `__iadd__()` method which accepts a @@ -165,16 +174,80 @@ def resolve(self) -> list[ResolvedToolCall]: return resolved_toolcalls - def to_json(self) -> list[dict[str, Any]]: + @property + def complete(self) -> bool: + for i, tool_call in enumerate(self._aggregate): + if tool_call.index != i: + return False + if not tool_call.function: + return False + if not tool_call.function.name: + return False + if not tool_call.type: + return False + if not tool_call.function.arguments: + return False + try: + json.loads(tool_call.function.arguments) + except Exception: + return False + + return True + + def as_litellm_tool_calls(self) -> list[LitellmToolCall]: """ - Returns the list of tool calls as a Python dictionary that can be - JSON-serialized. + Returns the current list of tool calls as a list of dictionaries. + + This should be set in the `tool_calls` key in the dictionary of the + LiteLLM assistant message responsible for dispatching these tool calls. """ return [ model.model_dump() for model in self._aggregate ] - + def render(self, outputs: list[LitellmToolCallOutput] | None = None) -> str: + """ + Renders this tool call list as a list of `` elements to + be shown in the chat. + """ + # Initialize list of props to render into tool call UI elements + props_list: list[JaiToolCallProps] = [] + + # Index all outputs if passed + outputs_by_id: dict[str, LitellmToolCallOutput] | None = None + if outputs: + outputs_by_id = {} + for output in outputs: + outputs_by_id[output['tool_call_id']] = output + + for tool_call in self._aggregate: + # Build the props for each tool call UI element + props: JaiToolCallProps = { + 'id': tool_call.id, + 'index': tool_call.index, + 'type': tool_call.type, + 'function_name': tool_call.function.name, + 'function_args': tool_call.function.arguments, + } + + # Add the output if present + if outputs_by_id and tool_call.id in outputs_by_id: + output = outputs_by_id[tool_call.id] + # Make sure to manually convert the dictionary to a JSON string + # first. Without doing this, Jinja2 will convert a dictionary to + # JSON using single quotes instead of double quotes, which + # cannot be parsed by the frontend. + output = json.dumps(output) + props['output'] = output + + props_list.append(props) + + # Render the tool call UI elements using the Jinja2 template and return + return JAI_TOOL_CALL_TEMPLATE.render({ + "props_list": props_list + }) + + def __len__(self) -> int: return len(self._aggregate) \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py index b901711c3..314c54990 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py @@ -1,21 +1,39 @@ from __future__ import annotations -from pydantic import BaseModel -from .toolcall_list import ToolCallList +from typing import TypedDict, Literal, Optional -class StreamResult(BaseModel): + +class LitellmToolCall(TypedDict): id: str - """ - ID of the new message. - """ + type: Literal['function'] + function: str + index: int - tool_call_list: ToolCallList - """ - Tool calls requested by the LLM in its streamed response. - """ +class LitellmMessage(TypedDict): + role: Literal['assistant', 'user', 'system'] + content: str + tool_calls: Optional[list[LitellmToolCall]] -class ToolCallOutput(BaseModel): +class LitellmToolCallOutput(TypedDict): tool_call_id: str - role: str = "tool" + role: Literal['tool'] name: str content: str +class JaiToolCallProps(TypedDict): + id: str | None + + type: Literal['function'] | None + + index: int | None + + function_name: str | None + + function_args: str | None + """ + The arguments to the function as a dictionary converted to a JSON string. + """ + + output: str | None + """ + The `LitellmToolCallOutput` as a JSON string. + """ \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/personas/__init__.py b/packages/jupyter-ai/jupyter_ai/personas/__init__.py index 6c0704f52..fb8dd1bfe 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/personas/__init__.py @@ -1,2 +1,2 @@ -from .base_persona import BasePersona, PersonaDefaults -from .persona_manager import PersonaManager +from .base_persona import * +from .persona_manager import * diff --git a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py index 99763f409..40fdb14db 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/base_persona.py +++ b/packages/jupyter-ai/jupyter_ai/personas/base_persona.py @@ -4,24 +4,22 @@ from abc import ABC, ABCMeta, abstractmethod from dataclasses import asdict from logging import Logger -from time import time from typing import TYPE_CHECKING, Any, Optional from jupyter_ai.config_manager import ConfigManager from jupyterlab_chat.models import Message, NewMessage, User from jupyterlab_chat.ychat import YChat -from litellm import ModelResponseStream, supports_function_calling +from litellm import supports_function_calling from litellm.utils import function_to_dict from pydantic import BaseModel from traitlets import MetaHasTraits from traitlets.config import LoggingConfigurable from .persona_awareness import PersonaAwareness -from ..litellm_lib import ToolCallList, StreamResult, run_tools +from ..litellm_lib import ToolCallList, run_tools from ..tools.default_toolkit import DEFAULT_TOOLKIT if TYPE_CHECKING: - from collections.abc import AsyncIterator from .persona_manager import PersonaManager class PersonaDefaults(BaseModel): @@ -234,127 +232,6 @@ def as_user_dict(self) -> dict[str, Any]: user = self.as_user() return asdict(user) - async def stream_message( - self, reply_stream: "AsyncIterator[ModelResponseStream | str]" - ) -> StreamResult: - """ - Takes an async iterator, dubbed the 'reply stream', and streams it to a - new message by this persona in the YChat. The async iterator may yield - either strings or `litellm.ModelResponseStream` objects. Details: - - - Creates a new message upon receiving the first chunk from the reply - stream, then continuously updates it until the stream is closed. - - - Automatically manages its awareness state to show writing status. - - Returns a list of `ResolvedToolCall` objects. If this list is not empty, - the persona should run these tools. - """ - stream_id: Optional[str] = None - stream_interrupted = False - tool_call_list = ToolCallList() - try: - self.awareness.set_local_state_field("isWriting", True) - - async for chunk in reply_stream: - # Start the stream with an empty message on the initial reply. - # Bind the new message ID to `stream_id`. - if not stream_id: - stream_id = self.ychat.add_message( - NewMessage(body="", sender=self.id) - ) - self.message_interrupted[stream_id] = asyncio.Event() - self.awareness.set_local_state_field("isWriting", stream_id) - assert stream_id - - # Compute `content_delta` and `tool_calls_delta` based on the - # type of object yielded by `reply_stream`. - if isinstance(chunk, ModelResponseStream): - delta = chunk.choices[0].delta - content_delta = delta.content - toolcalls_delta = delta.tool_calls - elif isinstance(chunk, str): - content_delta = chunk - toolcalls_delta = None - else: - raise Exception(f"Unrecognized type in stream_message(): {type(chunk)}") - - # LiteLLM streams always terminate with an empty chunk, so - # continue in this case. - if not (content_delta or toolcalls_delta): - continue - - # Terminate the stream if the user requested it. - if ( - stream_id - and stream_id in self.message_interrupted.keys() - and self.message_interrupted[stream_id].is_set() - ): - try: - # notify the model provider that streaming was interrupted - # (this is essential to allow the model to stop generating) - await reply_stream.athrow( # type:ignore[attr-defined] - GenerationInterrupted() - ) - except GenerationInterrupted: - # do not let the exception bubble up in case if - # the provider did not handle it - pass - stream_interrupted = True - break - - # Append `content_delta` to the existing message. - if content_delta: - self.ychat.update_message( - Message( - id=stream_id, - body=content_delta, - time=time(), - sender=self.id, - raw_time=False, - ), - append=True, - ) - if toolcalls_delta: - tool_call_list += toolcalls_delta - - except Exception as e: - self.log.error( - f"Persona '{self.name}' encountered an exception printed below when attempting to stream output." - ) - self.log.exception(e) - finally: - # Reset local state - self.awareness.set_local_state_field("isWriting", False) - self.message_interrupted.pop(stream_id, None) - - # If stream was interrupted, add a tombstone and return `[]`, - # indicating that no tools should be run afterwards. - if stream_id and stream_interrupted: - stream_tombstone = "\n\n(AI response stopped by user)" - self.ychat.update_message( - Message( - id=stream_id, - body=stream_tombstone, - time=time(), - sender=self.id, - raw_time=False, - ), - append=True, - ) - return None - - # TODO: determine where this should live - count = len(tool_call_list) - if count > 0: - self.log.info(f"AI response triggered {count} tool calls.") - - return StreamResult( - id=stream_id, - tool_call_list=tool_call_list - ) - - def send_message(self, body: str) -> None: """ Sends a new message to the chat from this persona. @@ -464,68 +341,6 @@ def resolve_attachment_to_path(self, attachment_id: str) -> Optional[str]: self.log.error(f"Failed to resolve attachment {attachment_id}: {e}") return None - def get_tools(self, model_id: str) -> list[dict]: - """ - Returns the `tools` parameter which should be passed to - `litellm.acompletion()` for a given LiteLLM model ID. - - If the model does not support tool-calling, this method returns an empty - list. Otherwise, it returns the list of tools available in the current - environment. These may include: - - - The default set of tool functions in Jupyter AI, defined in the - the default toolkit from `jupyter_ai.tools`. - - - (TODO) Tools provided by MCP server configuration, if any. - - - (TODO) Web search. - - - (TODO) File search using vector store IDs. - - TODO: cache this - - TODO: Implement some permissions system so users can control what tools - are allowable. - - NOTE: The returned list is expected by LiteLLM to conform to the `tools` - parameter defintiion defined by the OpenAI API: - https://platform.openai.com/docs/guides/tools#available-tools - - NOTE: This API is a WIP and is very likely to change. - """ - # Return early if the model does not support tool calling - if not supports_function_calling(model=model_id): - return [] - - tool_descriptions = [] - - # Get all tools from the default toolkit and store their object descriptions - for tool in DEFAULT_TOOLKIT.get_tools(): - # Here, we are using a util function from LiteLLM to coerce - # each `Tool` struct into a tool description dictionary expected - # by LiteLLM. - desc = { - "type": "function", - "function": function_to_dict(tool.callable), - } - tool_descriptions.append(desc) - - # Finally, return the tool descriptions - self.log.info(tool_descriptions) - return tool_descriptions - - - async def run_tools(self, tool_call_list: ToolCallList) -> list[dict]: - """ - Runs the tools specified in a given tool call list using the default - toolkit. - """ - tool_outputs = await run_tools(tool_call_list, toolkit=DEFAULT_TOOLKIT) - self.log.info(f"Ran {len(tool_outputs)} tool functions.") - return tool_outputs - - - def shutdown(self) -> None: """ Shuts the persona down. This method should: diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index 7a6261a67..c89934a51 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -1,16 +1,12 @@ -from typing import Any, Optional -import time - from jupyterlab_chat.models import Message -from litellm import acompletion -from ...litellm_lib import StreamResult, ToolCallOutput from ..base_persona import BasePersona, PersonaDefaults -from ..persona_manager import SYSTEM_USERNAME +from ...default_flow import run_default_flow, DefaultFlowParams from .prompt_template import ( JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE, JupyternautSystemPromptArgs, ) +from ...tools import DEFAULT_TOOLKIT class JupyternautPersona(BasePersona): @@ -31,6 +27,7 @@ def defaults(self): ) async def process_message(self, message: Message) -> None: + # Return early if no chat model is configured if not self.config_manager.chat_model: self.send_message( "No chat model is configured.\n\n" @@ -38,145 +35,28 @@ async def process_message(self, message: Message) -> None: ) return - model_id = self.config_manager.chat_model - - # `True` before the first LLM response is sent, `False` afterwards. - initial_response = True - # List of tool call outputs computed in the previous invocation. - tool_call_outputs: list[dict] = [] - - # Initialize list of messages, including history and context - messages: list[dict] = self.get_context_as_messages(model_id, message) - - # Loop until the AI is complete running all its tools. - while initial_response or len(tool_call_outputs): - # Stream message to the chat - response_aiter = await acompletion( - model=model_id, - messages=messages, - tools=self.get_tools(model_id), - stream=True, - ) - result = await self.stream_message(response_aiter) - initial_response = False - - # Append new reply to `messages` - reply = self.ychat.get_message(result.id) - tool_calls_json = result.tool_call_list.to_json() - messages.append({ - "role": "assistant", - "content": reply.body, - "tool_calls": tool_calls_json - }) - - # Render tool calls in new message - if len(result.tool_call_list): - self.render_tool_calls(result) - - # Run tools and append outputs to `messages` - tool_call_outputs = await self.run_tools(result.tool_call_list) - messages.extend(tool_call_outputs) - - # Render tool call outputs in new message - if tool_call_outputs: - self.render_tool_call_outputs( - message_id=result.id, - tool_call_outputs=tool_call_outputs - ) - - def render_tool_calls(self, stream_result: StreamResult): - """ - Renders tool calls by appending the tool calls to a message. - """ - message_id = stream_result.id - tool_call_list = stream_result.tool_call_list - - for tool_call in tool_call_list.resolve(): - id = tool_call.id - index = tool_call.index - type_val = tool_call.type - function = tool_call.function.model_dump_json() - # We have to HTML-escape double quotes in the JSON string. - function = function.replace('"', """) - - self.ychat.update_message(Message( - id=message_id, - body=f'\n\n\n', - sender=self.id, - time=time.time(), - raw_time=False - ), append=True) - - - def render_tool_call_outputs(self, message_id: str, tool_call_outputs: list[dict]): - # Updates the content of the last message directly - message = self.ychat.get_message(message_id) - body = message.body - for output in tool_call_outputs: - if not output['content']: - output['content'] = "" - output = ToolCallOutput(**output) - tool_id = output.tool_call_id - tool_output = output.model_dump_json() - tool_output = tool_output.replace('"', '"') - body = body.replace( - f' list[dict[str, Any]]: - """ - Returns the current context, including attachments and recent messages, - as a list of messages accepted by `litellm.acompletion()`. - """ - system_msg_args = JupyternautSystemPromptArgs( - model_id=model_id, - persona_name=self.name, - context=self.process_attachments(message), - ).model_dump() - - system_msg = { - "role": "system", - "content": JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(**system_msg_args), + # Build default flow params + system_prompt = self._build_system_prompt(message) + flow_params: DefaultFlowParams = { + "persona_id": self.id, + "model_id": self.config_manager.chat_model, + "model_args": self.config_manager.chat_model_args, + "ychat": self.ychat, + "awareness": self.awareness, + "system_prompt": system_prompt, + "toolkit": DEFAULT_TOOLKIT, + "logger": self.log, } - context_as_messages = [system_msg, *self._get_history_as_messages()] - return context_as_messages - - def _get_history_as_messages(self, k: Optional[int] = 2) -> list[dict[str, Any]]: - """ - Returns the current history as a list of messages accepted by - `litellm.acompletion()`. + # Run default agent flow + await run_default_flow(flow_params) - NOTE: You should usually call the public `get_context_as_messages()` - method instead. - """ - # TODO: consider bounding history based on message size (e.g. total - # char/token count) instead of message count. - all_messages = self.ychat.get_messages() - - # gather last k * 2 messages and return - start_idx = 0 if k is None else -2 * k - recent_messages: list[Message] = all_messages[start_idx:] - - history: list[dict[str, Any]] = [] - for msg in recent_messages: - role = ( - "assistant" - if msg.sender.startswith("jupyter-ai-personas::") - else "system" if msg.sender == SYSTEM_USERNAME else "user" - ) - history.append({"role": role, "content": msg.body}) - - return history + def _build_system_prompt(self, message: Message) -> str: + context = self.process_attachments(message) + format_args = JupyternautSystemPromptArgs( + persona_name=self.name, + model_id=self.config_manager.chat_model, + context=context, + ) + system_prompt = JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(format_args.model_dump()) + return system_prompt \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/tools/models.py b/packages/jupyter-ai/jupyter_ai/tools/models.py index e547f0c15..72b9c69ef 100644 --- a/packages/jupyter-ai/jupyter_ai/tools/models.py +++ b/packages/jupyter-ai/jupyter_ai/tools/models.py @@ -1,5 +1,6 @@ import re from typing import Callable, Optional +from litellm.utils import function_to_dict from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -215,3 +216,24 @@ def get_tools( toolset.add(tool) return toolset + + def to_json(self) -> list[dict]: + """ + Returns a list of tool descriptions in the type expected by LiteLLM. + """ + tool_descriptions = [] + + # Get all tools from the default toolkit and store their object descriptions + for tool in self.get_tools(): + # Here, we are using a util function from LiteLLM to coerce + # each `Tool` struct into a tool description dictionary expected + # by LiteLLM. + desc = { + "type": "function", + "function": function_to_dict(tool.callable), + } + tool_descriptions.append(desc) + + # Finally, return the tool descriptions + return tool_descriptions + From fd950bd2a91bebae28697805cbdce65e3d11fd99 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Thu, 18 Sep 2025 10:23:47 -0700 Subject: [PATCH 10/16] add pocketflow as a dependency --- packages/jupyter-ai/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/jupyter-ai/pyproject.toml b/packages/jupyter-ai/pyproject.toml index d82914103..cdd5b4dd2 100644 --- a/packages/jupyter-ai/pyproject.toml +++ b/packages/jupyter-ai/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ "litellm>=1.73,<2", "jinja2>=3.0,<4", "python_dotenv>=1,<2", + "pocketflow==0.0.3", ] dynamic = ["version", "description", "authors", "urls", "keywords"] From 793fb4e11eba703490518552ffbb089b2417dc55 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Thu, 18 Sep 2025 10:24:08 -0700 Subject: [PATCH 11/16] update jai-tool-call web component API to do less JSON parsing --- .../jupyter-ai/src/web-components/index.ts | 2 +- .../src/web-components/jai-tool-call.tsx | 52 +++++++++---------- .../{plugin.ts => web-components-plugin.ts} | 14 ++++- 3 files changed, 38 insertions(+), 30 deletions(-) rename packages/jupyter-ai/src/web-components/{plugin.ts => web-components-plugin.ts} (83%) diff --git a/packages/jupyter-ai/src/web-components/index.ts b/packages/jupyter-ai/src/web-components/index.ts index aea66be0d..5f5e58107 100644 --- a/packages/jupyter-ai/src/web-components/index.ts +++ b/packages/jupyter-ai/src/web-components/index.ts @@ -1,2 +1,2 @@ -export * from './plugin'; +export * from './web-components-plugin'; export * from './jai-tool-call'; diff --git a/packages/jupyter-ai/src/web-components/jai-tool-call.tsx b/packages/jupyter-ai/src/web-components/jai-tool-call.tsx index dc6a52cc5..d9530c620 100644 --- a/packages/jupyter-ai/src/web-components/jai-tool-call.tsx +++ b/packages/jupyter-ai/src/web-components/jai-tool-call.tsx @@ -1,4 +1,4 @@ -import React, { useState, useMemo } from 'react'; +import React, { useState } from 'react'; import { Box, Typography, @@ -10,13 +10,11 @@ import ExpandMore from '@mui/icons-material/ExpandMore'; import CheckCircle from '@mui/icons-material/CheckCircle'; type JaiToolCallProps = { - id: string; - type: string; - function: { - name: string; - arguments: Record; - }; - index: number; + id?: string; + type?: string; + function_name?: string; + function_args?: string; + index?: number; output?: { tool_call_id: string; role: string; @@ -26,10 +24,10 @@ type JaiToolCallProps = { }; export function JaiToolCall(props: JaiToolCallProps): JSX.Element | null { - const [expanded, setExpanded] = useState(false); console.log({ - output: props.output + props }); + const [expanded, setExpanded] = useState(false); const toolComplete = !!(props.output && Object.keys(props.output).length > 0); const hasOutput = !!(toolComplete && props.output?.content?.length); @@ -47,29 +45,28 @@ export function JaiToolCall(props: JaiToolCallProps): JSX.Element | null { {toolComplete ? 'Ran' : 'Running'}{' '} - {props.function.name} + {props.function_name} {' '} tool {toolComplete ? '.' : '...'} ); - const toolArgsJson = useMemo( - () => JSON.stringify(props.function.arguments, null, 2), - [props.function.arguments] - ); + // const toolArgsJson = useMemo( + // () => JSON.stringify(props?.function_args ?? {}, null, 2), + // [props.function_args] + // ); - const toolArgsSection: JSX.Element | null = - toolArgsJson === '{}' ? null : ( - - - Tool arguments - -
-          {toolArgsJson}
-        
-
- ); + const toolArgsSection: JSX.Element | null = props.function_args ? ( + + + Tool arguments + +
+        {props.function_args}
+      
+
+ ) : null; const toolOutputSection: JSX.Element | null = hasOutput ? ( @@ -80,12 +77,13 @@ export function JaiToolCall(props: JaiToolCallProps): JSX.Element | null { ) : null; - if (!props.id || !props.type || !props.function) { + if (!props.id || !props.type || !props.function_name) { return null; } return ( props: { id: 'string', type: 'string', - function: 'json', + function_name: 'string', + // this is deliberately not 'json' since `function_args` may be a + // partial JSON string. + function_args: 'string', index: 'number', output: 'json' } @@ -55,7 +58,14 @@ export const webComponentsPlugin: JupyterFrontEndPlugin allowedTags: [...(options?.allowedTags ?? []), 'jai-tool-call'], allowedAttributes: { ...options?.allowedAttributes, - 'jai-tool-call': ['id', 'type', 'function', 'index', 'output'] + 'jai-tool-call': [ + 'id', + 'type', + 'function_name', + 'function_args', + 'index', + 'output' + ] } }); } From 29f075d4d3b97718470abb40385965e9ada836f2 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Thu, 18 Sep 2025 10:41:02 -0700 Subject: [PATCH 12/16] remove debug logs --- packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py b/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py index 9fbdda70d..b93980f74 100644 --- a/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py +++ b/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py @@ -207,7 +207,6 @@ async def exec_async(self, prep_res: list[dict]): "content": content, "tool_call_ui_elements": tool_calls.render() }) - self.log.error(message_body) self.ychat.update_message( Message( id=stream_id, @@ -274,8 +273,6 @@ async def exec_async(self, prep_res: Tuple[str, ToolCallList]) -> list[LitellmTo # TODO: Run 1 tool at a time? outputs = await run_tools(tool_calls, self.toolkit) - for output in outputs: - self.log.error(output) return outputs async def post_async(self, shared, prep_res: Tuple[str, ToolCallList], exec_res: list[LitellmToolCallOutput]): @@ -300,12 +297,9 @@ async def post_async(self, shared, prep_res: Tuple[str, ToolCallList], exec_res: raw_time=False, ) ) - self.log.error(message_body) # Add tool outputs to `shared['litellm_messages']` shared['litellm_messages'].extend(exec_res) - for msg in shared['litellm_messages']: - self.log.error(msg) # Delete shared state that is now stale del shared['prev_message_id'] From 95d2615d45788649e8f484ee93df11fcfdb3124a Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Thu, 18 Sep 2025 10:48:46 -0700 Subject: [PATCH 13/16] show writing indicator while processing request --- .../jupyter-ai/jupyter_ai/default_flow/default_flow.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py b/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py index b93980f74..f885a6fbf 100644 --- a/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py +++ b/packages/jupyter-ai/jupyter_ai/default_flow/default_flow.py @@ -121,7 +121,6 @@ class RootNode(JaiAsyncNode): """ async def prep_async(self, shared): - self.log.info("Running RootNode.prep_async()") # Initialize `shared.litellm_messages` using the YChat message history # if it is unset. if not ('litellm_messages' in shared and isinstance(shared['litellm_messages'], list) and len(shared['litellm_messages']) > 0): @@ -325,5 +324,12 @@ async def run_default_flow(params: DefaultFlowParams): flow.set_params(params) # Finally, run the async node - await flow.run_async({}) + try: + params['awareness'].set_local_state_field("isWriting", True) + await flow.run_async({}) + except Exception as e: + # TODO: implement error handling + params['logger'].exception("Exception occurred while running default agent flow:") + finally: + params['awareness'].set_local_state_field("isWriting", False) From 9211dd3cb2e26f46a3cdae3d0e42a400679ca3a9 Mon Sep 17 00:00:00 2001 From: joadoumie Date: Sun, 5 Oct 2025 10:46:09 +0900 Subject: [PATCH 14/16] Add context-aware toolkit with workspace directory binding This commit adds workspace directory awareness to tools (bash and search_grep) by creating a dynamic toolkit that binds the workspace directory to tool calls. Also updates the system prompt to guide the LLM on when to use these tools. Note: This version has a bug where tool call UI is not displaying properly. Saving this commit to preserve the directional changes before reverting. --- .../personas/jupyternaut/jupyternaut.py | 54 +++++++++++++++++-- .../personas/jupyternaut/prompt_template.py | 12 +++++ .../jupyter_ai/tools/default_toolkit.py | 12 +++-- 3 files changed, 71 insertions(+), 7 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index c89934a51..db6434e74 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -1,3 +1,4 @@ +from typing import Optional from jupyterlab_chat.models import Message from ..base_persona import BasePersona, PersonaDefaults @@ -6,7 +7,8 @@ JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE, JupyternautSystemPromptArgs, ) -from ...tools import DEFAULT_TOOLKIT +from ...tools import DEFAULT_TOOLKIT, Toolkit, Tool +from ...tools.default_toolkit import bash, read, edit, write, search_grep class JupyternautPersona(BasePersona): @@ -37,6 +39,7 @@ async def process_message(self, message: Message) -> None: # Build default flow params system_prompt = self._build_system_prompt(message) + toolkit = self._build_toolkit() flow_params: DefaultFlowParams = { "persona_id": self.id, "model_id": self.config_manager.chat_model, @@ -44,7 +47,7 @@ async def process_message(self, message: Message) -> None: "ychat": self.ychat, "awareness": self.awareness, "system_prompt": system_prompt, - "toolkit": DEFAULT_TOOLKIT, + "toolkit": toolkit, "logger": self.log, } @@ -59,4 +62,49 @@ def _build_system_prompt(self, message: Message) -> str: context=context, ) system_prompt = JUPYTERNAUT_SYSTEM_PROMPT_TEMPLATE.render(format_args.model_dump()) - return system_prompt \ No newline at end of file + return system_prompt + + def _build_toolkit(self) -> Toolkit: + """ + Build a context-aware toolkit with the workspace directory bound to tools. + """ + # Get workspace directory for this chat + workspace_dir = self.get_workspace_dir() + + # Create wrapper functions that bind workspace_dir + # We can't use functools.partial because litellm.function_to_dict expects __name__ + async def bash(command: str, timeout: Optional[int] = None) -> str: + """Executes a bash command and returns the result + + Args: + command: The bash command to execute + timeout: Optional timeout in seconds + + Returns: + The command output (stdout and stderr combined) + """ + from ...tools.default_toolkit import bash as bash_orig + return await bash_orig(command, timeout=timeout, cwd=workspace_dir) + + async def search_grep(pattern: str, include: str = "*") -> str: + """Search for text patterns in files using ripgrep. + + Args: + pattern: A regular expression pattern to search for + include: A glob pattern to filter which files to search + + Returns: + The raw output from ripgrep, including file paths, line numbers, and matching lines + """ + from ...tools.default_toolkit import search_grep as search_grep_orig + return await search_grep_orig(pattern, include=include, cwd=workspace_dir) + + # Create toolkit with workspace-aware tools + toolkit = Toolkit(name="jupyter-ai-contextual-toolkit") + toolkit.add_tool(Tool(callable=bash)) + toolkit.add_tool(Tool(callable=search_grep)) + toolkit.add_tool(Tool(callable=read)) + toolkit.add_tool(Tool(callable=edit)) + toolkit.add_tool(Tool(callable=write)) + + return toolkit \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/prompt_template.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/prompt_template.py index 05cb7b956..8683c0442 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/prompt_template.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/prompt_template.py @@ -39,6 +39,18 @@ {% if context %}The user has shared the following context: {{context}} + +You have access to tools that can help you work with this context: +- `read(file_path, offset, limit)`: Read specific lines from a file +- `search_grep(pattern, include)`: Search for text patterns across files +- `bash(command)`: Execute bash commands to interact with files and the system + +Use these tools strategically based on the user's request. For example: +- If asked about file contents, use `read()` to examine specific portions +- If searching for specific code or text, use `search_grep()` +- For complex operations, use `bash()` commands + +File paths in the context are relative to the workspace directory. {% else %}The user did not share any additional context.{% endif %} """.strip() diff --git a/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py b/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py index d850775af..80a1735a3 100644 --- a/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py +++ b/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py @@ -165,7 +165,7 @@ def write(file_path: str, content: str) -> None: path.write_text(content, encoding="utf-8") -async def search_grep(pattern: str, include: str = "*") -> str: +async def search_grep(pattern: str, include: str = "*", cwd: Optional[str] = None) -> str: """ Search for text patterns in files using ripgrep. @@ -200,6 +200,8 @@ async def search_grep(pattern: str, include: str = "*") -> str: - {a,b} matches either "a" or "b" - ! at start negates the pattern Examples: "*.py", "**/*.js", "src/**/*.{ts,tsx}", "!*.test.*" + cwd : str, optional + The directory to search in. Defaults to current working directory. Returns ------- @@ -241,20 +243,21 @@ async def search_grep(pattern: str, include: str = "*") -> str: # Join command with proper shell escaping command = " ".join(f'"{part}"' if " " in part or any(c in part for c in "!*?[]{}()") else part for part in cmd_parts) - + try: - result = await bash(command) + result = await bash(command, cwd=cwd) return result except Exception as e: raise RuntimeError(f"Ripgrep search failed: {str(e)}") from e -async def bash(command: str, timeout: Optional[int] = None) -> str: +async def bash(command: str, timeout: Optional[int] = None, cwd: Optional[str] = None) -> str: """Executes a bash command and returns the result Args: command: The bash command to execute timeout: Optional timeout in seconds + cwd: Optional working directory to execute the command in Returns: The command output (stdout and stderr combined) @@ -267,6 +270,7 @@ async def bash(command: str, timeout: Optional[int] = None) -> str: *shlex.split(command), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, + cwd=cwd, ) try: From c1468ff0b0ed4e5b3f3515261efb377cffb14d40 Mon Sep 17 00:00:00 2001 From: joadoumie Date: Sun, 5 Oct 2025 11:03:53 +0900 Subject: [PATCH 15/16] Fix tool call UI by renaming id to tool_id The id attribute is a reserved HTML attribute and was not being properly passed to the React component when used in web components. This caused the tool call UI to not display because the component returned null when props.id was undefined. Changed all instances of id to tool_id in: - toolcall_list.py: Changed prop key from "id" to "tool_id" - types.py: Updated JaiToolCallProps type definition - web-components-plugin.ts: Updated prop declaration and allowed attributes - jai-tool-call.tsx: Updated TypeScript interface and usage --- packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py | 2 +- packages/jupyter-ai/jupyter_ai/litellm_lib/types.py | 2 +- packages/jupyter-ai/src/web-components/jai-tool-call.tsx | 6 +++--- .../jupyter-ai/src/web-components/web-components-plugin.ts | 6 +++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py index 563f06746..4fbc4b90a 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/toolcall_list.py @@ -223,7 +223,7 @@ def render(self, outputs: list[LitellmToolCallOutput] | None = None) -> str: for tool_call in self._aggregate: # Build the props for each tool call UI element props: JaiToolCallProps = { - 'id': tool_call.id, + 'tool_id': tool_call.id, 'index': tool_call.index, 'type': tool_call.type, 'function_name': tool_call.function.name, diff --git a/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py b/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py index 314c54990..c207d8873 100644 --- a/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py +++ b/packages/jupyter-ai/jupyter_ai/litellm_lib/types.py @@ -20,7 +20,7 @@ class LitellmToolCallOutput(TypedDict): content: str class JaiToolCallProps(TypedDict): - id: str | None + tool_id: str | None type: Literal['function'] | None diff --git a/packages/jupyter-ai/src/web-components/jai-tool-call.tsx b/packages/jupyter-ai/src/web-components/jai-tool-call.tsx index d9530c620..465291761 100644 --- a/packages/jupyter-ai/src/web-components/jai-tool-call.tsx +++ b/packages/jupyter-ai/src/web-components/jai-tool-call.tsx @@ -10,7 +10,7 @@ import ExpandMore from '@mui/icons-material/ExpandMore'; import CheckCircle from '@mui/icons-material/CheckCircle'; type JaiToolCallProps = { - id?: string; + tool_id?: string; type?: string; function_name?: string; function_args?: string; @@ -77,13 +77,13 @@ export function JaiToolCall(props: JaiToolCallProps): JSX.Element | null { ) : null; - if (!props.id || !props.type || !props.function_name) { + if (!props.tool_id || !props.type || !props.function_name) { return null; } return ( provides: ISanitizer, activate: (app: JupyterFrontEnd) => { // Define the JaiToolCall web component - // ['id', 'type', 'function', 'index', 'output'] + // ['tool_id', 'type', 'function', 'index', 'output'] const JaiToolCallWebComponent = r2wc(JaiToolCall, { props: { - id: 'string', + tool_id: 'string', type: 'string', function_name: 'string', // this is deliberately not 'json' since `function_args` may be a @@ -59,7 +59,7 @@ export const webComponentsPlugin: JupyterFrontEndPlugin allowedAttributes: { ...options?.allowedAttributes, 'jai-tool-call': [ - 'id', + 'tool_id', 'type', 'function_name', 'function_args', From f51c8472e1c65cc0d06da7fea2956502384b852c Mon Sep 17 00:00:00 2001 From: joadoumie Date: Sun, 5 Oct 2025 11:15:01 +0900 Subject: [PATCH 16/16] Add workspace directory context to file tools Added cwd parameter to read, edit, and write tools to support relative file paths. The tools now resolve relative paths against the workspace directory, allowing the LLM to reference files without absolute paths. Changes: - Added cwd parameter to read(), edit(), and write() in default_toolkit.py - Updated _build_toolkit() to use functools.partial with a helper function - Replaced verbose wrapper functions with cleaner partial binding approach - All file tools now receive workspace directory context automatically --- .../personas/jupyternaut/jupyternaut.py | 44 +++++-------------- .../jupyter_ai/tools/default_toolkit.py | 30 ++++++++++--- 2 files changed, 35 insertions(+), 39 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py index db6434e74..3631b91f3 100644 --- a/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py +++ b/packages/jupyter-ai/jupyter_ai/personas/jupyternaut/jupyternaut.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Optional from jupyterlab_chat.models import Message @@ -71,40 +72,19 @@ def _build_toolkit(self) -> Toolkit: # Get workspace directory for this chat workspace_dir = self.get_workspace_dir() - # Create wrapper functions that bind workspace_dir - # We can't use functools.partial because litellm.function_to_dict expects __name__ - async def bash(command: str, timeout: Optional[int] = None) -> str: - """Executes a bash command and returns the result - - Args: - command: The bash command to execute - timeout: Optional timeout in seconds - - Returns: - The command output (stdout and stderr combined) - """ - from ...tools.default_toolkit import bash as bash_orig - return await bash_orig(command, timeout=timeout, cwd=workspace_dir) - - async def search_grep(pattern: str, include: str = "*") -> str: - """Search for text patterns in files using ripgrep. - - Args: - pattern: A regular expression pattern to search for - include: A glob pattern to filter which files to search - - Returns: - The raw output from ripgrep, including file paths, line numbers, and matching lines - """ - from ...tools.default_toolkit import search_grep as search_grep_orig - return await search_grep_orig(pattern, include=include, cwd=workspace_dir) + def bind_cwd(func, **kwargs): + """Create a partial function with custom __name__ and __doc__ preserved""" + bound_func = partial(func, **kwargs) + bound_func.__name__ = func.__name__ + bound_func.__doc__ = func.__doc__ + return bound_func # Create toolkit with workspace-aware tools toolkit = Toolkit(name="jupyter-ai-contextual-toolkit") - toolkit.add_tool(Tool(callable=bash)) - toolkit.add_tool(Tool(callable=search_grep)) - toolkit.add_tool(Tool(callable=read)) - toolkit.add_tool(Tool(callable=edit)) - toolkit.add_tool(Tool(callable=write)) + toolkit.add_tool(Tool(callable=bind_cwd(bash, cwd=workspace_dir))) + toolkit.add_tool(Tool(callable=bind_cwd(search_grep, cwd=workspace_dir))) + toolkit.add_tool(Tool(callable=bind_cwd(read, cwd=workspace_dir))) + toolkit.add_tool(Tool(callable=bind_cwd(edit, cwd=workspace_dir))) + toolkit.add_tool(Tool(callable=bind_cwd(write, cwd=workspace_dir))) return toolkit \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py b/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py index 80a1735a3..a1b71cf01 100644 --- a/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py +++ b/packages/jupyter-ai/jupyter_ai/tools/default_toolkit.py @@ -6,20 +6,23 @@ from .models import Tool, Toolkit -def read(file_path: str, offset: int, limit: int) -> str: +def read(file_path: str, offset: int, limit: int, cwd: Optional[str] = None) -> str: """ Read a subset of lines from a text file. Parameters ---------- file_path : str - Absolute path to the file that should be read. + Path to the file that should be read. Can be absolute or relative to cwd. offset : int The line number at which to start reading (1-based indexing). limit : int - Number of lines to read starting from *offset*. + Number of lines to read starting from *offset. If *offset + limit* exceeds the number of lines in the file, all available lines after *offset* are returned. + cwd : str, optional + The directory to use as the base for relative paths. If not provided, + file_path must be absolute. Returns ------- @@ -33,6 +36,8 @@ def read(file_path: str, offset: int, limit: int) -> str: ['third line\n', 'fourth line\n', 'fifth line\n', 'sixth line\n'] """ path = pathlib.Path(file_path) + if cwd and not path.is_absolute(): + path = pathlib.Path(cwd) / path if not path.is_file(): raise FileNotFoundError(f"File not found: {file_path}") @@ -69,6 +74,7 @@ def edit( old_string: str, new_string: str, replace_all: bool = False, + cwd: Optional[str] = None, ) -> None: """ Replace occurrences of a substring in a file. @@ -76,7 +82,7 @@ def edit( Parameters ---------- file_path : str - Absolute path to the file that should be edited. + Path to the file that should be edited. Can be absolute or relative to cwd. old_string : str Text that should be replaced. new_string : str @@ -84,6 +90,9 @@ def edit( replace_all : bool, optional If ``True`` all occurrences of *old_string* are replaced. If ``False`` (default), only the first occurrence in the file is replaced. + cwd : str, optional + The directory to use as the base for relative paths. If not provided, + file_path must be absolute. Returns ------- @@ -110,6 +119,8 @@ def edit( >>> edit('/tmp/test.txt', 'foo', 'bar', replace_all=True) """ path = pathlib.Path(file_path) + if cwd and not path.is_absolute(): + path = pathlib.Path(cwd) / path if not path.is_file(): raise FileNotFoundError(f"File not found: {file_path}") @@ -129,16 +140,19 @@ def edit( path.write_text(new_content, encoding="utf-8") -def write(file_path: str, content: str) -> None: +def write(file_path: str, content: str, cwd: Optional[str] = None) -> None: """ Write content to a file, creating it if it doesn't exist. Parameters ---------- file_path : str - Absolute path to the file that should be written. + Path to the file that should be written. Can be absolute or relative to cwd. content : str Content to write to the file. + cwd : str, optional + The directory to use as the base for relative paths. If not provided, + file_path must be absolute. Returns ------- @@ -160,7 +174,9 @@ def write(file_path: str, content: str) -> None: >>> write('/tmp/data.json', '{"key": "value"}') """ path = pathlib.Path(file_path) - + if cwd and not path.is_absolute(): + path = pathlib.Path(cwd) / path + # Write the content to the file path.write_text(content, encoding="utf-8")