Skip to content

Commit 21908e5

Browse files
guicho271828MASATARO ASAI Masataro.Asai@ibm.comnrfulton
authored
feat: Adds a vllm backend (#122)
* feat: added smaller qwen models for debugging Signed-off-by: Masataro Asai <guicho2.71828@gmail.com> * feat(vllm): copied from huggingface Signed-off-by: Masataro Asai <guicho2.71828@gmail.com> * fix(vllm): remove alora and cache Signed-off-by: Masataro Asai <guicho2.71828@gmail.com> * fix(vllm): remove tool calls Signed-off-by: Masataro Asai <guicho2.71828@gmail.com> * fix(vllm): finished the implementation with limited functionality: free-form and constrained generation Signed-off-by: Masataro Asai <guicho2.71828@gmail.com> * fix(vllm): passing mypy and linter Signed-off-by: Masataro Asai <guicho2.71828@gmail.com> * fix(vllm): added vllm optional dep in pyproject.toml Signed-off-by: Masataro Asai <guicho2.71828@gmail.com> * feat(vllm test): copied from huggingface Signed-off-by: Masataro Asai <guicho2.71828@gmail.com> * fix(vllm test): implemented the test Signed-off-by: Masataro Asai <guicho2.71828@gmail.com> * test: require V0 in vllm test Signed-off-by: Masataro Asai <guicho2.71828@gmail.com> * refactor: ctx to chat conversion function * refactor: use_alora function * refactor: moved _extract_model_tool_requests to mellea.backends.utils * feat(vllm): added tool calls * test(tools): run test with mistral * fix(vllm): rename model_options -> engine_args * fix(vllm): use FancyLogger * fix(vllm): ignore type checking for vllm and msgspec * fix(vllm): fixed the backend name in the log * feat(vllm): asynchronous call support * test(vllm): asynchronous call support * fix(vllm): avoid unnecessary incremental processing in non-streaming mode * fix(vllm): fix for the new return format * fix(vllm): fixed vllm test for the new contexts * fix(vllm): addressed minor comments * fix(vllm): uv lock * fix(vllm): mark V0 api test qualitative; will be removed in a future PR that migrates to V1 --------- Signed-off-by: Masataro Asai <guicho2.71828@gmail.com> Co-authored-by: MASATARO ASAI Masataro.Asai@ibm.com <masataro@login1.bluevela.rmf.ibm.com> Co-authored-by: Nathan Fulton <nathan@ibm.com>
1 parent 10f6ffa commit 21908e5

File tree

8 files changed

+2602
-406
lines changed

8 files changed

+2602
-406
lines changed

mellea/backends/_utils.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from __future__ import annotations
2+
3+
import inspect
4+
from collections.abc import Callable
5+
from typing import Any, Literal
6+
7+
from mellea.backends.aloras import Alora
8+
from mellea.backends.formatter import Formatter
9+
from mellea.backends.tools import parse_tools
10+
from mellea.helpers.fancy_logger import FancyLogger
11+
from mellea.stdlib.base import CBlock, Component, Context, ModelToolCall
12+
from mellea.stdlib.chat import Message
13+
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement
14+
15+
# Chat = dict[Literal["role", "content"], str] # external apply_chat_template type hint is weaker
16+
# Chat = dict[str, str | list[dict[str, Any]] ] # for multi-modal models
17+
Chat = dict[str, str]
18+
19+
20+
def to_chat(
21+
action: Component | CBlock,
22+
ctx: Context,
23+
formatter: Formatter,
24+
system_prompt: str | None,
25+
) -> list[Chat]:
26+
"""Converts a context and an action into a series of dicts to be passed to apply_chat_template .
27+
28+
This function is used by local inference backends.
29+
"""
30+
assert ctx.is_chat_context
31+
32+
linearized_ctx = ctx.view_for_generation()
33+
assert linearized_ctx is not None, (
34+
"If ctx.is_chat_context, then the context should be linearizable."
35+
)
36+
ctx_as_message_list: list[Message] = formatter.to_chat_messages(linearized_ctx)
37+
# add action
38+
ctx_as_message_list.extend(formatter.to_chat_messages([action]))
39+
40+
ctx_as_conversation: list = [
41+
{"role": m.role, "content": m.content} for m in ctx_as_message_list
42+
]
43+
44+
# Check that we ddin't accidentally end up with CBlocks.
45+
for msg in ctx_as_conversation:
46+
for v in msg.values():
47+
if "CBlock" in v:
48+
FancyLogger.get_logger().error(
49+
f"Found the string `CBlock` in what should've been a stringified context: {ctx_as_conversation}"
50+
)
51+
52+
# handle custom system prompts. It's important that we do this before the _parse_and_**clean**_model_options step.
53+
if system_prompt is not None:
54+
system_msg: Chat = {"role": "system", "content": system_prompt}
55+
ctx_as_conversation.insert(0, system_msg)
56+
57+
return ctx_as_conversation
58+
59+
60+
def use_alora(
61+
action: Component | CBlock,
62+
alora: Alora | None,
63+
default_to_constraint_checking_alora: bool,
64+
) -> bool:
65+
"""Returns True when the condition for using alora is met.
66+
67+
See `docs/dev/requirement_aLoRA_rerouting.md` for an explanation of the following code block.
68+
"""
69+
if issubclass(type(action), Requirement):
70+
# The general rule is that we reroute to the alora if it exists.
71+
reroute_to_alora = alora is not None
72+
# However, there are some exceptions:
73+
if not default_to_constraint_checking_alora:
74+
reroute_to_alora = False
75+
if issubclass(type(action), LLMaJRequirement):
76+
reroute_to_alora = False
77+
if issubclass(type(action), ALoraRequirement):
78+
reroute_to_alora = True
79+
return reroute_to_alora
80+
else:
81+
return False
82+
83+
84+
def to_tool_calls(
85+
tools: dict[str, Callable], decoded_result: str
86+
) -> dict[str, ModelToolCall] | None:
87+
"""Parse a tool call string."""
88+
model_tool_calls: dict[str, ModelToolCall] = dict()
89+
for tool_name, tool_args in parse_tools(decoded_result):
90+
func = tools.get(tool_name)
91+
if func is None:
92+
FancyLogger.get_logger().warning(
93+
f"model attempted to call a non-existing function: {tool_name}"
94+
)
95+
continue
96+
97+
# Clean up the function args slightly. Some models seem to
98+
# hallucinate parameters when none are required.
99+
sig = inspect.signature(func)
100+
if len(sig.parameters) == 0:
101+
tool_args = {}
102+
103+
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, tool_args)
104+
105+
if len(model_tool_calls) > 0:
106+
return model_tool_calls
107+
return None

mellea/backends/huggingface.py

Lines changed: 24 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from transformers.generation.utils import GenerateDecoderOnlyOutput
3131

3232
from mellea.backends import BaseModelSubclass
33+
from mellea.backends._utils import to_chat, to_tool_calls, use_alora
3334
from mellea.backends.aloras import Alora, AloraBackendMixin
3435
from mellea.backends.cache import Cache, SimpleLRUCache
3536
from mellea.backends.formatter import Formatter, FormatterBackend, TemplateFormatter
@@ -39,7 +40,6 @@
3940
add_tools_from_context_actions,
4041
add_tools_from_model_options,
4142
convert_tools_to_json,
42-
parse_tools,
4343
)
4444
from mellea.backends.types import ModelOption
4545
from mellea.helpers.async_helpers import send_to_queue
@@ -198,26 +198,24 @@ def generate_from_context(
198198
# Upsert model options.
199199
model_opts = self._simplify_and_merge(model_options)
200200

201-
# See `docs/dev/requirement_aLoRA_rerouting.md` for an explanation of the following code block.
202-
if issubclass(type(action), Requirement):
203-
# The general rule is that we reroute to the alora if it exists.
204-
reroute_to_alora = self.get_alora("constraint") is not None
205-
# However, there are some exceptions:
206-
if not self.default_to_constraint_checking_alora:
207-
reroute_to_alora = False
208-
if issubclass(type(action), LLMaJRequirement):
209-
reroute_to_alora = False
210-
if issubclass(type(action), ALoraRequirement):
211-
reroute_to_alora = True
212-
if reroute_to_alora:
213-
mot = self._generate_from_context_alora(
214-
action, ctx, _format=format, model_options=model_opts
215-
)
216-
return mot, ctx.add(mot)
217-
mot = self._generate_from_context_standard(
218-
action, ctx, _format=format, model_options=model_opts, tool_calls=tool_calls
219-
)
220-
return mot, ctx.add(action).add(mot)
201+
if use_alora(
202+
action,
203+
self.get_alora("constraint"),
204+
self.default_to_constraint_checking_alora,
205+
):
206+
mot = self._generate_from_context_alora(
207+
action, ctx, _format=format, model_options=model_opts
208+
)
209+
return mot, ctx.add(mot)
210+
else:
211+
mot = self._generate_from_context_standard(
212+
action,
213+
ctx,
214+
_format=format,
215+
model_options=model_opts,
216+
tool_calls=tool_calls,
217+
)
218+
return mot, ctx.add(action).add(mot)
221219

222220
def _generate_from_context_alora(
223221
self,
@@ -279,35 +277,8 @@ def _generate_from_context_standard(
279277
# If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template.
280278
# Otherwise, we will linearize the context and treat it as a raw input.
281279
if ctx.is_chat_context:
282-
linearized_ctx = ctx.view_for_generation()
283-
assert linearized_ctx is not None, (
284-
"If ctx.is_chat_context, then the context should be linearizable."
285-
)
286-
ctx_as_message_list: list[Message] = self.formatter.to_chat_messages(
287-
linearized_ctx
288-
)
289-
# add action
290-
ctx_as_message_list.extend(self.formatter.to_chat_messages([action]))
291-
ctx_as_conversation = [
292-
{"role": m.role, "content": m.content} for m in ctx_as_message_list
293-
]
294-
295-
# Check that we ddin't accidentally end up with CBlocks.
296-
for msg in ctx_as_conversation:
297-
for v in msg.values():
298-
if "CBlock" in v:
299-
FancyLogger.get_logger().error(
300-
f"Found the string `CBlock` in what should've been a stringified context: {ctx_as_conversation}"
301-
)
302-
303-
# handle custom system prompts. It's important that we do this before the _parse_and_**clean**_model_options step.
304280
system_prompt = model_options.get(ModelOption.SYSTEM_PROMPT, None)
305-
if system_prompt is not None:
306-
system_msg: dict[str, str] = {
307-
"role": "system",
308-
"content": system_prompt,
309-
}
310-
ctx_as_conversation.insert(0, system_msg)
281+
ctx_as_chat = to_chat(action, ctx, self.formatter, system_prompt)
311282

312283
# Append tool call information if applicable.
313284
tools: dict[str, Callable] = dict()
@@ -332,7 +303,7 @@ def _generate_from_context_standard(
332303
set_seed(seed)
333304

334305
input_ids = self._tokenizer.apply_chat_template( # type: ignore
335-
ctx_as_conversation,
306+
ctx_as_chat,
336307
tools=convert_tools_to_json(tools), # type: ignore
337308
add_generation_prompt=True, # If we change this, must modify huggingface granite guardian.
338309
return_tensors="pt",
@@ -397,7 +368,7 @@ def _generate_from_context_standard(
397368
)
398369

399370
output = ModelOutputThunk(None)
400-
output._context = linearized_ctx
371+
output._context = ctx.view_for_generation()
401372
output._action = action
402373
output._model_options = model_options
403374

@@ -406,7 +377,7 @@ def _generate_from_context_standard(
406377
output._process = functools.partial(self.processing, input_ids=input_ids)
407378
output._post_process = functools.partial(
408379
self.post_processing,
409-
conversation=ctx_as_conversation,
380+
conversation=ctx_as_chat,
410381
input_ids=input_ids,
411382
_format=_format,
412383
tool_calls=tool_calls,
@@ -497,7 +468,7 @@ async def post_processing(
497468

498469
# Only scan for tools if we are not doing structured output and tool calls were provided to the model.
499470
if _format is None and tool_calls:
500-
mot.tool_calls = self._extract_model_tool_requests(tools, mot.value)
471+
mot.tool_calls = to_tool_calls(tools, mot.value)
501472

502473
assert mot._action is not None, (
503474
"ModelOutputThunks should have their action assigned during generation"
@@ -698,30 +669,6 @@ def _filter_chat_template_only_options(
698669
}
699670
return {k: v for k, v in model_options.items() if k not in chat_template_only}
700671

701-
def _extract_model_tool_requests(
702-
self, tools: dict[str, Callable], decoded_result: str
703-
) -> dict[str, ModelToolCall] | None:
704-
model_tool_calls: dict[str, ModelToolCall] = dict()
705-
for tool_name, tool_args in parse_tools(decoded_result):
706-
func = tools.get(tool_name)
707-
if func is None:
708-
FancyLogger.get_logger().warning(
709-
f"model attempted to call a non-existing function: {tool_name}"
710-
)
711-
continue
712-
713-
# Clean up the function args slightly. Some models seem to
714-
# hallucinate parameters when none are required.
715-
sig = inspect.signature(func)
716-
if len(sig.parameters) == 0:
717-
tool_args = {}
718-
719-
model_tool_calls[tool_name] = ModelToolCall(tool_name, func, tool_args)
720-
721-
if len(model_tool_calls) > 0:
722-
return model_tool_calls
723-
return None
724-
725672
# region ALora loading, unloading, and utility functions.
726673
def add_alora(self, alora: HFAlora):
727674
"""Loads an ALora for this backend.

mellea/backends/model_ids.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ class ModelIdentifier:
126126
#### Qwen models ####
127127
#####################
128128

129+
QWEN3_0_6B = ModelIdentifier(hf_model_name="Qwen/Qwen3-0.6B", ollama_name="qwen3:0.6b")
130+
131+
QWEN3_1_7B = ModelIdentifier(hf_model_name="Qwen/Qwen3-1.7B", ollama_name="qwen3:1.7b")
132+
129133
QWEN3_8B = ModelIdentifier(hf_model_name="Qwen/Qwen3-8B", ollama_name="qwen3:8b")
130134

131135
QWEN3_14B = ModelIdentifier(hf_model_name="Qwen/Qwen3-14B", ollama_name="qwen3:14b")

0 commit comments

Comments
 (0)