Skip to content

Commit 9f672a4

Browse files
krrishdholakiasatendrakumar
authored andcommitted
Request Headers - support x-litellm-num-retries + Usage - support usage by model group (BerriAI#12890)
* feat(litellm_pre_call_utils.py): add num_retries to litellm data for backend call allow user to pass in num retries via request headers * test(test_litellm_pre_call_utils.py): add unit test * docs(request_headers.md): document new request header * fix(common_daily_activity.py): show spend breakdown by model group Partial fix for BerriAI#12887 * feat(new_usage.tsx): new tab switcher for viewing usage by model group vs. received model Closes BerriAI#12887
1 parent 9a3ea08 commit 9f672a4

File tree

8 files changed

+315
-137
lines changed

8 files changed

+315
-137
lines changed

docs/my-website/docs/proxy/request_headers.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ Special headers that are supported by LiteLLM.
1010

1111
`x-litellm-tags`: Optional[str]: A comma separated list (e.g. `tag1,tag2,tag3`) of tags to use for [tag-based routing](./tag_routing) **OR** [spend-tracking](./enterprise.md#tracking-spend-for-custom-tags).
1212

13+
`x-litellm-num-retries`: Optional[int]: The number of retries for the request.
14+
1315
## Anthropic Headers
1416

1517
`anthropic-version` Optional[str]: The version of the Anthropic API to use.

litellm/proxy/_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2780,6 +2780,7 @@ class LitellmDataForBackendLLMCall(TypedDict, total=False):
27802780
organization: str
27812781
timeout: Optional[float]
27822782
user: Optional[str]
2783+
num_retries: Optional[int]
27832784

27842785

27852786
class JWTKeyItem(TypedDict, total=False):

litellm/proxy/litellm_pre_call_utils.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,16 @@ def _get_timeout_from_request(headers: dict) -> Optional[float]:
272272
return float(timeout_header)
273273
return None
274274

275+
@staticmethod
276+
def _get_num_retries_from_request(headers: dict) -> Optional[int]:
277+
"""
278+
Workaround for client request from Vercel's AI SDK.
279+
"""
280+
num_retries_header = headers.get("x-litellm-num-retries", None)
281+
if num_retries_header is not None:
282+
return int(num_retries_header)
283+
return None
284+
275285
@staticmethod
276286
def _get_forwardable_headers(
277287
headers: Union[Headers, dict],
@@ -407,6 +417,10 @@ def add_litellm_data_for_backend_llm_call(
407417
if timeout is not None:
408418
data["timeout"] = timeout
409419

420+
num_retries = LiteLLMProxyRequestSetup._get_num_retries_from_request(headers)
421+
if num_retries is not None:
422+
data["num_retries"] = num_retries
423+
410424
return data
411425

412426
@staticmethod
@@ -801,7 +815,10 @@ async def add_litellm_data_to_request( # noqa: PLR0915
801815
data[k] = v
802816

803817
# Add disabled callbacks from key metadata
804-
if user_api_key_dict.metadata and "litellm_disabled_callbacks" in user_api_key_dict.metadata:
818+
if (
819+
user_api_key_dict.metadata
820+
and "litellm_disabled_callbacks" in user_api_key_dict.metadata
821+
):
805822
disabled_callbacks = user_api_key_dict.metadata["litellm_disabled_callbacks"]
806823
if disabled_callbacks and isinstance(disabled_callbacks, list):
807824
data["litellm_disabled_callbacks"] = disabled_callbacks

litellm/proxy/management_endpoints/common_daily_activity.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,44 @@ def update_breakdown_metrics(
8080
)
8181
)
8282

83+
# Update model group breakdown
84+
if record.model_group and record.model_group not in breakdown.model_groups:
85+
breakdown.model_groups[record.model_group] = MetricWithMetadata(
86+
metrics=SpendMetrics(),
87+
metadata=model_metadata.get(record.model_group, {}),
88+
)
89+
if record.model_group:
90+
breakdown.model_groups[record.model_group].metrics = update_metrics(
91+
breakdown.model_groups[record.model_group].metrics, record
92+
)
93+
94+
# Update API key breakdown for this model
95+
if (
96+
record.api_key
97+
not in breakdown.model_groups[record.model_group].api_key_breakdown
98+
):
99+
breakdown.model_groups[record.model_group].api_key_breakdown[
100+
record.api_key
101+
] = KeyMetricWithMetadata(
102+
metrics=SpendMetrics(),
103+
metadata=KeyMetadata(
104+
key_alias=api_key_metadata.get(record.api_key, {}).get(
105+
"key_alias", None
106+
),
107+
team_id=api_key_metadata.get(record.api_key, {}).get(
108+
"team_id", None
109+
),
110+
),
111+
)
112+
breakdown.model_groups[record.model_group].api_key_breakdown[
113+
record.api_key
114+
].metrics = update_metrics(
115+
breakdown.model_groups[record.model_group]
116+
.api_key_breakdown[record.api_key]
117+
.metrics,
118+
record,
119+
)
120+
83121
if record.mcp_namespaced_tool_name:
84122
if record.mcp_namespaced_tool_name not in breakdown.mcp_servers:
85123
breakdown.mcp_servers[record.mcp_namespaced_tool_name] = MetricWithMetadata(
@@ -295,22 +333,6 @@ async def get_daily_activity(
295333
take=page_size,
296334
)
297335

298-
# # for 50% of the records, set the mcp_server_id to a random value
299-
# mcp_server_dict = {"Zapier_Gmail_MCP", "Stripe_MCP"}
300-
# import random
301-
302-
# for idx, record in enumerate(daily_spend_data):
303-
# record = LiteLLM_DailyUserSpend(**record.model_dump())
304-
# if random.random() < 0.5:
305-
# record.mcp_server_id = random.choice(list(mcp_server_dict))
306-
# record.model = None
307-
# record.model_group = None
308-
# record.prompt_tokens = 0
309-
# record.completion_tokens = 0
310-
# record.cache_read_input_tokens = 0
311-
# record.cache_creation_input_tokens = 0
312-
# daily_spend_data[idx] = record
313-
314336
# Get all unique API keys from the spend data
315337
api_keys = set()
316338
for record in daily_spend_data:

litellm/types/proxy/management_endpoints/common_daily_activity.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ class BreakdownMetrics(BaseModel):
6161
models: Dict[str, MetricWithMetadata] = Field(
6262
default_factory=dict
6363
) # model -> {metrics, metadata}
64+
model_groups: Dict[str, MetricWithMetadata] = Field(
65+
default_factory=dict
66+
) # model_group -> {metrics, metadata}
6467
providers: Dict[str, MetricWithMetadata] = Field(
6568
default_factory=dict
6669
) # provider -> {metrics, metadata}

0 commit comments

Comments
 (0)