Skip to content

Commit 342a056

Browse files
Merge pull request #53 from amazon-science/feature/lambda_llm
[Feature] Adding Lambda LLMs
2 parents b586fe8 + d858c63 commit 342a056

File tree

9 files changed

+369
-17
lines changed

9 files changed

+369
-17
lines changed

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@ dependencies = [
2929
"langchain-community==0.3.20",
3030
"json-repair==0.40.0",
3131
"Jinja2==3.1.6",
32-
"dspy==2.6.*",
33-
"asteval==1.0.6"
32+
"dspy==2.6.11",
33+
"asteval==1.0.6",
34+
"glom==24.11.0",
35+
"aioboto3==14.1.0"
3436
]
3537

3638
[project.optional-dependencies]

src/fmcore/aws/factory/boto_factory.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from datetime import timezone
12
from typing import Dict
3+
4+
import aioboto3
25
import boto3
36
from botocore.credentials import RefreshableCredentials
47
from botocore.session import get_session
@@ -14,7 +17,7 @@ class BotoFactory:
1417
@classmethod
1518
def __get_refreshable_session(cls, role_arn: str, region: str, session_name: str) -> boto3.Session:
1619
"""
17-
Creates a Boto3 session with refreshable credentials for the assumed IAM role.
20+
Creates a botocore session with refreshable credentials for the assumed IAM role.
1821
1922
Args:
2023
role_arn (str): ARN of the IAM role to assume.
@@ -49,43 +52,79 @@ def refresh() -> dict:
4952
botocore_session._credentials = refreshable_credentials
5053
botocore_session.set_config_variable(AWSConstants.REGION, region)
5154

52-
return boto3.Session(botocore_session=botocore_session)
55+
return botocore_session
5356

5457
@classmethod
55-
def __create_session(cls, *, role_arn: str, region: str, session_name: str) -> boto3.Session:
58+
def __create_session(cls, *, role_arn: str = None, region: str, session_name: str) -> boto3.Session:
5659
"""
5760
Creates a Boto3 session, either using role-based authentication or default credentials.
5861
5962
Args:
6063
region (str): AWS region for the session.
61-
role_arn (str): IAM role ARN to assume (if provided).
64+
role_arn (str, optional): IAM role ARN to assume.
65+
session_name (str): Name for the session.
6266
6367
Returns:
6468
boto3.Session: A configured Boto3 session.
6569
"""
66-
return (
67-
cls.__get_refreshable_session(role_arn=role_arn, region=region, session_name=session_name)
68-
if role_arn
69-
else boto3.Session(region_name=region)
70+
if not role_arn:
71+
return boto3.Session(region_name=region)
72+
73+
# Get a botocore session with refreshable credentials
74+
botocore_session = cls.__get_refreshable_session(
75+
role_arn=role_arn, region=region, session_name=session_name
7076
)
7177

78+
return boto3.Session(botocore_session=botocore_session)
79+
7280
@classmethod
73-
def get_client(cls, *, service_name: str, region: str, role_arn: str) -> boto3.client:
81+
def get_client(cls, *, service_name: str, region: str, role_arn: str = None) -> boto3.client:
7482
"""
7583
Retrieves a cached Boto3 client or creates a new one.
7684
7785
Args:
7886
service_name (str): AWS service name (e.g., 's3', 'bedrock-runtime').
7987
region (str): AWS region for the client.
80-
role_arn (str): IAM role ARN for authentication (optional).
88+
role_arn (str, optional): IAM role ARN for authentication.
8189
8290
Returns:
8391
boto3.client: A configured Boto3 client.
8492
"""
85-
key = f"{service_name}-{region}"
86-
session = cls.__create_session(region=region, role_arn=role_arn, session_name=f"{key}-Session")
93+
key = f"{service_name}-{region}-{role_arn or 'default'}"
8794

8895
if key not in cls.__clients:
96+
session = cls.__create_session(
97+
region=region, role_arn=role_arn, session_name=f"{service_name}-Session"
98+
)
8999
cls.__clients[key] = session.client(service_name, region_name=region)
90100

91101
return cls.__clients[key]
102+
103+
@classmethod
104+
def get_async_session(cls, *, service_name: str, region: str, role_arn: str = None) -> aioboto3.Session:
105+
session_name: str = f"Async-{service_name}-Session"
106+
107+
def refresh():
108+
sts_client = boto3.client("sts", region_name=region)
109+
creds = sts_client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)["Credentials"]
110+
return {
111+
"access_key": creds["AccessKeyId"],
112+
"secret_key": creds["SecretAccessKey"],
113+
"token": creds["SessionToken"],
114+
"expiry_time": creds["Expiration"].astimezone(timezone.utc).isoformat(),
115+
}
116+
117+
creds = RefreshableCredentials.create_from_metadata(
118+
metadata=refresh(), refresh_using=refresh, method="sts-assume-role"
119+
)
120+
121+
frozen = creds.get_frozen_credentials()
122+
123+
session = aioboto3.Session(
124+
aws_access_key_id=frozen.access_key,
125+
aws_secret_access_key=frozen.secret_key,
126+
aws_session_token=frozen.token,
127+
region_name=region,
128+
)
129+
130+
return session

src/fmcore/llm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from fmcore.llm.base_llm import BaseLLM
22
from fmcore.llm.bedrock_llm import BedrockLLM
3+
from fmcore.llm.lambda_llm import LambdaLLM
34
from fmcore.llm.distributed_llm import DistributedLLM

src/fmcore/llm/bedrock_llm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
from fmcore.llm.base_llm import BaseLLM
1010
from fmcore.llm.types.llm_types import LLMConfig
1111
from fmcore.utils.rate_limit_utils import RateLimiterUtils
12+
from fmcore.utils.retry_utils import RetryUtil
1213

1314

14-
class BedrockLLM(BaseLLM, BaseModel):
15+
class BedrockLLM(BaseLLM[List[BaseMessage], BaseMessage, BaseMessageChunk], BaseModel):
1516
"""
1617
AWS Bedrock language model with built-in asynchronous rate limiting.
1718
@@ -58,6 +59,7 @@ def invoke(self, messages: List[BaseMessage]) -> BaseMessage:
5859
"""
5960
return self.client.invoke(input=messages)
6061

62+
@RetryUtil.with_backoff(lambda self: self.config.provider_params.retries)
6163
async def ainvoke(self, messages: List[BaseMessage]) -> BaseMessage:
6264
"""
6365
Asynchronously invokes the model with rate limiting.
@@ -83,6 +85,7 @@ def stream(self, messages: List[BaseMessage]) -> Iterator[BaseMessageChunk]:
8385
"""
8486
return self.client.stream(input=messages)
8587

88+
@RetryUtil.with_backoff(lambda self: self.config.provider_params.retries)
8689
async def astream(self, messages: List[BaseMessage]) -> AsyncIterator[BaseMessageChunk]:
8790
"""
8891
Asynchronously streams response chunks from the model with rate limiting.

src/fmcore/llm/lambda_llm.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import json
2+
from typing import List, Iterator, AsyncIterator, Dict
3+
4+
import aioboto3
5+
from aiolimiter import AsyncLimiter
6+
from botocore.client import BaseClient
7+
from langchain_aws import ChatBedrockConverse
8+
from langchain_community.adapters.openai import convert_dict_to_message
9+
from pydantic import BaseModel
10+
from langchain_core.messages import (
11+
BaseMessage,
12+
BaseMessageChunk,
13+
convert_to_openai_messages,
14+
)
15+
16+
from fmcore.aws.factory.boto_factory import BotoFactory
17+
from fmcore.llm.base_llm import BaseLLM
18+
from fmcore.llm.types.llm_types import LLMConfig
19+
from fmcore.llm.types.provider_types import LambdaProviderParams
20+
from fmcore.utils.rate_limit_utils import RateLimiterUtils
21+
from fmcore.utils.retry_utils import RetryUtil
22+
23+
24+
class LambdaLLM(BaseLLM[List[BaseMessage], BaseMessage, BaseMessageChunk], BaseModel):
25+
"""
26+
An LLM implementation that routes requests through an AWS Lambda function.
27+
28+
This class uses both synchronous and asynchronous boto3 Lambda clients to
29+
interact with an LLM hosted via AWS Lambda. It includes automatic async rate
30+
limiting and supports OpenAI-style message formatting.
31+
32+
Attributes:
33+
sync_client (BaseClient): Boto3 synchronous client for AWS Lambda.
34+
async_session ( aioboto3.Session): Boto3 asynchronous session for AWS Lambda.
35+
rate_limiter (AsyncLimiter): Async limiter to enforce API rate limits.
36+
37+
Note:
38+
The `async_client` is not stored directly because `aioboto3.client(...)` returns
39+
an asynchronous context manager, which must be used with `async with` and cannot
40+
be reused safely across calls. Instead, we store an `aioboto3.Session` instance
41+
in `async_session`, from which a fresh client is created inside each `async with`
42+
block
43+
44+
"""
45+
46+
aliases = ["LAMBDA"]
47+
48+
sync_client: BaseClient
49+
async_session: aioboto3.Session # Using session here as aioboto3.client returns context manager
50+
rate_limiter: AsyncLimiter
51+
52+
@classmethod
53+
def _get_instance(cls, *, llm_config: LLMConfig) -> "LambdaLLM":
54+
"""
55+
Factory method to create an instance of LambdaLLM with the given configuration.
56+
57+
Args:
58+
llm_config (LLMConfig): The LLM configuration, including model and provider details.
59+
60+
Returns:
61+
LambdaLLM: A configured instance of the Lambda-backed LLM.
62+
"""
63+
provider_params: LambdaProviderParams = llm_config.provider_params
64+
65+
sync_client = BotoFactory.get_client(
66+
service_name="lambda",
67+
region=provider_params.region,
68+
role_arn=provider_params.role_arn,
69+
)
70+
async_session = BotoFactory.get_async_session(
71+
service_name="lambda",
72+
region=provider_params.region,
73+
role_arn=provider_params.role_arn,
74+
)
75+
76+
rate_limiter = RateLimiterUtils.create_async_rate_limiter(
77+
rate_limit_config=provider_params.rate_limit
78+
)
79+
80+
return LambdaLLM(
81+
config=llm_config, sync_client=sync_client, async_session=async_session, rate_limiter=rate_limiter
82+
)
83+
84+
def convert_messages_to_lambda_payload(self, messages: List[BaseMessage]) -> Dict:
85+
"""
86+
Converts internal message objects to the payload format expected by the Lambda function.
87+
We expect all lambdas to be accepting openai messages format
88+
89+
Args:
90+
messages (List[BaseMessage]): List of internal message objects.
91+
92+
Returns:
93+
Dict: The payload dictionary to send to the Lambda function.
94+
"""
95+
return {
96+
"modelId": self.config.model_id,
97+
"messages": convert_to_openai_messages(messages),
98+
"model_params": self.config.model_params.model_dump(),
99+
}
100+
101+
def convert_lambda_response_to_messages(self, response: Dict) -> BaseMessage:
102+
"""
103+
Converts the raw Lambda function response into a BaseMessage.
104+
105+
This method expects the Lambda response to contain a 'Payload' key with a stream
106+
of OpenAI-style messages (a list of dictionaries). It parses the stream, extracts
107+
the first message, and converts it into a BaseMessage instance.
108+
109+
Args:
110+
response (Dict): The response dictionary returned from the Lambda invocation.
111+
112+
Returns:
113+
BaseMessage: The first parsed message from the response.
114+
"""
115+
response_payload: List[Dict] = json.load(response["Payload"])
116+
# The Lambda returns a list of messages in OpenAI format.
117+
# Currently, we only expect a single response message,
118+
# so we take the first item in the list.
119+
return convert_dict_to_message(response_payload[0])
120+
121+
def invoke(self, messages: List[BaseMessage]) -> BaseMessage:
122+
"""
123+
Synchronously invokes the Lambda function with given messages.
124+
125+
Args:
126+
messages (List[BaseMessage]): Input messages for the model.
127+
128+
Returns:
129+
BaseMessage: Response message from the model.
130+
"""
131+
payload = self.convert_messages_to_lambda_payload(messages)
132+
response = self.sync_client.invoke(
133+
FunctionName=self.config.provider_params.function_arn,
134+
InvocationType="RequestResponse",
135+
Payload=json.dumps(payload),
136+
)
137+
return self.convert_lambda_response_to_messages(response)
138+
139+
@RetryUtil.with_backoff(lambda self: self.config.provider_params.retries)
140+
async def ainvoke(self, messages: List[BaseMessage]) -> BaseMessage:
141+
"""
142+
Asynchronously invokes the Lambda function with rate limiting.
143+
144+
Args:
145+
messages (List[BaseMessage]): Input messages for the model.
146+
147+
Returns:
148+
BaseMessage: Response message from the model.
149+
"""
150+
async with self.rate_limiter:
151+
async with self.async_session.client("lambda") as lambda_client:
152+
payload = self.convert_messages_to_lambda_payload(messages)
153+
response = await lambda_client.invoke(
154+
FunctionName=self.config.provider_params.function_arn,
155+
InvocationType="RequestResponse",
156+
Payload=json.dumps(payload),
157+
)
158+
payload = await response["Payload"].read()
159+
response_payload: List[Dict] = json.loads(payload.decode("utf-8"))
160+
# The Lambda returns a list of messages in OpenAI format.
161+
# Currently, we only expect a single response message,
162+
# so we take the first item in the list.
163+
return convert_dict_to_message(response_payload[0])
164+
165+
def stream(self, messages: List[BaseMessage]) -> Iterator[BaseMessageChunk]:
166+
"""
167+
Not implemented. Streaming is not supported for LambdaLLM.
168+
169+
Raises:
170+
NotImplementedError
171+
"""
172+
raise NotImplementedError("Streaming is not implemented for LambdaLLM")
173+
174+
async def astream(self, messages: List[BaseMessage]) -> AsyncIterator[BaseMessageChunk]:
175+
"""
176+
Not implemented. Asynchronous streaming is not supported for LambdaLLM.
177+
178+
Raises:
179+
NotImplementedError
180+
"""
181+
raise NotImplementedError("Streaming is not implemented for LambdaLLM")
182+
183+
def batch(self, messages: List[List[BaseMessage]]) -> List[BaseMessage]:
184+
"""
185+
Not implemented. Batch processing is not supported for LambdaLLM.
186+
187+
Raises:
188+
NotImplementedError
189+
"""
190+
raise NotImplementedError("Batch processing is not implemented for LambdaLLM.")
191+
192+
async def abatch(self, messages: List[List[BaseMessage]]) -> List[BaseMessage]:
193+
"""
194+
Not implemented. Asynchronous batch processing is not supported for LambdaLLM.
195+
196+
Raises:
197+
NotImplementedError
198+
"""
199+
raise NotImplementedError("Batch processing is not implemented for LambdaLLM.")

src/fmcore/llm/types/provider_types.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,23 @@ class BedrockProviderParams(BaseProviderParams, AWSAccountMixin, RateLimiterMixi
5252
"""
5353

5454
aliases = ["BEDROCK"]
55+
56+
57+
class LambdaProviderParams(BaseProviderParams, AWSAccountMixin, RateLimiterMixin, RetryConfigMixin):
58+
"""
59+
Configuration for a Bedrock provider using AWS.
60+
61+
This class combines AWS account settings with request configuration parameters
62+
(such as rate limits and retry policies) needed to interact with Bedrock services.
63+
It mixes in AWS-specific account details, rate limiting, and retry configurations
64+
to form a complete provider setup.
65+
66+
Mixes in:
67+
AWSAccountMixin: Supplies AWS-specific account details (e.g., role ARN, region).
68+
RateLimiterMixin: Supplies API rate limiting settings.
69+
RetryConfigMixin: Supplies retry policy settings.
70+
"""
71+
72+
aliases = ["LAMBDA"]
73+
74+
function_arn: str

0 commit comments

Comments
 (0)