Skip to content

Commit 8798735

Browse files
authored
Adds custom agents to the langchain benchmarking repo (#120)
* This PR adds code for running custom agents to the langchain benchmarking repo. * The agent code is good enough for experimentation / prototyping, but I don't think it's good enough for the langchain repo: -- The abstractions aren't fully implemented and aren't ready for production use -- but OK for research -- For production use, one may want to remove all the intermediate abstractions to keep the agent as simple as possible I was thinking initially of including this in a different repo, but I think it's over-complicating things, probably OK to include some reference implementations inside of langchain benchmarks.
1 parent 7ed859c commit 8798735

File tree

19 files changed

+991
-74
lines changed

19 files changed

+991
-74
lines changed

agents/__init__.py

Whitespace-only changes.

agents/tests/__init__.py

Whitespace-only changes.

langchain_benchmarks/model_registration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@
194194
),
195195
RegisteredModel(
196196
provider="fireworks",
197-
name="mixtral-8x7b-instruct",
197+
name="mixtral-8x7b-instruct-fw",
198198
description="Mistral MoE 8x7B Instruct v0.1 model with Sparse "
199199
"Mixture of Experts. Fine tuned for instruction following",
200200
type="llm",
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from langchain_benchmarks.tool_usage.agents.adapters import apply_agent_executor_adapter
2+
from langchain_benchmarks.tool_usage.agents.experimental.factory import (
3+
CustomAgentFactory,
4+
)
5+
from langchain_benchmarks.tool_usage.agents.openai_functions import OpenAIAgentFactory
6+
7+
__all__ = ["OpenAIAgentFactory", "apply_agent_executor_adapter", "CustomAgentFactory"]
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from typing import Optional, Callable, Any
2+
3+
from langchain.agents import AgentExecutor
4+
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
5+
6+
7+
def _ensure_output_exists(inputs: dict) -> dict:
8+
"""Make sure that the output key is always present."""
9+
if "output" not in inputs:
10+
return {"output": "", **inputs}
11+
return inputs
12+
13+
14+
def apply_agent_executor_adapter(
15+
agent_executor: AgentExecutor,
16+
*,
17+
state_reader: Optional[Callable[[], Any]] = None,
18+
) -> Runnable:
19+
"""An adapter for the agent executor to standardize its input and output.
20+
21+
1) Map `question` to `input` (`question` is used in the datasets,
22+
but `input` is used in the agent executor)
23+
2) Ensure that `output` is always returned (will be set to "" if missing) --
24+
note that this may be relaxed after more updates in the eval config.
25+
3) Populate `state` key in the response of the agent with the system state
26+
if a state reader is provided.
27+
28+
Args:
29+
agent_executor: the agent executor
30+
state_reader: A callable without parameters that if invoked will return
31+
the state of the environment. Used to populate the 'state' key.
32+
33+
Returns:
34+
a new runnable with a standardized output.
35+
"""
36+
37+
def _read_state(*args: Any, **kwargs: Any) -> Any:
38+
"""Read the state of the environment."""
39+
if state_reader is not None:
40+
return state_reader()
41+
else:
42+
return None
43+
44+
def _format_input(inputs: dict) -> dict:
45+
"""Make sure that the input is always called `input`."""
46+
47+
if "question" not in inputs:
48+
raise ValueError(
49+
"Expected 'question' to be in the inputs. Found only the following "
50+
f"keys {sorted(inputs.keys())}."
51+
)
52+
53+
inputs = inputs.copy() # Because 'question' is popped below
54+
55+
if "input" not in inputs:
56+
return {"input": inputs.pop("question"), **inputs}
57+
return inputs
58+
59+
runnable = (
60+
RunnableLambda(_format_input).with_config({"run_name": "Format Input"})
61+
| agent_executor
62+
| RunnableLambda(_ensure_output_exists).with_config(
63+
{"run_name": "Ensure Output"}
64+
)
65+
)
66+
67+
if state_reader is not None:
68+
runnable = runnable | RunnablePassthrough.assign(state=_read_state).with_config(
69+
{"run_name": "Read Env State"}
70+
)
71+
return runnable

langchain_benchmarks/tool_usage/agents/experimental/__init__.py

Whitespace-only changes.
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from typing import List, Literal, Optional, Sequence, Tuple, Union
2+
3+
from langchain.agents import AgentOutputParser
4+
from langchain.prompts.chat import ChatPromptTemplate
5+
from langchain.schema.runnable import Runnable
6+
from langchain.tools import StructuredTool
7+
from langchain_core.agents import AgentAction, AgentFinish
8+
from langchain_core.language_models import BaseChatModel, BaseLanguageModel
9+
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage
10+
from langchain_core.prompts import MessagesPlaceholder
11+
from typing_extensions import NotRequired, TypedDict
12+
13+
from langchain_benchmarks import RateLimiter
14+
from langchain_benchmarks.rate_limiting import with_rate_limit
15+
from langchain_benchmarks.tool_usage.agents.experimental.encoder import (
16+
AstPrinter,
17+
TypeScriptEncoder,
18+
XMLEncoder,
19+
)
20+
from langchain_benchmarks.tool_usage.agents.experimental.encoder import FunctionResult
21+
from langchain_benchmarks.tool_usage.agents.experimental.prompts import (
22+
_AGENT_INSTRUCTIONS_BLOB_STYLE,
23+
)
24+
from langchain_benchmarks.tool_usage.agents.experimental.tool_utils import (
25+
convert_tool_to_function_definition,
26+
)
27+
28+
29+
def format_steps_for_chat(
30+
intermediate_steps: List[Tuple[AgentAction, str]],
31+
ast_printer: AstPrinter,
32+
) -> List[BaseMessage]:
33+
"""Format the steps."""
34+
messages = []
35+
for action, observation in intermediate_steps:
36+
# Action messages contains the tool invocation request from the LLM
37+
# Now add the result of the tool invocation.
38+
39+
if action.tool == "_Exception":
40+
messages.append(
41+
AIMessage(
42+
content=action.log,
43+
)
44+
)
45+
messages.append(
46+
# Tool input is the error message for the exception
47+
HumanMessage(content=action.tool_input)
48+
)
49+
else:
50+
messages.extend(action.messages)
51+
function_result: FunctionResult = {
52+
"name": action.tool,
53+
"error": None,
54+
"result": observation,
55+
}
56+
messages.append(
57+
HumanMessage(
58+
content=ast_printer.visit_function_result(function_result),
59+
)
60+
)
61+
62+
return messages
63+
64+
65+
# PUBLIC API
66+
67+
68+
class AgentInput(TypedDict):
69+
"""The input to the agent."""
70+
71+
input: str
72+
"""The input to the agent."""
73+
intermediate_steps: List[Tuple[AgentAction, str]]
74+
"""The intermediate steps taken by the agent."""
75+
examples: NotRequired[List[BaseMessage]]
76+
"""A list of messages that can be used to form example traces."""
77+
78+
79+
def create_agent(
80+
model: Union[BaseChatModel, BaseLanguageModel],
81+
tools: Sequence[StructuredTool],
82+
parser: AgentOutputParser,
83+
*,
84+
ast_printer: Union[AstPrinter, Literal["xml"]] = "xml",
85+
rate_limiter: Optional[RateLimiter] = None,
86+
) -> Runnable[AgentInput, Union[AgentAction, AgentFinish]]:
87+
"""Create an agent for a chat model."""
88+
if isinstance(ast_printer, str):
89+
if ast_printer == "xml":
90+
ast_printer_ = XMLEncoder()
91+
elif ast_printer == "typescript":
92+
ast_printer_ = TypeScriptEncoder()
93+
else:
94+
raise ValueError(f"Unknown ast printer: {ast_printer}")
95+
elif isinstance(ast_printer, AstPrinter):
96+
ast_printer_ = ast_printer
97+
else:
98+
raise TypeError(
99+
f"Expected AstPrinter or str, got {type(ast_printer)} for `ast_printer`"
100+
)
101+
102+
function_definitions = [convert_tool_to_function_definition(tool) for tool in tools]
103+
tool_description = ast_printer_.visit_function_definitions(function_definitions)
104+
105+
template = ChatPromptTemplate.from_messages(
106+
[
107+
("system", _AGENT_INSTRUCTIONS_BLOB_STYLE),
108+
MessagesPlaceholder("examples"), # Can use to add example traces
109+
("human", "{input}"),
110+
MessagesPlaceholder("history"),
111+
]
112+
).partial(tool_description=tool_description)
113+
114+
# For the time being, hard-coding the fact that we're using a <tool> tag.
115+
model = model.bind(stop=["</tool>"])
116+
117+
if rate_limiter:
118+
# Apply a rate limiter if it was provided
119+
model = with_rate_limit(model, rate_limiter)
120+
121+
agent = (
122+
{
123+
"input": lambda x: x["input"],
124+
"history": lambda x: format_steps_for_chat(
125+
x["intermediate_steps"], ast_printer_
126+
),
127+
"examples": lambda x: x.get("examples", []),
128+
}
129+
| template
130+
| model
131+
| parser
132+
)
133+
return agent

0 commit comments

Comments
 (0)