Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions mellea/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,21 @@ def generate_from_context(
...

@abc.abstractmethod
def _generate_from_raw(
def generate_from_raw(
self,
actions: list[Component | CBlock],
ctx: Context,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why we are passing in a context if we are not using it. I think we should just leave it out of the function definition if we aren't going to use it for generation or append to it after the generation. I see that the docstring documents that it's not used, but I think that just adds extra confusion.

It was also my impression from the team meeting yesterday that we don't have a clear idea going forward of what we want context to eventually do here.

I think the one reason for keeping it would be to eventually utilize it. However, that is functionally an api change if we were to change the behavior of context. As a result, I think it would make sense to just break the function def in that instance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should remove both tools and ctx, but only kept it in cause in the original issue, the motivation was to keep the signature consistent with generate_from_context.

*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
generate_logs: list[GenerateLog] | None = None,
tool_calls: bool = False,
) -> list[ModelOutputThunk]:
"""Generates a model output from the provided input. Does not use context or templates.

Args:
actions: list of actions to generate responses for. Each action is separate.
ctx: context passed to generation. Currently not used in generate_from_raw
format: A response format to used for structured outputs / constrained decoding. Note: some backends do not support this parameter. They will log warnings and continue to generate.
model_options: Any model options to upsert into the defaults for this call.
generate_logs: a `GenerateLog` instance to add log information to.
tool_calls: Always set to false unless supported by backend.
"""
54 changes: 33 additions & 21 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,15 +497,21 @@ async def post_processing(

mot._generate_log = generate_log

def _generate_from_raw(
def generate_from_raw(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation of generate_from_raw continues the current implementation's decision to be synchronous. Are we okay with not allowing async batching operations (and it might be the case that these batching interfaces don't support async, I haven't looked)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think support of async is inconsistent here. Given that generate_from_raw is based on the /competitions endpoint which is deprecated, I think its going to be hard to maintain that for all backends here.

self,
actions: list[Component | CBlock],
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
generate_logs: list[GenerateLog] | None = None,
tool_calls: bool = False,
) -> list[ModelOutputThunk]:
"""Generate using the completions api. Gives the input provided to the model without templating."""
if tool_calls:
FancyLogger.get_logger().warning(
"The raw endpoint does not support tool calling at the moment."
)

model_opts = self._simplify_and_merge(model_options)
seed = model_opts.get(ModelOption.SEED, None)
if seed is not None:
Expand Down Expand Up @@ -561,28 +567,34 @@ def _generate_from_raw(
sequences_to_decode, skip_special_tokens=True
)

results = [
ModelOutputThunk(value=decoded_result) for decoded_result in decoded_results
]
results = []
for i, decoded_result in enumerate(decoded_results):
n_prompt_tokens = inputs["input_ids"][i].size(0) # type: ignore
n_completion_tokens = len(sequences_to_decode[i])
result = ModelOutputThunk(
value=decoded_result,
meta={
"usage": {
"prompt_tokens": n_prompt_tokens, # type: ignore
"completion_tokens": n_completion_tokens,
"total_tokens": n_prompt_tokens + n_completion_tokens,
}
},
)

for i, result in enumerate(results):
self.formatter.parse(actions[i], result)

if generate_logs is not None:
assert isinstance(generate_logs, list)
date = datetime.datetime.now()

for i in range(len(prompts)):
generate_log = GenerateLog()
generate_log.prompt = prompts[i]
generate_log.backend = f"hf::{self.model_id!s}"
generate_log.model_options = model_opts
generate_log.date = date
generate_log.model_output = decoded_results
generate_log.extra = {"format": format, "seed": seed}
generate_log.action = actions[i]
generate_log.result = results[i]
generate_logs.append(generate_log)
generate_log = GenerateLog()
generate_log.prompt = self.formatter.print(actions[i])
generate_log.backend = f"hf::{self.model_id!s}"
generate_log.model_options = model_opts
generate_log.date = datetime.datetime.now()
generate_log.model_output = decoded_result
generate_log.extra = {"format": format, "seed": seed}
generate_log.action = actions[i]

result._generate_log = generate_log
results.append(result)

return results

Expand Down
5 changes: 3 additions & 2 deletions mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,13 +475,14 @@ def _extract_tools(
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
return tools

def _generate_from_raw(
def generate_from_raw(
self,
actions: list[Component | CBlock],
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
generate_logs: list[GenerateLog] | None = None,
tool_calls: bool = False,
) -> list[ModelOutputThunk]:
"""Generate using the completions api. Gives the input provided to the model without templating."""
raise NotImplementedError("This method is not implemented yet.")
Comment on lines 487 to 488
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we are promoting this to be a supported API, we should add support for LiteLLM.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, but I might make a separate PR for this just to unblock this issue

Expand Down
62 changes: 38 additions & 24 deletions mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,19 +385,24 @@ def generate_from_chat_context(

return output

def _generate_from_raw(
def generate_from_raw(
self,
actions: list[Component | CBlock],
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
generate_logs: list[GenerateLog] | None = None,
tool_calls: bool = False,
) -> list[ModelOutputThunk]:
"""Generate using the generate api. Gives the input provided to the model without templating."""
if len(actions) > 1:
FancyLogger.get_logger().info(
"Ollama doesn't support batching; will attempt to process concurrently."
)
if tool_calls:
FancyLogger.get_logger().warning(
"The completion endpoint does not support tool calling at the moment."
)

model_opts = self._simplify_and_merge(model_options)

Expand Down Expand Up @@ -438,32 +443,41 @@ async def get_response():
else:
result = ModelOutputThunk(
value=response.response,
meta={"generate_response": response.model_dump()},
meta={
"generate_response": response.model_dump(),
"usage": {
"completion_tokens": response.eval_count,
"prompt_tokens": response.prompt_eval_count,
"total_tokens": (
response.prompt_eval_count + response.eval_count
if response.prompt_eval_count is not None
and response.eval_count is not None
else None
),
},
},
)

self.formatter.parse(actions[i], result)
results.append(result)

if generate_logs is not None:
# noinspection DuplicatedCode
assert isinstance(generate_logs, list)
generate_log = GenerateLog()
generate_log.prompt = prompts[i]
generate_log.backend = f"ollama::{self.model_id!s}"
generate_log.date = date
generate_log.model_options = model_opts
generate_log.model_output = result.value
generate_log.extra = {
"format": format,
"thinking": model_opts.get(ModelOption.THINKING, None),
"seed": model_opts.get(ModelOption.SEED, None),
}
generate_log.action = actions[i]
generate_log.result = result

if error:
generate_log.extra["error"] = error
generate_logs.append(generate_log)
generate_log = GenerateLog()
generate_log.prompt = prompts[i]
generate_log.backend = f"ollama::{self.model_id!s}"
generate_log.date = date
generate_log.model_options = model_opts
generate_log.model_output = result.value
generate_log.extra = {
"format": format,
"thinking": model_opts.get(ModelOption.THINKING, None),
"seed": model_opts.get(ModelOption.SEED, None),
}
generate_log.action = actions[i]

if error:
generate_log.extra["error"] = error
result._generate_log = generate_log

results.append(result)

return results

Expand Down
66 changes: 37 additions & 29 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
from collections.abc import Callable, Coroutine
from enum import Enum
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, overload
from urllib.parse import urlparse

import openai
Expand Down Expand Up @@ -646,13 +646,14 @@ async def post_processing(
generate_log.result = mot
mot._generate_log = generate_log

def _generate_from_raw(
def generate_from_raw(
self,
actions: list[Component | CBlock],
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
generate_logs: list[GenerateLog] | None = None,
tool_calls: bool = False,
) -> list[ModelOutputThunk]:
"""Generate using the completions api. Gives the input provided to the model without templating."""
extra_body = {}
Expand All @@ -664,6 +665,10 @@ def _generate_from_raw(

# Some versions (like vllm's version) of the OpenAI API support structured decoding for completions requests.
extra_body["guided_json"] = format.model_json_schema()
if tool_calls:
FancyLogger.get_logger().warning(
"The completion endpoint does not support tool calling at the moment."
)

model_opts = self._simplify_and_merge(model_options, is_chat_context=False)

Expand All @@ -689,32 +694,35 @@ def _generate_from_raw(
# Necessary for type checker.
assert isinstance(completion_response, Completion)

results = [
ModelOutputThunk(
value=response.text,
meta={"oai_completion_response": response.model_dump()},
)
for response in completion_response.choices
]

for i, result in enumerate(results):
self.formatter.parse(actions[i], result)

if generate_logs is not None:
assert isinstance(generate_logs, list)
date = datetime.datetime.now()

for i in range(len(prompts)):
generate_log = GenerateLog()
generate_log.prompt = prompts[i]
generate_log.backend = f"openai::{self.model_id!s}"
generate_log.model_options = model_opts
generate_log.date = date
generate_log.model_output = completion_response
generate_log.extra = {"seed": model_opts.get("seed", None)}
generate_log.action = actions[i]
generate_log.result = results[i]
generate_logs.append(generate_log)
results = []
for response, action, prompt in zip(
completion_response.choices, actions, prompts
):
output = ModelOutputThunk(None)
output.value = response.text
output._context = None # There is no context for generate_from_raw for now
output._action = action
output._model_options = model_opts
output._meta = {
"oai_completion_response": response.model_dump(),
"usage": completion_response.usage.model_dump()
if completion_response.usage
else None,
}

Comment on lines +706 to +712
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like LiteLLM has standardized on:

  "usage": {
    "prompt_tokens": 18,
    "completion_tokens": 25,
    "total_tokens": 43
  }

Even if OpenAI doesn't give us back correct values, it might still be worthwhile to standardize these fields in ModelOutputThunks and just not set them for OpenAI. Otherwise, anyone that has to use these values needs to be aware of every backend (instead of just checking if the values are None).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is similar to what openai returns. I think it should be part of the .usage object.

self.formatter.parse(action, output)

generate_log = GenerateLog()
generate_log.prompt = prompt
generate_log.backend = f"openai::{self.model_id!s}"
generate_log.model_options = model_opts
generate_log.date = datetime.datetime.now()
generate_log.model_output = completion_response
generate_log.extra = {"seed": model_opts.get("seed", None)}
generate_log.action = action
output._generate_log = generate_log

results.append(output)

return results

Expand Down
40 changes: 21 additions & 19 deletions mellea/backends/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,15 +393,21 @@ async def post_processing(

mot._generate_log = generate_log

def _generate_from_raw(
def generate_from_raw(
self,
actions: list[Component | CBlock],
ctx: Context,
*,
format: type[BaseModelSubclass] | None = None,
model_options: dict | None = None,
generate_logs: list[GenerateLog] | None = None,
tool_calls: bool = False,
) -> list[ModelOutputThunk]:
"""Generate using the completions api. Gives the input provided to the model without templating."""
if tool_calls:
FancyLogger.get_logger().warning(
"The completion endpoint does not support tool calling at the moment."
)

model_options = self._simplify_and_merge(model_options)

prompts = [self.formatter.print(action) for action in actions]
Expand Down Expand Up @@ -447,25 +453,21 @@ async def generate_all(prompts):

for i, result in enumerate(results):
self.formatter.parse(actions[i], result)

if generate_logs is not None:
assert isinstance(generate_logs, list)
date = datetime.datetime.now()

for i in range(len(prompts)):
generate_log = GenerateLog()
generate_log.prompt = prompts[i]
generate_log.backend = f"vllm::{self.model_id!s}"
generate_log.model_options = model_options
generate_log.date = date
generate_log.model_output = decoded_results
generate_log.extra = {
"format": format,
"seed": model_options.get(ModelOption.SEED, None),
}
generate_log.action = actions[i]
generate_log.result = results[i]
generate_logs.append(generate_log)
generate_log = GenerateLog()
generate_log.prompt = prompts[i]
generate_log.backend = f"vllm::{self.model_id!s}"
generate_log.model_options = model_options
generate_log.date = date
generate_log.model_output = decoded_results
generate_log.extra = {
"format": format,
"seed": model_options.get(ModelOption.SEED, None),
}
generate_log.action = actions[i]
generate_log.result = results[i]
result._generate_log = generate_log

return results

Expand Down
Loading