From 030391935484f83ff439799f789800fa2806c355 Mon Sep 17 00:00:00 2001 From: Avinash Balakrishnan Date: Tue, 28 Oct 2025 13:48:44 -0700 Subject: [PATCH 1/7] feat: making generate_from_raw public in openai --- mellea/backends/__init__.py | 8 +++-- mellea/backends/openai.py | 66 +++++++++++++++++++++---------------- 2 files changed, 42 insertions(+), 32 deletions(-) diff --git a/mellea/backends/__init__.py b/mellea/backends/__init__.py index fd76bc50..bda4fd44 100644 --- a/mellea/backends/__init__.py +++ b/mellea/backends/__init__.py @@ -58,19 +58,21 @@ def generate_from_context( ... @abc.abstractmethod - def _generate_from_raw( + def generate_from_raw( self, actions: list[Component | CBlock], + ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, + tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generates a model output from the provided input. Does not use context or templates. Args: actions: list of actions to generate responses for. Each action is separate. + ctx: context passed to generation. Currently not used in generate_from_raw format: A response format to used for structured outputs / constrained decoding. Note: some backends do not support this parameter. They will log warnings and continue to generate. model_options: Any model options to upsert into the defaults for this call. - generate_logs: a `GenerateLog` instance to add log information to. + tool_calls: Always set to false unless supported by backend. """ diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 39b026a8..0778da91 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -8,7 +8,7 @@ import json from collections.abc import Callable, Coroutine from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, overload from urllib.parse import urlparse import openai @@ -646,13 +646,14 @@ async def post_processing( generate_log.result = mot mot._generate_log = generate_log - def _generate_from_raw( + def generate_from_raw( self, actions: list[Component | CBlock], + ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, + tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the completions api. Gives the input provided to the model without templating.""" extra_body = {} @@ -664,6 +665,10 @@ def _generate_from_raw( # Some versions (like vllm's version) of the OpenAI API support structured decoding for completions requests. extra_body["guided_json"] = format.model_json_schema() + if tool_calls: + FancyLogger.get_logger().warning( + "The completion endpoint does not support tool calling at the moment." + ) model_opts = self._simplify_and_merge(model_options, is_chat_context=False) @@ -689,32 +694,35 @@ def _generate_from_raw( # Necessary for type checker. assert isinstance(completion_response, Completion) - results = [ - ModelOutputThunk( - value=response.text, - meta={"oai_completion_response": response.model_dump()}, - ) - for response in completion_response.choices - ] - - for i, result in enumerate(results): - self.formatter.parse(actions[i], result) - - if generate_logs is not None: - assert isinstance(generate_logs, list) - date = datetime.datetime.now() - - for i in range(len(prompts)): - generate_log = GenerateLog() - generate_log.prompt = prompts[i] - generate_log.backend = f"openai::{self.model_id!s}" - generate_log.model_options = model_opts - generate_log.date = date - generate_log.model_output = completion_response - generate_log.extra = {"seed": model_opts.get("seed", None)} - generate_log.action = actions[i] - generate_log.result = results[i] - generate_logs.append(generate_log) + results = [] + for response, action, prompt in zip( + completion_response.choices, actions, prompts + ): + output = ModelOutputThunk(None) + output.value = response.text + output._context = None # There is no context for generate_from_raw for now + output._action = action + output._model_options = model_opts + output._meta = { + "oai_completion_response": response.model_dump(), + "oai_usage": completion_response.usage.model_dump() + if completion_response.usage + else None, + } + + self.formatter.parse(action, output) + + generate_log = GenerateLog() + generate_log.prompt = prompt + generate_log.backend = f"openai::{self.model_id!s}" + generate_log.model_options = model_opts + generate_log.date = datetime.datetime.now() + generate_log.model_output = completion_response + generate_log.extra = {"seed": model_opts.get("seed", None)} + generate_log.action = action + output._generate_log = generate_log + + results.append(output) return results From 1414c782984d3b984259a7679c263a8c6ac3a8b6 Mon Sep 17 00:00:00 2001 From: Avinash Balakrishnan Date: Tue, 28 Oct 2025 17:09:00 -0700 Subject: [PATCH 2/7] feat: making generate_from_raw public in ollama --- mellea/backends/ollama.py | 47 ++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 86c4509b..1972d9ec 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -385,13 +385,14 @@ def generate_from_chat_context( return output - def _generate_from_raw( + def generate_from_raw( self, actions: list[Component | CBlock], + ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, + tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the generate api. Gives the input provided to the model without templating.""" if len(actions) > 1: @@ -438,32 +439,32 @@ async def get_response(): else: result = ModelOutputThunk( value=response.response, - meta={"generate_response": response.model_dump()}, + meta={ + "generate_response": response.model_dump(), + "usage": response.eval_count, + }, ) self.formatter.parse(actions[i], result) results.append(result) - if generate_logs is not None: - # noinspection DuplicatedCode - assert isinstance(generate_logs, list) - generate_log = GenerateLog() - generate_log.prompt = prompts[i] - generate_log.backend = f"ollama::{self.model_id!s}" - generate_log.date = date - generate_log.model_options = model_opts - generate_log.model_output = result.value - generate_log.extra = { - "format": format, - "thinking": model_opts.get(ModelOption.THINKING, None), - "seed": model_opts.get(ModelOption.SEED, None), - } - generate_log.action = actions[i] - generate_log.result = result - - if error: - generate_log.extra["error"] = error - generate_logs.append(generate_log) + generate_log = GenerateLog() + generate_log.prompt = prompts[i] + generate_log.backend = f"ollama::{self.model_id!s}" + generate_log.date = date + generate_log.model_options = model_opts + generate_log.model_output = result.value + generate_log.extra = { + "format": format, + "thinking": model_opts.get(ModelOption.THINKING, None), + "seed": model_opts.get(ModelOption.SEED, None), + } + generate_log.action = actions[i] + generate_log.result = result + + if error: + generate_log.extra["error"] = error + result._generate_log = generate_log return results From 69d06ab306ef494d15a118c9761a88b4f07a9446 Mon Sep 17 00:00:00 2001 From: Avinash Balakrishnan Date: Tue, 28 Oct 2025 17:10:14 -0700 Subject: [PATCH 3/7] feat: making generate_from_raw public in ollama --- mellea/backends/ollama.py | 9 ++++++++- mellea/backends/openai.py | 2 +- test/backends/test_ollama.py | 8 ++++---- test/backends/test_openai_ollama.py | 16 ++++++++-------- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 1972d9ec..ee164831 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -399,6 +399,10 @@ def generate_from_raw( FancyLogger.get_logger().info( "Ollama doesn't support batching; will attempt to process concurrently." ) + if tool_calls: + FancyLogger.get_logger().warning( + "The completion endpoint does not support tool calling at the moment." + ) model_opts = self._simplify_and_merge(model_options) @@ -441,7 +445,10 @@ async def get_response(): value=response.response, meta={ "generate_response": response.model_dump(), - "usage": response.eval_count, + "usage": { + "completion_tokens": response.eval_count, + "prompt_tokens": response.prompt_eval_count, + }, }, ) diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 0778da91..e147dfc4 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -705,7 +705,7 @@ def generate_from_raw( output._model_options = model_opts output._meta = { "oai_completion_response": response.model_dump(), - "oai_usage": completion_response.usage.model_dump() + "usage": completion_response.usage.model_dump() if completion_response.usage else None, } diff --git a/test/backends/test_ollama.py b/test/backends/test_ollama.py index 4887c542..4a044e1e 100644 --- a/test/backends/test_ollama.py +++ b/test/backends/test_ollama.py @@ -100,8 +100,8 @@ class Email(pydantic.BaseModel): def test_generate_from_raw(session): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] - results = session.backend._generate_from_raw( - actions=[CBlock(value=prompt) for prompt in prompts], generate_logs=None + results = session.backend.generate_from_raw( + actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx ) assert len(results) == len(prompts) @@ -115,10 +115,10 @@ class Answer(pydantic.BaseModel): name: str value: int - results = session.backend._generate_from_raw( + results = session.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], + ctx=session.ctx, format=Answer, - generate_logs=None, ) assert len(results) == len(prompts) diff --git a/test/backends/test_openai_ollama.py b/test/backends/test_openai_ollama.py index b2883e4e..64746308 100644 --- a/test/backends/test_openai_ollama.py +++ b/test/backends/test_openai_ollama.py @@ -11,7 +11,7 @@ from mellea.backends.model_ids import META_LLAMA_3_2_1B from mellea.backends.openai import OpenAIBackend from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock, ModelOutputThunk, ChatContext, SimpleContext +from mellea.stdlib.base import CBlock, ChatContext, ModelOutputThunk, SimpleContext @pytest.fixture(scope="module") @@ -111,15 +111,15 @@ class Email(pydantic.BaseModel): # assert email.to.email_address.endswith("example.com") pass - # Ollama doesn't support batch requests. Cannot run this test unless we switch backend providers. - # def test_generate_from_raw(self): - # prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] +# @pytest.mark.qualitative +# def test_generate_from_raw(m_session): +# prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] - # results = self.m.backend._generate_from_raw( - # actions=[CBlock(value=prompt) for prompt in prompts], generate_logs=None - # ) +# results = m_session.backend.generate_from_raw( +# actions=[CBlock(value=prompt) for prompt in prompts], ctx=m_session.ctx +# ) - # assert len(results) == len(prompts) +# assert len(results) == len(prompts) # Default OpenAI implementation doesn't support structured outputs for the completions API. # def test_generate_from_raw_with_format(self): From 8700c29d1b06439d0f2830a427779222efef75ac Mon Sep 17 00:00:00 2001 From: Avinash Balakrishnan Date: Wed, 29 Oct 2025 17:02:19 -0700 Subject: [PATCH 4/7] feat: making generate_from_raw public in hf, watsonx --- mellea/backends/huggingface.py | 46 ++++++------ mellea/backends/litellm.py | 5 +- mellea/backends/ollama.py | 4 +- mellea/backends/watsonx.py | 47 ++++++------ test/backends/test_huggingface.py | 47 +++++++++--- .../test_openai_vllm/test_openai_vllm.py | 40 +++++++---- test/backends/test_watsonx.py | 72 +++++++++++++++---- 7 files changed, 170 insertions(+), 91 deletions(-) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index fb9cacf6..35194e4a 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -526,13 +526,14 @@ async def post_processing( mot._generate_log = generate_log - def _generate_from_raw( + def generate_from_raw( self, actions: list[Component | CBlock], + ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, + tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the completions api. Gives the input provided to the model without templating.""" model_opts = self._simplify_and_merge(model_options) @@ -590,28 +591,31 @@ def _generate_from_raw( sequences_to_decode, skip_special_tokens=True ) - results = [ - ModelOutputThunk(value=decoded_result) for decoded_result in decoded_results - ] + results = [] + for i, decoded_result in enumerate(decoded_results): + result = ModelOutputThunk( + value=decoded_result, + meta={ + "usage": { + "prompt_tokens": inputs["input_ids"][i].size(0), # type: ignore + "completion_tokens": len(sequences_to_decode[i]), + } + }, + ) - for i, result in enumerate(results): self.formatter.parse(actions[i], result) - if generate_logs is not None: - assert isinstance(generate_logs, list) - date = datetime.datetime.now() - - for i in range(len(prompts)): - generate_log = GenerateLog() - generate_log.prompt = prompts[i] - generate_log.backend = f"hf::{self.model_id!s}" - generate_log.model_options = model_opts - generate_log.date = date - generate_log.model_output = decoded_results - generate_log.extra = {"format": format, "seed": seed} - generate_log.action = actions[i] - generate_log.result = results[i] - generate_logs.append(generate_log) + generate_log = GenerateLog() + generate_log.prompt = self.formatter.print(actions[i]) + generate_log.backend = f"hf::{self.model_id!s}" + generate_log.model_options = model_opts + generate_log.date = datetime.datetime.now() + generate_log.model_output = decoded_result + generate_log.extra = {"format": format, "seed": seed} + generate_log.action = actions[i] + + result._generate_log = generate_log + results.append(result) return results diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 89b61536..9af431b0 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -475,13 +475,14 @@ def _extract_tools( FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") return tools - def _generate_from_raw( + def generate_from_raw( self, actions: list[Component | CBlock], + ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, + tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the completions api. Gives the input provided to the model without templating.""" raise NotImplementedError("This method is not implemented yet.") diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index ee164831..7d55b081 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -453,7 +453,6 @@ async def get_response(): ) self.formatter.parse(actions[i], result) - results.append(result) generate_log = GenerateLog() generate_log.prompt = prompts[i] @@ -467,12 +466,13 @@ async def get_response(): "seed": model_opts.get(ModelOption.SEED, None), } generate_log.action = actions[i] - generate_log.result = result if error: generate_log.extra["error"] = error result._generate_log = generate_log + results.append(result) + return results def _extract_model_tool_requests( diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 7e0ce1b1..7ae42b1f 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -478,13 +478,14 @@ async def post_processing( generate_log.action = mot._action mot._generate_log = generate_log - def _generate_from_raw( + def generate_from_raw( self, actions: list[Component | CBlock], + ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, + tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generates a completion text. Gives the input provided to the model without templating.""" if format is not None: @@ -502,36 +503,32 @@ def _generate_from_raw( model_opts, is_chat_context=False ), ) + results = [] + date = datetime.datetime.now() - results = [ - ModelOutputThunk( + for i, response in enumerate(responses): + result = ModelOutputThunk( value=response["results"][0]["generated_text"], meta={"oai_completion_response": response["results"][0]}, ) - for response in responses - ] - for i, result in enumerate(results): self.formatter.parse(actions[i], result) - if generate_logs is not None: - assert isinstance(generate_logs, list) - date = datetime.datetime.now() - - for i in range(len(prompts)): - generate_log = GenerateLog() - generate_log.prompt = prompts[i] - generate_log.backend = f"watsonx::{self.model_id!s}" - generate_log.model_options = model_opts - generate_log.date = date - generate_log.model_output = responses - generate_log.extra = { - "format": format, - "seed": model_opts.get(ModelOption.SEED, None), - } - generate_log.action = actions[i] - generate_log.result = results[i] - generate_logs.append(generate_log) + generate_log = GenerateLog() + generate_log.prompt = prompts[i] + generate_log.backend = f"watsonx::{self.model_id!s}" + generate_log.model_options = model_opts + generate_log.date = date + generate_log.model_output = responses + generate_log.extra = { + "format": format, + "seed": model_opts.get(ModelOption.SEED, None), + } + generate_log.action = actions[i] + + result._generate_log = generate_log + + results.append(result) return results diff --git a/test/backends/test_huggingface.py b/test/backends/test_huggingface.py index 3f40d4cd..c8cd3a74 100644 --- a/test/backends/test_huggingface.py +++ b/test/backends/test_huggingface.py @@ -38,6 +38,7 @@ def session(backend): yield session session.reset() + @pytest.mark.qualitative def test_system_prompt(session): result = session.chat( @@ -46,6 +47,7 @@ def test_system_prompt(session): ) print(result) + @pytest.mark.qualitative async def test_constraint_alora(session, backend): answer = session.instruct( @@ -66,6 +68,7 @@ async def test_constraint_alora(session, backend): await alora_output.avalue() assert alora_output.value in ["Y", "N"], alora_output + @pytest.mark.qualitative def test_constraint_lora_with_requirement(session, backend): answer = session.instruct( @@ -133,6 +136,7 @@ def test_llmaj_req_does_not_use_alora(session, backend): assert isinstance(val_result, ValidationResult) assert str(val_result.reason) not in ["Y", "N"] + @pytest.mark.qualitative def test_instruct(session): result = session.instruct("Compute 1+1.") @@ -150,6 +154,7 @@ def test_multiturn(session): words = session.instruct("Now list five English words that start with that letter.") print(words) + @pytest.mark.qualitative def test_chat(session): output_message = session.chat("What is 1+1?") @@ -188,12 +193,19 @@ class Email(pydantic.BaseModel): "The email address should be at example.com" ) + @pytest.mark.qualitative def test_generate_from_raw(session): - prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?", "what is 4+2+2?"] - - results = session.backend._generate_from_raw( - actions=[CBlock(value=prompt) for prompt in prompts], generate_logs=None + prompts = [ + "what is 1+1?", + "what is 2+2?", + "what is 3+3?", + "what is 4+4?", + "what is 4+2+2?", + ] + + results = session.backend.generate_from_raw( + actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx ) assert len(results) == len(prompts) @@ -207,10 +219,10 @@ class Answer(pydantic.BaseModel): name: str value: int - results = session.backend._generate_from_raw( + results = session.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], format=Answer, - generate_logs=None, + ctx=session.ctx, ) assert len(results) == len(prompts) @@ -223,11 +235,16 @@ class Answer(pydantic.BaseModel): f"formatting directive failed for {random_result.value}: {e.json()}" ) + @pytest.mark.qualitative async def test_async_parallel_requests(session): model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) - mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + mot1, _ = session.backend.generate_from_context( + CBlock("Say Hello."), SimpleContext(), model_options=model_opts + ) + mot2, _ = session.backend.generate_from_context( + CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts + ) m1_val = None m2_val = None @@ -244,19 +261,27 @@ async def test_async_parallel_requests(session): # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response # contains the full response. - assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" - assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" + assert m1_final_val.startswith(m1_val), ( + "final val should contain the first streamed chunk" + ) + assert m2_final_val.startswith(m2_val), ( + "final val should contain the first streamed chunk" + ) assert m1_final_val == mot1.value assert m2_final_val == mot2.value + @pytest.mark.qualitative async def test_async_avalue(session): - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + mot1, _ = session.backend.generate_from_context( + CBlock("Say Hello."), SimpleContext() + ) m1_final_val = await mot1.avalue() assert m1_final_val is not None assert m1_final_val == mot1.value + if __name__ == "__main__": import pytest diff --git a/test/backends/test_openai_vllm/test_openai_vllm.py b/test/backends/test_openai_vllm/test_openai_vllm.py index 505aa1ea..30dff26a 100644 --- a/test/backends/test_openai_vllm/test_openai_vllm.py +++ b/test/backends/test_openai_vllm/test_openai_vllm.py @@ -3,7 +3,12 @@ from mellea.stdlib.base import CBlock, ModelOutputThunk, ChatContext from mellea.backends.openai import OpenAIBackend from mellea.backends.aloras.openai.granite_aloras import add_granite_aloras -from mellea.stdlib.requirement import Requirement, ALoraRequirement, LLMaJRequirement, req +from mellea.stdlib.requirement import ( + Requirement, + ALoraRequirement, + LLMaJRequirement, + req, +) from mellea.backends.formatter import TemplateFormatter from mellea.backends.types import ModelOption @@ -99,8 +104,8 @@ class Email(pydantic.BaseModel): def test_generate_from_raw(self): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] - results = self.m.backend._generate_from_raw( - actions=[CBlock(value=prompt) for prompt in prompts], generate_logs=None + results = self.m.backend.generate_from_raw( + actions=[CBlock(value=prompt) for prompt in prompts], ctx=self.m.ctx ) assert len(results) == len(prompts) @@ -112,10 +117,10 @@ class Answer(pydantic.BaseModel): name: str value: int - results = self.m.backend._generate_from_raw( + results = self.m.backend.generate_from_raw( actions=[CBlock(value=prompt) for prompt in prompts], format=Answer, - generate_logs=None, + ctx=self.m.ctx, ) assert len(results) == len(prompts) @@ -124,9 +129,9 @@ class Answer(pydantic.BaseModel): try: answer = Answer.model_validate_json(random_result.value) except pydantic.ValidationError as e: - assert ( - False - ), f"formatting directive failed for {random_result.value}: {e.json()}" + assert False, ( + f"formatting directive failed for {random_result.value}: {e.json()}" + ) class TestOpenAIALoraStuff: @@ -153,7 +158,9 @@ def test_constraint_alora(self): answer = self.m.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) - alora_output = self.backend.get_aloras()[0].generate_using_strings( + alora_output = self.backend.get_aloras()[ + 0 + ].generate_using_strings( input="Find the difference between these two strings: aaaaaaaaaa aaaaabaaaa", response=str(answer), constraint="The answer mention that there is a b in the middle of one of the strings but not the other.", @@ -168,7 +175,9 @@ def test_constraint_lora_with_requirement(self): "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = self.m.validate( - ALoraRequirement("The answer should mention that there is a b in the middle of one of the strings but not the other."), + ALoraRequirement( + "The answer should mention that there is a b in the middle of one of the strings but not the other." + ) ) assert len(validation_outputs) == 1 val_result = validation_outputs[0] @@ -182,7 +191,9 @@ def test_constraint_lora_override(self): "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa" ) validation_outputs = self.m.validate( - LLMaJRequirement("The answer should mention that there is a b in the middle of one of the strings but not the other."), + LLMaJRequirement( + "The answer should mention that there is a b in the middle of one of the strings but not the other." + ) ) assert len(validation_outputs) == 1 val_result = validation_outputs[0] @@ -199,7 +210,7 @@ def test_constraint_lora_override_does_not_override_alora(self): validation_outputs = self.m.validate( ALoraRequirement( "The answer should mention that there is a b in the middle of one of the strings but not the other." - ), + ) ) assert len(validation_outputs) == 1 non_alora_output = validation_outputs[0] @@ -216,7 +227,7 @@ def test_llmaj_req_does_not_use_alora(self): validation_outputs = self.m.validate( LLMaJRequirement( "The answer should mention that there is a b in the middle of one of the strings but not the other." - ), + ) ) assert len(validation_outputs) == 1 non_alora_output = validation_outputs[0] @@ -245,8 +256,7 @@ def test_format(self): class Person(pydantic.BaseModel): name: str email_address: Annotated[ - str, - pydantic.StringConstraints(pattern=r"[a-zA-Z]{5,10}@example\.com"), + str, pydantic.StringConstraints(pattern=r"[a-zA-Z]{5,10}@example\.com") ] class Email(pydantic.BaseModel): diff --git a/test/backends/test_watsonx.py b/test/backends/test_watsonx.py index a0c43fd1..0d160b89 100644 --- a/test/backends/test_watsonx.py +++ b/test/backends/test_watsonx.py @@ -19,9 +19,9 @@ def backend(): pytest.skip("Skipping watsonx tests.") else: return WatsonxAIBackend( - model_id="ibm/granite-3-3-8b-instruct", - formatter=TemplateFormatter(model_id="ibm-granite/granite-3.3-8b-instruct"), - ) + model_id="ibm/granite-3-3-8b-instruct", + formatter=TemplateFormatter(model_id="ibm-granite/granite-3.3-8b-instruct"), + ) @pytest.fixture(scope="function") @@ -34,11 +34,31 @@ def session(backend: WatsonxAIBackend): yield session session.reset() + @pytest.mark.qualitative def test_filter_chat_completions_kwargs(backend: WatsonxAIBackend): """Detect changes to the WatsonxAI TextChatParameters.""" - - known_keys = ['frequency_penalty', 'logprobs', 'top_logprobs', 'presence_penalty', 'response_format', 'temperature', 'max_tokens', 'max_completion_tokens', 'time_limit', 'top_p', 'n', 'logit_bias', 'seed', 'stop', 'guided_choice', 'guided_regex', 'guided_grammar', 'guided_json'] + + known_keys = [ + "frequency_penalty", + "logprobs", + "top_logprobs", + "presence_penalty", + "response_format", + "temperature", + "max_tokens", + "max_completion_tokens", + "time_limit", + "top_p", + "n", + "logit_bias", + "seed", + "stop", + "guided_choice", + "guided_regex", + "guided_grammar", + "guided_json", + ] test_dict = {key: 1 for key in known_keys} # Make sure keys that we think should be in the TextChatParameters are there. @@ -52,18 +72,21 @@ def test_filter_chat_completions_kwargs(backend: WatsonxAIBackend): filtered_incorrect_dict = backend.filter_chat_completions_kwargs(incorrect_dict) assert "random" not in filtered_incorrect_dict + @pytest.mark.qualitative def test_instruct(session: MelleaSession): result = session.instruct("Compute 1+1.") assert isinstance(result, ModelOutputThunk) assert "2" in result.value # type: ignore + @pytest.mark.qualitative def test_multiturn(session: MelleaSession): session.instruct("What is the capital of France?") answer = session.instruct("Tell me the answer to the previous question.") assert "Paris" in answer.value # type: ignore + @pytest.mark.qualitative def test_chat(session): output_message = session.chat("What is 1+1?") @@ -71,6 +94,7 @@ def test_chat(session): f"Expected a message with content containing 2 but found {output_message}" ) + @pytest.mark.qualitative def test_format(session: MelleaSession): class Person(pydantic.BaseModel): @@ -109,17 +133,22 @@ class Email(pydantic.BaseModel): def test_generate_from_raw(session: MelleaSession): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] - results = session.backend._generate_from_raw( - actions=[CBlock(value=prompt) for prompt in prompts], generate_logs=None + results = session.backend.generate_from_raw( + actions=[CBlock(value=prompt) for prompt in prompts], ctx=session.ctx ) assert len(results) == len(prompts) + @pytest.mark.qualitative async def test_async_parallel_requests(session): model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) - mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + mot1, _ = session.backend.generate_from_context( + CBlock("Say Hello."), SimpleContext(), model_options=model_opts + ) + mot2, _ = session.backend.generate_from_context( + CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts + ) m1_val = None m2_val = None @@ -136,39 +165,52 @@ async def test_async_parallel_requests(session): # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response # contains the full response. - assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" - assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" + assert m1_final_val.startswith(m1_val), ( + "final val should contain the first streamed chunk" + ) + assert m2_final_val.startswith(m2_val), ( + "final val should contain the first streamed chunk" + ) assert m1_final_val == mot1.value assert m2_final_val == mot2.value + @pytest.mark.qualitative async def test_async_avalue(session): - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + mot1, _ = session.backend.generate_from_context( + CBlock("Say Hello."), SimpleContext() + ) m1_final_val = await mot1.avalue() assert m1_final_val is not None assert m1_final_val == mot1.value + def test_client_cache(backend): first_client = backend._model async def get_client_async(): return backend._model - + second_client = asyncio.run(get_client_async()) items_in_cache = backend._client_cache.cache.values() - assert len(items_in_cache) == 2, "should be two clients in the cache since _async_client was called from two event loops" + assert len(items_in_cache) == 2, ( + "should be two clients in the cache since _async_client was called from two event loops" + ) assert first_client in items_in_cache assert second_client in items_in_cache third_client = backend._model - assert third_client == first_client, "clients in sync code should be the same if haven't been pushed out of the cache" + assert third_client == first_client, ( + "clients in sync code should be the same if haven't been pushed out of the cache" + ) fourth_client = asyncio.run(get_client_async()) assert fourth_client in backend._client_cache.cache.values() assert len(backend._client_cache.cache.values()) == 2 + if __name__ == "__main__": import pytest From bb03db977e24647d2ad0fb56c847d0d1c2dca317 Mon Sep 17 00:00:00 2001 From: Avinash Balakrishnan Date: Mon, 3 Nov 2025 10:47:18 -0800 Subject: [PATCH 5/7] feat: making generate_from_raw public in vllm --- mellea/backends/vllm.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/mellea/backends/vllm.py b/mellea/backends/vllm.py index b561baf0..51dcde74 100644 --- a/mellea/backends/vllm.py +++ b/mellea/backends/vllm.py @@ -393,15 +393,21 @@ async def post_processing( mot._generate_log = generate_log - def _generate_from_raw( + def generate_from_raw( self, actions: list[Component | CBlock], + ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, + tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the completions api. Gives the input provided to the model without templating.""" + if tool_calls: + FancyLogger.get_logger().warning( + "The completion endpoint does not support tool calling at the moment." + ) + model_options = self._simplify_and_merge(model_options) prompts = [self.formatter.print(action) for action in actions] @@ -447,25 +453,21 @@ async def generate_all(prompts): for i, result in enumerate(results): self.formatter.parse(actions[i], result) - - if generate_logs is not None: - assert isinstance(generate_logs, list) date = datetime.datetime.now() - for i in range(len(prompts)): - generate_log = GenerateLog() - generate_log.prompt = prompts[i] - generate_log.backend = f"vllm::{self.model_id!s}" - generate_log.model_options = model_options - generate_log.date = date - generate_log.model_output = decoded_results - generate_log.extra = { - "format": format, - "seed": model_options.get(ModelOption.SEED, None), - } - generate_log.action = actions[i] - generate_log.result = results[i] - generate_logs.append(generate_log) + generate_log = GenerateLog() + generate_log.prompt = prompts[i] + generate_log.backend = f"vllm::{self.model_id!s}" + generate_log.model_options = model_options + generate_log.date = date + generate_log.model_output = decoded_results + generate_log.extra = { + "format": format, + "seed": model_options.get(ModelOption.SEED, None), + } + generate_log.action = actions[i] + generate_log.result = results[i] + result._generate_log = generate_log return results From 0226721519ec82ee2c49f184bedb89a67ee00652 Mon Sep 17 00:00:00 2001 From: Avinash Balakrishnan Date: Mon, 3 Nov 2025 10:48:49 -0800 Subject: [PATCH 6/7] adding warning to hf backend --- mellea/backends/huggingface.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 83bb0643..21175c10 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -507,6 +507,11 @@ def generate_from_raw( tool_calls: bool = False, ) -> list[ModelOutputThunk]: """Generate using the completions api. Gives the input provided to the model without templating.""" + if tool_calls: + FancyLogger.get_logger().warning( + "The raw endpoint does not support tool calling at the moment." + ) + model_opts = self._simplify_and_merge(model_options) seed = model_opts.get(ModelOption.SEED, None) if seed is not None: From cfc9d6c7d8edc857b788007abbeada0d8c343c82 Mon Sep 17 00:00:00 2001 From: Avinash Balakrishnan Date: Mon, 3 Nov 2025 11:25:06 -0800 Subject: [PATCH 7/7] attempt at standardizing usage metrics --- mellea/backends/huggingface.py | 7 +++++-- mellea/backends/ollama.py | 6 ++++++ mellea/backends/watsonx.py | 13 +++++++++++-- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 21175c10..c5d9b0db 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -569,12 +569,15 @@ def generate_from_raw( results = [] for i, decoded_result in enumerate(decoded_results): + n_prompt_tokens = inputs["input_ids"][i].size(0) # type: ignore + n_completion_tokens = len(sequences_to_decode[i]) result = ModelOutputThunk( value=decoded_result, meta={ "usage": { - "prompt_tokens": inputs["input_ids"][i].size(0), # type: ignore - "completion_tokens": len(sequences_to_decode[i]), + "prompt_tokens": n_prompt_tokens, # type: ignore + "completion_tokens": n_completion_tokens, + "total_tokens": n_prompt_tokens + n_completion_tokens, } }, ) diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 7d55b081..02d6d620 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -448,6 +448,12 @@ async def get_response(): "usage": { "completion_tokens": response.eval_count, "prompt_tokens": response.prompt_eval_count, + "total_tokens": ( + response.prompt_eval_count + response.eval_count + if response.prompt_eval_count is not None + and response.eval_count is not None + else None + ), }, }, ) diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 4137a383..61192f0c 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -509,9 +509,18 @@ def generate_from_raw( date = datetime.datetime.now() for i, response in enumerate(responses): + output = response["results"][0] result = ModelOutputThunk( - value=response["results"][0]["generated_text"], - meta={"oai_completion_response": response["results"][0]}, + value=output["generated_text"], + meta={ + "oai_completion_response": response["results"][0], + "usage": { + "prompt_tokens": output.get("input_token_count", 0), + "completion_tokens": output.get("generated_token_count", 0), + "total_tokens": output.get("input_token_count", 0) + + output.get("generated_token_count", 0), + }, + }, ) self.formatter.parse(actions[i], result)