Skip to content

Commit 1404b38

Browse files
authored
Enable ephemeral prompt caching with LangSmith metrics (#28)
* enable ephemeral prompt caching by default * add detailed token usage info for LangServe UI * disable prompt caching when using Bedrock Invoke API
1 parent 2959001 commit 1404b38

File tree

1 file changed

+84
-43
lines changed

1 file changed

+84
-43
lines changed

jupyter_ai_jupyternaut/jupyternaut/chat_models.py

Lines changed: 84 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Sequence,
1818
Tuple,
1919
Type,
20+
TYPE_CHECKING,
2021
Union,
2122
)
2223

@@ -64,6 +65,8 @@
6465

6566
logger = logging.getLogger(__name__)
6667

68+
if TYPE_CHECKING:
69+
from litellm import ModelResponseStream, Usage
6770

6871
class ChatLiteLLMException(Exception):
6972
"""Error with the `LiteLLM I/O` library"""
@@ -341,9 +344,30 @@ async def acompletion_with_retry(
341344
"""Use tenacity to retry the async completion call."""
342345
retry_decorator = _create_retry_decorator(self, run_manager=run_manager)
343346

347+
# Enables ephemeral prompt caching of the last system message by
348+
# default when passed to `litellm.acompletion()`.
349+
#
350+
# See: https://docs.litellm.ai/docs/tutorials/prompt_caching
351+
cache_control_kwargs = {
352+
"cache_control_injection_points": [
353+
{ "location": "message", "role": "system" }
354+
]
355+
}
356+
357+
# Disable ephemeral prompt caching on Amazon Bedrock when the
358+
# InvokeModel API is used instead of Converse API. This is motivated by
359+
# an upstream bug in LiteLLM that has yet to be patched.
360+
#
361+
# See: github.com/BerriAI/litellm/issues/17479
362+
if self.model.startswith("bedrock/") and not self.model.startswith("bedrock/converse/"):
363+
cache_control_kwargs = {}
364+
344365
@retry_decorator
345366
async def _completion_with_retry(**kwargs: Any) -> Any:
346-
return await self.client.acompletion(**kwargs)
367+
return await self.client.acompletion(
368+
**kwargs,
369+
**cache_control_kwargs,
370+
)
347371

348372
return await _completion_with_retry(**kwargs)
349373

@@ -456,30 +480,10 @@ def _stream(
456480
run_manager: Optional[CallbackManagerForLLMRun] = None,
457481
**kwargs: Any,
458482
) -> Iterator[ChatGenerationChunk]:
459-
message_dicts, params = self._create_message_dicts(messages, stop)
460-
params = {**params, **kwargs, "stream": True}
461-
params["stream_options"] = self.stream_options
462-
default_chunk_class = AIMessageChunk
463-
for chunk in self.completion_with_retry(
464-
messages=message_dicts, run_manager=run_manager, **params
465-
):
466-
usage_metadata = None
467-
if not isinstance(chunk, dict):
468-
chunk = chunk.model_dump()
469-
if "usage" in chunk and chunk["usage"]:
470-
usage_metadata = _create_usage_metadata(chunk["usage"])
471-
if len(chunk["choices"]) == 0:
472-
continue
473-
delta = chunk["choices"][0]["delta"]
474-
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
475-
if usage_metadata and isinstance(chunk, AIMessageChunk):
476-
chunk.usage_metadata = usage_metadata
477-
478-
default_chunk_class = chunk.__class__
479-
cg_chunk = ChatGenerationChunk(message=chunk)
480-
if run_manager:
481-
run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
482-
yield cg_chunk
483+
# deleting this method minimizes code duplication.
484+
# we can run `_astream()` in a `ThreadPoolExecutor` if we need to
485+
# implement this method in the future.
486+
raise NotImplementedError()
483487

484488
async def _astream(
485489
self,
@@ -491,25 +495,41 @@ async def _astream(
491495
message_dicts, params = self._create_message_dicts(messages, stop)
492496
params = {**params, **kwargs, "stream": True}
493497
params["stream_options"] = self.stream_options
498+
499+
# This local variable hints the type of successive chunks when a
500+
# new chunk differs from the previous one in type.
501+
# (unsure if this is required)
494502
default_chunk_class = AIMessageChunk
495-
async for chunk in await self.acompletion_with_retry(
503+
504+
async for _untyped_chunk in await self.acompletion_with_retry(
496505
messages=message_dicts, run_manager=run_manager, **params
497506
):
498-
usage_metadata = None
499-
if not isinstance(chunk, dict):
500-
chunk = chunk.model_dump()
501-
if "usage" in chunk and chunk["usage"]:
502-
usage_metadata = _create_usage_metadata(chunk["usage"])
503-
if len(chunk["choices"]) == 0:
507+
# LiteLLM chunk
508+
litellm_chunk: ModelResponseStream = _untyped_chunk
509+
# LiteLLM usage metadata
510+
litellm_usage: Usage | None = getattr(litellm_chunk, 'usage', None)
511+
512+
# Continue (do nothing) if the chunk is empty
513+
if len(litellm_chunk.choices) == 0:
504514
continue
505-
delta = chunk["choices"][0]["delta"]
506-
chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
507-
if usage_metadata and isinstance(chunk, AIMessageChunk):
508-
chunk.usage_metadata = usage_metadata
509-
default_chunk_class = chunk.__class__
510-
cg_chunk = ChatGenerationChunk(message=chunk)
515+
516+
# Extract delta from chunk
517+
delta = litellm_chunk.choices[0].delta
518+
519+
# Convert LiteLLM delta (litellm.Delta) to LangChain
520+
# chunk (BaseMessageChunk)
521+
message_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
522+
523+
# Append usage metadata if it exists
524+
if litellm_usage and isinstance(message_chunk, AIMessageChunk):
525+
message_chunk.usage_metadata = _create_usage_metadata(litellm_usage)
526+
527+
# Set type of successive chunks until a new chunk changes type
528+
default_chunk_class = message_chunk.__class__
529+
530+
cg_chunk = ChatGenerationChunk(message=message_chunk)
511531
if run_manager:
512-
await run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk)
532+
await run_manager.on_llm_new_token(message_chunk.content, chunk=cg_chunk)
513533
yield cg_chunk
514534

515535
async def _agenerate(
@@ -612,11 +632,32 @@ def _llm_type(self) -> str:
612632
return "litellm-chat"
613633

614634

615-
def _create_usage_metadata(token_usage: Mapping[str, Any]) -> UsageMetadata:
616-
input_tokens = token_usage.get("prompt_tokens", 0)
617-
output_tokens = token_usage.get("completion_tokens", 0)
635+
def _create_usage_metadata(usage: Usage) -> UsageMetadata:
636+
"""
637+
Converts LiteLLM usage metadata object (`litellm.Usage`) into LangChain usage
638+
metadata object (`langchain_core.messages.ai.UsageMetadata`).
639+
"""
640+
input_tokens = usage.prompt_tokens or 0
641+
input_audio_tokens = usage.prompt_tokens_details.audio_tokens or 0
642+
output_tokens = usage.completion_tokens or 0
643+
output_audio_tokens = usage.completion_tokens_details.audio_tokens or 0
644+
output_reasoning_tokens = usage.completion_tokens_details.reasoning_tokens or 0
645+
total_tokens = input_tokens + output_tokens
646+
647+
cache_creation_tokens = usage.prompt_tokens_details.cache_creation_tokens or 0
648+
cache_read_tokens = usage.prompt_tokens_details.cached_tokens or 0
649+
618650
return UsageMetadata(
619651
input_tokens=input_tokens,
620652
output_tokens=output_tokens,
621-
total_tokens=input_tokens + output_tokens,
653+
total_tokens=total_tokens,
654+
input_token_details={
655+
"cache_creation": cache_creation_tokens,
656+
"cache_read": cache_read_tokens,
657+
"audio": input_audio_tokens,
658+
},
659+
output_token_details={
660+
"audio": output_audio_tokens,
661+
"reasoning": output_reasoning_tokens,
662+
}
622663
)

0 commit comments

Comments
 (0)