Skip to content

Commit 7eae224

Browse files
authored
feat: making generate_from_raw public (#219)
* feat: making generate_from_raw public in openai * feat: making generate_from_raw public in ollama * feat: making generate_from_raw public in ollama * feat: making generate_from_raw public in hf, watsonx * feat: making generate_from_raw public in vllm * adding warning to hf backend * attempt at standardizing usage metrics
1 parent 6aea9dc commit 7eae224

File tree

12 files changed

+301
-179
lines changed

12 files changed

+301
-179
lines changed

mellea/backends/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,21 @@ def generate_from_context(
5858
...
5959

6060
@abc.abstractmethod
61-
def _generate_from_raw(
61+
def generate_from_raw(
6262
self,
6363
actions: list[Component | CBlock],
64+
ctx: Context,
6465
*,
6566
format: type[BaseModelSubclass] | None = None,
6667
model_options: dict | None = None,
67-
generate_logs: list[GenerateLog] | None = None,
68+
tool_calls: bool = False,
6869
) -> list[ModelOutputThunk]:
6970
"""Generates a model output from the provided input. Does not use context or templates.
7071
7172
Args:
7273
actions: list of actions to generate responses for. Each action is separate.
74+
ctx: context passed to generation. Currently not used in generate_from_raw
7375
format: A response format to used for structured outputs / constrained decoding. Note: some backends do not support this parameter. They will log warnings and continue to generate.
7476
model_options: Any model options to upsert into the defaults for this call.
75-
generate_logs: a `GenerateLog` instance to add log information to.
77+
tool_calls: Always set to false unless supported by backend.
7678
"""

mellea/backends/huggingface.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -497,15 +497,21 @@ async def post_processing(
497497

498498
mot._generate_log = generate_log
499499

500-
def _generate_from_raw(
500+
def generate_from_raw(
501501
self,
502502
actions: list[Component | CBlock],
503+
ctx: Context,
503504
*,
504505
format: type[BaseModelSubclass] | None = None,
505506
model_options: dict | None = None,
506-
generate_logs: list[GenerateLog] | None = None,
507+
tool_calls: bool = False,
507508
) -> list[ModelOutputThunk]:
508509
"""Generate using the completions api. Gives the input provided to the model without templating."""
510+
if tool_calls:
511+
FancyLogger.get_logger().warning(
512+
"The raw endpoint does not support tool calling at the moment."
513+
)
514+
509515
model_opts = self._simplify_and_merge(model_options)
510516
seed = model_opts.get(ModelOption.SEED, None)
511517
if seed is not None:
@@ -561,28 +567,34 @@ def _generate_from_raw(
561567
sequences_to_decode, skip_special_tokens=True
562568
)
563569

564-
results = [
565-
ModelOutputThunk(value=decoded_result) for decoded_result in decoded_results
566-
]
570+
results = []
571+
for i, decoded_result in enumerate(decoded_results):
572+
n_prompt_tokens = inputs["input_ids"][i].size(0) # type: ignore
573+
n_completion_tokens = len(sequences_to_decode[i])
574+
result = ModelOutputThunk(
575+
value=decoded_result,
576+
meta={
577+
"usage": {
578+
"prompt_tokens": n_prompt_tokens, # type: ignore
579+
"completion_tokens": n_completion_tokens,
580+
"total_tokens": n_prompt_tokens + n_completion_tokens,
581+
}
582+
},
583+
)
567584

568-
for i, result in enumerate(results):
569585
self.formatter.parse(actions[i], result)
570586

571-
if generate_logs is not None:
572-
assert isinstance(generate_logs, list)
573-
date = datetime.datetime.now()
574-
575-
for i in range(len(prompts)):
576-
generate_log = GenerateLog()
577-
generate_log.prompt = prompts[i]
578-
generate_log.backend = f"hf::{self.model_id!s}"
579-
generate_log.model_options = model_opts
580-
generate_log.date = date
581-
generate_log.model_output = decoded_results
582-
generate_log.extra = {"format": format, "seed": seed}
583-
generate_log.action = actions[i]
584-
generate_log.result = results[i]
585-
generate_logs.append(generate_log)
587+
generate_log = GenerateLog()
588+
generate_log.prompt = self.formatter.print(actions[i])
589+
generate_log.backend = f"hf::{self.model_id!s}"
590+
generate_log.model_options = model_opts
591+
generate_log.date = datetime.datetime.now()
592+
generate_log.model_output = decoded_result
593+
generate_log.extra = {"format": format, "seed": seed}
594+
generate_log.action = actions[i]
595+
596+
result._generate_log = generate_log
597+
results.append(result)
586598

587599
return results
588600

mellea/backends/litellm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -475,13 +475,14 @@ def _extract_tools(
475475
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")
476476
return tools
477477

478-
def _generate_from_raw(
478+
def generate_from_raw(
479479
self,
480480
actions: list[Component | CBlock],
481+
ctx: Context,
481482
*,
482483
format: type[BaseModelSubclass] | None = None,
483484
model_options: dict | None = None,
484-
generate_logs: list[GenerateLog] | None = None,
485+
tool_calls: bool = False,
485486
) -> list[ModelOutputThunk]:
486487
"""Generate using the completions api. Gives the input provided to the model without templating."""
487488
raise NotImplementedError("This method is not implemented yet.")

mellea/backends/ollama.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -385,19 +385,24 @@ def generate_from_chat_context(
385385

386386
return output
387387

388-
def _generate_from_raw(
388+
def generate_from_raw(
389389
self,
390390
actions: list[Component | CBlock],
391+
ctx: Context,
391392
*,
392393
format: type[BaseModelSubclass] | None = None,
393394
model_options: dict | None = None,
394-
generate_logs: list[GenerateLog] | None = None,
395+
tool_calls: bool = False,
395396
) -> list[ModelOutputThunk]:
396397
"""Generate using the generate api. Gives the input provided to the model without templating."""
397398
if len(actions) > 1:
398399
FancyLogger.get_logger().info(
399400
"Ollama doesn't support batching; will attempt to process concurrently."
400401
)
402+
if tool_calls:
403+
FancyLogger.get_logger().warning(
404+
"The completion endpoint does not support tool calling at the moment."
405+
)
401406

402407
model_opts = self._simplify_and_merge(model_options)
403408

@@ -438,32 +443,41 @@ async def get_response():
438443
else:
439444
result = ModelOutputThunk(
440445
value=response.response,
441-
meta={"generate_response": response.model_dump()},
446+
meta={
447+
"generate_response": response.model_dump(),
448+
"usage": {
449+
"completion_tokens": response.eval_count,
450+
"prompt_tokens": response.prompt_eval_count,
451+
"total_tokens": (
452+
response.prompt_eval_count + response.eval_count
453+
if response.prompt_eval_count is not None
454+
and response.eval_count is not None
455+
else None
456+
),
457+
},
458+
},
442459
)
443460

444461
self.formatter.parse(actions[i], result)
445-
results.append(result)
446462

447-
if generate_logs is not None:
448-
# noinspection DuplicatedCode
449-
assert isinstance(generate_logs, list)
450-
generate_log = GenerateLog()
451-
generate_log.prompt = prompts[i]
452-
generate_log.backend = f"ollama::{self.model_id!s}"
453-
generate_log.date = date
454-
generate_log.model_options = model_opts
455-
generate_log.model_output = result.value
456-
generate_log.extra = {
457-
"format": format,
458-
"thinking": model_opts.get(ModelOption.THINKING, None),
459-
"seed": model_opts.get(ModelOption.SEED, None),
460-
}
461-
generate_log.action = actions[i]
462-
generate_log.result = result
463-
464-
if error:
465-
generate_log.extra["error"] = error
466-
generate_logs.append(generate_log)
463+
generate_log = GenerateLog()
464+
generate_log.prompt = prompts[i]
465+
generate_log.backend = f"ollama::{self.model_id!s}"
466+
generate_log.date = date
467+
generate_log.model_options = model_opts
468+
generate_log.model_output = result.value
469+
generate_log.extra = {
470+
"format": format,
471+
"thinking": model_opts.get(ModelOption.THINKING, None),
472+
"seed": model_opts.get(ModelOption.SEED, None),
473+
}
474+
generate_log.action = actions[i]
475+
476+
if error:
477+
generate_log.extra["error"] = error
478+
result._generate_log = generate_log
479+
480+
results.append(result)
467481

468482
return results
469483

mellea/backends/openai.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import json
99
from collections.abc import Callable, Coroutine
1010
from enum import Enum
11-
from typing import TYPE_CHECKING, Any
11+
from typing import TYPE_CHECKING, Any, overload
1212
from urllib.parse import urlparse
1313

1414
import openai
@@ -646,13 +646,14 @@ async def post_processing(
646646
generate_log.result = mot
647647
mot._generate_log = generate_log
648648

649-
def _generate_from_raw(
649+
def generate_from_raw(
650650
self,
651651
actions: list[Component | CBlock],
652+
ctx: Context,
652653
*,
653654
format: type[BaseModelSubclass] | None = None,
654655
model_options: dict | None = None,
655-
generate_logs: list[GenerateLog] | None = None,
656+
tool_calls: bool = False,
656657
) -> list[ModelOutputThunk]:
657658
"""Generate using the completions api. Gives the input provided to the model without templating."""
658659
extra_body = {}
@@ -664,6 +665,10 @@ def _generate_from_raw(
664665

665666
# Some versions (like vllm's version) of the OpenAI API support structured decoding for completions requests.
666667
extra_body["guided_json"] = format.model_json_schema()
668+
if tool_calls:
669+
FancyLogger.get_logger().warning(
670+
"The completion endpoint does not support tool calling at the moment."
671+
)
667672

668673
model_opts = self._simplify_and_merge(model_options, is_chat_context=False)
669674

@@ -689,32 +694,35 @@ def _generate_from_raw(
689694
# Necessary for type checker.
690695
assert isinstance(completion_response, Completion)
691696

692-
results = [
693-
ModelOutputThunk(
694-
value=response.text,
695-
meta={"oai_completion_response": response.model_dump()},
696-
)
697-
for response in completion_response.choices
698-
]
699-
700-
for i, result in enumerate(results):
701-
self.formatter.parse(actions[i], result)
702-
703-
if generate_logs is not None:
704-
assert isinstance(generate_logs, list)
705-
date = datetime.datetime.now()
706-
707-
for i in range(len(prompts)):
708-
generate_log = GenerateLog()
709-
generate_log.prompt = prompts[i]
710-
generate_log.backend = f"openai::{self.model_id!s}"
711-
generate_log.model_options = model_opts
712-
generate_log.date = date
713-
generate_log.model_output = completion_response
714-
generate_log.extra = {"seed": model_opts.get("seed", None)}
715-
generate_log.action = actions[i]
716-
generate_log.result = results[i]
717-
generate_logs.append(generate_log)
697+
results = []
698+
for response, action, prompt in zip(
699+
completion_response.choices, actions, prompts
700+
):
701+
output = ModelOutputThunk(None)
702+
output.value = response.text
703+
output._context = None # There is no context for generate_from_raw for now
704+
output._action = action
705+
output._model_options = model_opts
706+
output._meta = {
707+
"oai_completion_response": response.model_dump(),
708+
"usage": completion_response.usage.model_dump()
709+
if completion_response.usage
710+
else None,
711+
}
712+
713+
self.formatter.parse(action, output)
714+
715+
generate_log = GenerateLog()
716+
generate_log.prompt = prompt
717+
generate_log.backend = f"openai::{self.model_id!s}"
718+
generate_log.model_options = model_opts
719+
generate_log.date = datetime.datetime.now()
720+
generate_log.model_output = completion_response
721+
generate_log.extra = {"seed": model_opts.get("seed", None)}
722+
generate_log.action = action
723+
output._generate_log = generate_log
724+
725+
results.append(output)
718726

719727
return results
720728

mellea/backends/vllm.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -393,15 +393,21 @@ async def post_processing(
393393

394394
mot._generate_log = generate_log
395395

396-
def _generate_from_raw(
396+
def generate_from_raw(
397397
self,
398398
actions: list[Component | CBlock],
399+
ctx: Context,
399400
*,
400401
format: type[BaseModelSubclass] | None = None,
401402
model_options: dict | None = None,
402-
generate_logs: list[GenerateLog] | None = None,
403+
tool_calls: bool = False,
403404
) -> list[ModelOutputThunk]:
404405
"""Generate using the completions api. Gives the input provided to the model without templating."""
406+
if tool_calls:
407+
FancyLogger.get_logger().warning(
408+
"The completion endpoint does not support tool calling at the moment."
409+
)
410+
405411
model_options = self._simplify_and_merge(model_options)
406412

407413
prompts = [self.formatter.print(action) for action in actions]
@@ -447,25 +453,21 @@ async def generate_all(prompts):
447453

448454
for i, result in enumerate(results):
449455
self.formatter.parse(actions[i], result)
450-
451-
if generate_logs is not None:
452-
assert isinstance(generate_logs, list)
453456
date = datetime.datetime.now()
454457

455-
for i in range(len(prompts)):
456-
generate_log = GenerateLog()
457-
generate_log.prompt = prompts[i]
458-
generate_log.backend = f"vllm::{self.model_id!s}"
459-
generate_log.model_options = model_options
460-
generate_log.date = date
461-
generate_log.model_output = decoded_results
462-
generate_log.extra = {
463-
"format": format,
464-
"seed": model_options.get(ModelOption.SEED, None),
465-
}
466-
generate_log.action = actions[i]
467-
generate_log.result = results[i]
468-
generate_logs.append(generate_log)
458+
generate_log = GenerateLog()
459+
generate_log.prompt = prompts[i]
460+
generate_log.backend = f"vllm::{self.model_id!s}"
461+
generate_log.model_options = model_options
462+
generate_log.date = date
463+
generate_log.model_output = decoded_results
464+
generate_log.extra = {
465+
"format": format,
466+
"seed": model_options.get(ModelOption.SEED, None),
467+
}
468+
generate_log.action = actions[i]
469+
generate_log.result = results[i]
470+
result._generate_log = generate_log
469471

470472
return results
471473

0 commit comments

Comments
 (0)