1717 Sequence ,
1818 Tuple ,
1919 Type ,
20+ TYPE_CHECKING ,
2021 Union ,
2122)
2223
6465
6566logger = logging .getLogger (__name__ )
6667
68+ if TYPE_CHECKING :
69+ from litellm import ModelResponseStream , Usage
6770
6871class ChatLiteLLMException (Exception ):
6972 """Error with the `LiteLLM I/O` library"""
@@ -341,9 +344,30 @@ async def acompletion_with_retry(
341344 """Use tenacity to retry the async completion call."""
342345 retry_decorator = _create_retry_decorator (self , run_manager = run_manager )
343346
347+ # Enables ephemeral prompt caching of the last system message by
348+ # default when passed to `litellm.acompletion()`.
349+ #
350+ # See: https://docs.litellm.ai/docs/tutorials/prompt_caching
351+ cache_control_kwargs = {
352+ "cache_control_injection_points" : [
353+ { "location" : "message" , "role" : "system" }
354+ ]
355+ }
356+
357+ # Disable ephemeral prompt caching on Amazon Bedrock when the
358+ # InvokeModel API is used instead of Converse API. This is motivated by
359+ # an upstream bug in LiteLLM that has yet to be patched.
360+ #
361+ # See: github.com/BerriAI/litellm/issues/17479
362+ if self .model .startswith ("bedrock/" ) and not self .model .startswith ("bedrock/converse/" ):
363+ cache_control_kwargs = {}
364+
344365 @retry_decorator
345366 async def _completion_with_retry (** kwargs : Any ) -> Any :
346- return await self .client .acompletion (** kwargs )
367+ return await self .client .acompletion (
368+ ** kwargs ,
369+ ** cache_control_kwargs ,
370+ )
347371
348372 return await _completion_with_retry (** kwargs )
349373
@@ -456,30 +480,10 @@ def _stream(
456480 run_manager : Optional [CallbackManagerForLLMRun ] = None ,
457481 ** kwargs : Any ,
458482 ) -> Iterator [ChatGenerationChunk ]:
459- message_dicts , params = self ._create_message_dicts (messages , stop )
460- params = {** params , ** kwargs , "stream" : True }
461- params ["stream_options" ] = self .stream_options
462- default_chunk_class = AIMessageChunk
463- for chunk in self .completion_with_retry (
464- messages = message_dicts , run_manager = run_manager , ** params
465- ):
466- usage_metadata = None
467- if not isinstance (chunk , dict ):
468- chunk = chunk .model_dump ()
469- if "usage" in chunk and chunk ["usage" ]:
470- usage_metadata = _create_usage_metadata (chunk ["usage" ])
471- if len (chunk ["choices" ]) == 0 :
472- continue
473- delta = chunk ["choices" ][0 ]["delta" ]
474- chunk = _convert_delta_to_message_chunk (delta , default_chunk_class )
475- if usage_metadata and isinstance (chunk , AIMessageChunk ):
476- chunk .usage_metadata = usage_metadata
477-
478- default_chunk_class = chunk .__class__
479- cg_chunk = ChatGenerationChunk (message = chunk )
480- if run_manager :
481- run_manager .on_llm_new_token (chunk .content , chunk = cg_chunk )
482- yield cg_chunk
483+ # deleting this method minimizes code duplication.
484+ # we can run `_astream()` in a `ThreadPoolExecutor` if we need to
485+ # implement this method in the future.
486+ raise NotImplementedError ()
483487
484488 async def _astream (
485489 self ,
@@ -491,25 +495,41 @@ async def _astream(
491495 message_dicts , params = self ._create_message_dicts (messages , stop )
492496 params = {** params , ** kwargs , "stream" : True }
493497 params ["stream_options" ] = self .stream_options
498+
499+ # This local variable hints the type of successive chunks when a
500+ # new chunk differs from the previous one in type.
501+ # (unsure if this is required)
494502 default_chunk_class = AIMessageChunk
495- async for chunk in await self .acompletion_with_retry (
503+
504+ async for _untyped_chunk in await self .acompletion_with_retry (
496505 messages = message_dicts , run_manager = run_manager , ** params
497506 ):
498- usage_metadata = None
499- if not isinstance (chunk , dict ):
500- chunk = chunk .model_dump ()
501- if "usage" in chunk and chunk ["usage" ]:
502- usage_metadata = _create_usage_metadata (chunk ["usage" ])
503- if len (chunk ["choices" ]) == 0 :
507+ # LiteLLM chunk
508+ litellm_chunk : ModelResponseStream = _untyped_chunk
509+ # LiteLLM usage metadata
510+ litellm_usage : Usage | None = getattr (litellm_chunk , 'usage' , None )
511+
512+ # Continue (do nothing) if the chunk is empty
513+ if len (litellm_chunk .choices ) == 0 :
504514 continue
505- delta = chunk ["choices" ][0 ]["delta" ]
506- chunk = _convert_delta_to_message_chunk (delta , default_chunk_class )
507- if usage_metadata and isinstance (chunk , AIMessageChunk ):
508- chunk .usage_metadata = usage_metadata
509- default_chunk_class = chunk .__class__
510- cg_chunk = ChatGenerationChunk (message = chunk )
515+
516+ # Extract delta from chunk
517+ delta = litellm_chunk .choices [0 ].delta
518+
519+ # Convert LiteLLM delta (litellm.Delta) to LangChain
520+ # chunk (BaseMessageChunk)
521+ message_chunk = _convert_delta_to_message_chunk (delta , default_chunk_class )
522+
523+ # Append usage metadata if it exists
524+ if litellm_usage and isinstance (message_chunk , AIMessageChunk ):
525+ message_chunk .usage_metadata = _create_usage_metadata (litellm_usage )
526+
527+ # Set type of successive chunks until a new chunk changes type
528+ default_chunk_class = message_chunk .__class__
529+
530+ cg_chunk = ChatGenerationChunk (message = message_chunk )
511531 if run_manager :
512- await run_manager .on_llm_new_token (chunk .content , chunk = cg_chunk )
532+ await run_manager .on_llm_new_token (message_chunk .content , chunk = cg_chunk )
513533 yield cg_chunk
514534
515535 async def _agenerate (
@@ -612,11 +632,32 @@ def _llm_type(self) -> str:
612632 return "litellm-chat"
613633
614634
615- def _create_usage_metadata (token_usage : Mapping [str , Any ]) -> UsageMetadata :
616- input_tokens = token_usage .get ("prompt_tokens" , 0 )
617- output_tokens = token_usage .get ("completion_tokens" , 0 )
635+ def _create_usage_metadata (usage : Usage ) -> UsageMetadata :
636+ """
637+ Converts LiteLLM usage metadata object (`litellm.Usage`) into LangChain usage
638+ metadata object (`langchain_core.messages.ai.UsageMetadata`).
639+ """
640+ input_tokens = usage .prompt_tokens or 0
641+ input_audio_tokens = usage .prompt_tokens_details .audio_tokens or 0
642+ output_tokens = usage .completion_tokens or 0
643+ output_audio_tokens = usage .completion_tokens_details .audio_tokens or 0
644+ output_reasoning_tokens = usage .completion_tokens_details .reasoning_tokens or 0
645+ total_tokens = input_tokens + output_tokens
646+
647+ cache_creation_tokens = usage .prompt_tokens_details .cache_creation_tokens or 0
648+ cache_read_tokens = usage .prompt_tokens_details .cached_tokens or 0
649+
618650 return UsageMetadata (
619651 input_tokens = input_tokens ,
620652 output_tokens = output_tokens ,
621- total_tokens = input_tokens + output_tokens ,
653+ total_tokens = total_tokens ,
654+ input_token_details = {
655+ "cache_creation" : cache_creation_tokens ,
656+ "cache_read" : cache_read_tokens ,
657+ "audio" : input_audio_tokens ,
658+ },
659+ output_token_details = {
660+ "audio" : output_audio_tokens ,
661+ "reasoning" : output_reasoning_tokens ,
662+ }
622663 )
0 commit comments