-
Notifications
You must be signed in to change notification settings - Fork 49
feat: making generate_from_raw public #219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0303919
63287af
1414c78
69d06ab
0b68858
8700c29
b709120
ea20a79
23e8fa9
b6dc128
e8c0827
f4a8891
3e1423e
bb03db9
0226721
cfc9d6c
f9bcec6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -497,15 +497,21 @@ async def post_processing( | |
|
|
||
| mot._generate_log = generate_log | ||
|
|
||
| def _generate_from_raw( | ||
| def generate_from_raw( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This implementation of generate_from_raw continues the current implementation's decision to be synchronous. Are we okay with not allowing async batching operations (and it might be the case that these batching interfaces don't support async, I haven't looked)?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think support of async is inconsistent here. Given that |
||
| self, | ||
| actions: list[Component | CBlock], | ||
| ctx: Context, | ||
| *, | ||
| format: type[BaseModelSubclass] | None = None, | ||
| model_options: dict | None = None, | ||
| generate_logs: list[GenerateLog] | None = None, | ||
| tool_calls: bool = False, | ||
| ) -> list[ModelOutputThunk]: | ||
| """Generate using the completions api. Gives the input provided to the model without templating.""" | ||
| if tool_calls: | ||
| FancyLogger.get_logger().warning( | ||
| "The raw endpoint does not support tool calling at the moment." | ||
| ) | ||
|
|
||
| model_opts = self._simplify_and_merge(model_options) | ||
| seed = model_opts.get(ModelOption.SEED, None) | ||
| if seed is not None: | ||
|
|
@@ -561,28 +567,34 @@ def _generate_from_raw( | |
| sequences_to_decode, skip_special_tokens=True | ||
| ) | ||
|
|
||
| results = [ | ||
| ModelOutputThunk(value=decoded_result) for decoded_result in decoded_results | ||
| ] | ||
| results = [] | ||
| for i, decoded_result in enumerate(decoded_results): | ||
| n_prompt_tokens = inputs["input_ids"][i].size(0) # type: ignore | ||
| n_completion_tokens = len(sequences_to_decode[i]) | ||
| result = ModelOutputThunk( | ||
| value=decoded_result, | ||
| meta={ | ||
| "usage": { | ||
| "prompt_tokens": n_prompt_tokens, # type: ignore | ||
| "completion_tokens": n_completion_tokens, | ||
| "total_tokens": n_prompt_tokens + n_completion_tokens, | ||
| } | ||
| }, | ||
| ) | ||
|
|
||
| for i, result in enumerate(results): | ||
| self.formatter.parse(actions[i], result) | ||
|
|
||
| if generate_logs is not None: | ||
| assert isinstance(generate_logs, list) | ||
| date = datetime.datetime.now() | ||
|
|
||
| for i in range(len(prompts)): | ||
| generate_log = GenerateLog() | ||
| generate_log.prompt = prompts[i] | ||
| generate_log.backend = f"hf::{self.model_id!s}" | ||
| generate_log.model_options = model_opts | ||
| generate_log.date = date | ||
| generate_log.model_output = decoded_results | ||
| generate_log.extra = {"format": format, "seed": seed} | ||
| generate_log.action = actions[i] | ||
| generate_log.result = results[i] | ||
| generate_logs.append(generate_log) | ||
| generate_log = GenerateLog() | ||
| generate_log.prompt = self.formatter.print(actions[i]) | ||
| generate_log.backend = f"hf::{self.model_id!s}" | ||
| generate_log.model_options = model_opts | ||
| generate_log.date = datetime.datetime.now() | ||
| generate_log.model_output = decoded_result | ||
| generate_log.extra = {"format": format, "seed": seed} | ||
| generate_log.action = actions[i] | ||
|
|
||
| result._generate_log = generate_log | ||
| results.append(result) | ||
|
|
||
| return results | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -475,13 +475,14 @@ def _extract_tools( | |
| FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") | ||
| return tools | ||
|
|
||
| def _generate_from_raw( | ||
| def generate_from_raw( | ||
| self, | ||
| actions: list[Component | CBlock], | ||
| ctx: Context, | ||
| *, | ||
| format: type[BaseModelSubclass] | None = None, | ||
| model_options: dict | None = None, | ||
| generate_logs: list[GenerateLog] | None = None, | ||
| tool_calls: bool = False, | ||
| ) -> list[ModelOutputThunk]: | ||
| """Generate using the completions api. Gives the input provided to the model without templating.""" | ||
| raise NotImplementedError("This method is not implemented yet.") | ||
|
Comment on lines
487
to
488
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if we are promoting this to be a supported API, we should add support for LiteLLM.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair, but I might make a separate PR for this just to unblock this issue |
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,7 +8,7 @@ | |
| import json | ||
| from collections.abc import Callable, Coroutine | ||
| from enum import Enum | ||
| from typing import TYPE_CHECKING, Any | ||
| from typing import TYPE_CHECKING, Any, overload | ||
| from urllib.parse import urlparse | ||
|
|
||
| import openai | ||
|
|
@@ -646,13 +646,14 @@ async def post_processing( | |
| generate_log.result = mot | ||
| mot._generate_log = generate_log | ||
|
|
||
| def _generate_from_raw( | ||
| def generate_from_raw( | ||
| self, | ||
| actions: list[Component | CBlock], | ||
| ctx: Context, | ||
| *, | ||
| format: type[BaseModelSubclass] | None = None, | ||
| model_options: dict | None = None, | ||
| generate_logs: list[GenerateLog] | None = None, | ||
| tool_calls: bool = False, | ||
| ) -> list[ModelOutputThunk]: | ||
| """Generate using the completions api. Gives the input provided to the model without templating.""" | ||
| extra_body = {} | ||
|
|
@@ -664,6 +665,10 @@ def _generate_from_raw( | |
|
|
||
| # Some versions (like vllm's version) of the OpenAI API support structured decoding for completions requests. | ||
| extra_body["guided_json"] = format.model_json_schema() | ||
| if tool_calls: | ||
| FancyLogger.get_logger().warning( | ||
| "The completion endpoint does not support tool calling at the moment." | ||
| ) | ||
jakelorocco marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| model_opts = self._simplify_and_merge(model_options, is_chat_context=False) | ||
|
|
||
|
|
@@ -689,32 +694,35 @@ def _generate_from_raw( | |
| # Necessary for type checker. | ||
| assert isinstance(completion_response, Completion) | ||
|
|
||
| results = [ | ||
| ModelOutputThunk( | ||
| value=response.text, | ||
| meta={"oai_completion_response": response.model_dump()}, | ||
| ) | ||
| for response in completion_response.choices | ||
| ] | ||
|
|
||
| for i, result in enumerate(results): | ||
| self.formatter.parse(actions[i], result) | ||
|
|
||
| if generate_logs is not None: | ||
| assert isinstance(generate_logs, list) | ||
| date = datetime.datetime.now() | ||
|
|
||
| for i in range(len(prompts)): | ||
| generate_log = GenerateLog() | ||
| generate_log.prompt = prompts[i] | ||
| generate_log.backend = f"openai::{self.model_id!s}" | ||
| generate_log.model_options = model_opts | ||
| generate_log.date = date | ||
| generate_log.model_output = completion_response | ||
| generate_log.extra = {"seed": model_opts.get("seed", None)} | ||
| generate_log.action = actions[i] | ||
| generate_log.result = results[i] | ||
| generate_logs.append(generate_log) | ||
| results = [] | ||
| for response, action, prompt in zip( | ||
| completion_response.choices, actions, prompts | ||
| ): | ||
| output = ModelOutputThunk(None) | ||
| output.value = response.text | ||
| output._context = None # There is no context for generate_from_raw for now | ||
| output._action = action | ||
| output._model_options = model_opts | ||
| output._meta = { | ||
| "oai_completion_response": response.model_dump(), | ||
| "usage": completion_response.usage.model_dump() | ||
| if completion_response.usage | ||
| else None, | ||
| } | ||
|
|
||
|
Comment on lines
+706
to
+712
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like LiteLLM has standardized on: Even if OpenAI doesn't give us back correct values, it might still be worthwhile to standardize these fields in ModelOutputThunks and just not set them for OpenAI. Otherwise, anyone that has to use these values needs to be aware of every backend (instead of just checking if the values are None).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is similar to what openai returns. I think it should be part of the |
||
| self.formatter.parse(action, output) | ||
|
|
||
| generate_log = GenerateLog() | ||
| generate_log.prompt = prompt | ||
| generate_log.backend = f"openai::{self.model_id!s}" | ||
| generate_log.model_options = model_opts | ||
| generate_log.date = datetime.datetime.now() | ||
| generate_log.model_output = completion_response | ||
| generate_log.extra = {"seed": model_opts.get("seed", None)} | ||
| generate_log.action = action | ||
| output._generate_log = generate_log | ||
|
|
||
| results.append(output) | ||
|
|
||
| return results | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why we are passing in a context if we are not using it. I think we should just leave it out of the function definition if we aren't going to use it for generation or append to it after the generation. I see that the docstring documents that it's not used, but I think that just adds extra confusion.
It was also my impression from the team meeting yesterday that we don't have a clear idea going forward of what we want context to eventually do here.
I think the one reason for keeping it would be to eventually utilize it. However, that is functionally an api change if we were to change the behavior of context. As a result, I think it would make sense to just break the function def in that instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree we should remove both
toolsandctx, but only kept it in cause in the original issue, the motivation was to keep the signature consistent withgenerate_from_context.