Skip to content

Commit e84371e

Browse files
feat: add token limit chunks to embedding models (#670)
* feat: add token limit chunks to embedding models * fix: linting, dry build and cdk-lib issues --------- Co-authored-by: Maryam Khidir <mkhidir@amazon.de>
1 parent 880ca16 commit e84371e

File tree

28 files changed

+602
-521
lines changed

28 files changed

+602
-521
lines changed

.github/workflows/build.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,13 @@ jobs:
77
runs-on: ubuntu-latest
88
steps:
99
- uses: actions/checkout@v4
10-
- uses: actions/setup-node@v3
10+
- uses: actions/setup-node@v4
1111
with:
1212
node-version: "20"
13+
- name: Install latest CDK CLI
14+
run: |
15+
npm install -g aws-cdk@latest
16+
cdk --version
1317
- name: Formatting
1418
run: |
1519
npm ci
@@ -23,7 +27,7 @@ jobs:
2327
npm audit
2428
npm run build
2529
npm run test
26-
npx cdk synth
30+
cdk synth
2731
- name: PyTests
2832
# Suppression of pip audit failure until langchain is upgraded.
2933
run: |

.python-version

Lines changed: 0 additions & 1 deletion
This file was deleted.

NOTICE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ The following Python packages may be included in this product:
66
- cfnresponse==1.1.2
77
- opensearch-py==2.3.1
88
- openai==0.28.0
9-
- requests==2.32.0
9+
- requests==2.32.4
1010
- huggingface-hub
1111
- hf-transfer
1212
- aws_xray_sdk==2.12.1
@@ -817,7 +817,7 @@ Agreement.
817817

818818
The following Python packages may be included in this product:
819819

820-
- urllib3<2
820+
- urllib3==2.5.0
821821

822822
These packages each contain the following license and notice below:
823823

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,16 @@ This blueprint deploys the complete AWS GenAI LLM Chatbot solution in your AWS a
2727
- AWS CLI configured with credentials
2828
- Node.js 18+ and npm
2929
- Python 3.8+
30+
- AWS CDK CLI version compatible with aws-cdk-lib 2.206.0 or later
31+
```bash
32+
# Install or update the CDK CLI globally
33+
npm install -g aws-cdk@latest
34+
35+
# Verify the installed version
36+
cdk --version
37+
```
38+
39+
> **Important**: The CDK CLI version must be compatible with the aws-cdk-lib version used in this project (currently 2.206.0). If you encounter a "Cloud assembly schema version mismatch" error during deployment, update your CDK CLI to the latest version using the command above.
3040
3141
### Deployment
3242

lib/authentication/lambda/addFederatedUserToUserGroup/index.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def add_user_to_group(cognito, username, group_name, user_pool_id):
7272

7373
def handler(event, context):
7474
print(f"Event received: {event}")
75-
75+
7676
# Handle different trigger types with different event structures
7777
if "request" in event and "userAttributes" in event["request"]:
7878
# POST_AUTHENTICATION trigger
@@ -88,18 +88,30 @@ def handler(event, context):
8888
new_group = user_attributes.get("custom:chatbot_role")
8989
user_pool_id = event["userPoolId"]
9090
trigger_type = "PRE_AUTHENTICATION"
91-
elif "request" in event and "userAttributes" in event["request"] and "validationData" in event["request"]:
91+
elif (
92+
"request" in event
93+
and "userAttributes" in event["request"]
94+
and "validationData" in event["request"]
95+
):
9296
# POST_CONFIRMATION trigger
9397
user_attributes = event["request"]["userAttributes"]
9498
username = user_attributes.get("sub") or user_attributes.get("username")
9599
new_group = user_attributes.get("custom:chatbot_role")
96100
user_pool_id = event["userPoolId"]
97101
trigger_type = "POST_CONFIRMATION"
98-
elif "request" in event and "userAttributes" in event["request"] and "validationData" not in event["request"]:
102+
elif (
103+
"request" in event
104+
and "userAttributes" in event["request"]
105+
and "validationData" not in event["request"]
106+
):
99107
# PRE_SIGN_UP trigger
100108
user_attributes = event["request"]["userAttributes"]
101109
# For Pre sign-up, username might be in different fields
102-
username = user_attributes.get("sub") or user_attributes.get("username") or user_attributes.get("email")
110+
username = (
111+
user_attributes.get("sub")
112+
or user_attributes.get("username")
113+
or user_attributes.get("email")
114+
)
103115
new_group = user_attributes.get("custom:chatbot_role")
104116
user_pool_id = event["userPoolId"]
105117
trigger_type = "PRE_SIGN_UP"
@@ -115,7 +127,7 @@ def handler(event, context):
115127

116128
# Get default group from environment variable or use 'user' as fallback
117129
default_group = os.environ.get("DEFAULT_USER_GROUP", "user")
118-
130+
119131
# If no custom:chatbot_role is provided, use default group
120132
if not new_group:
121133
new_group = default_group
@@ -125,18 +137,22 @@ def handler(event, context):
125137
if trigger_type == "PRE_SIGN_UP":
126138
print("Pre sign-up trigger - user will be created after this trigger completes")
127139
print(f"Will assign user to group: {new_group}")
128-
print("Note: Group assignment will happen in a separate trigger (POST_CONFIRMATION)")
129-
140+
print(
141+
"Note: Group assignment will happen in a separate \
142+
trigger (POST_CONFIRMATION)"
143+
)
144+
130145
# For Pre sign-up, we can only validate or modify the sign-up request
131146
# We cannot assign groups yet as the user doesn't exist
132-
# The group assignment will need to happen in POST_CONFIRMATION or PRE_AUTHENTICATION
133-
147+
# The group assignment will need to happen in
148+
# POST_CONFIRMATION or PRE_AUTHENTICATION
149+
134150
# You might want to add the group information to the user attributes
135151
# so it can be used later in POST_CONFIRMATION
136152
if "custom:chatbot_role" not in user_attributes:
137153
user_attributes["custom:chatbot_role"] = new_group
138154
print(f"Added custom:chatbot_role attribute: {new_group}")
139-
155+
140156
return event
141157

142158
# For other triggers (PRE_AUTHENTICATION, POST_AUTHENTICATION, POST_CONFIRMATION)

lib/chatbot-api/functions/api-handler/routes/models.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414

1515
@router.resolver(field_name="listModels")
1616
@tracer.capture_method
17-
@permissions.approved_roles([
18-
permissions.ADMIN_ROLE,
19-
permissions.WORKSPACES_MANAGER_ROLE
20-
])
17+
@permissions.approved_roles(
18+
[permissions.ADMIN_ROLE, permissions.WORKSPACES_MANAGER_ROLE]
19+
)
2120
def models() -> list[dict[str, Any]]:
2221
return genai_core.models.list_models()

lib/sagemaker-model/hf-custom-script-model/samples/pipeline/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _forward(self, model_inputs, **kwargs):
3737
input_ids=input_ids.to(self.model.device),
3838
attention_mask=attention_mask.to(self.model.device),
3939
return_dict_in_generate=True,
40-
**kwargs
40+
**kwargs,
4141
)
4242

4343
return {"input_ids": input_ids, "outputs": outputs}
Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
boto3==1.28.57
1+
boto3==1.40.13
22
aws-lambda-powertools==2.25.1
33
numpy==1.26.4
44
cfnresponse==1.1.2
55
aws_requests_auth==0.4.3
66
requests-aws4auth==1.2.3
77
langchain==0.3.7
88
langchain-community==0.3.3
9-
opensearch-py==2.3.1
10-
psycopg2-binary==2.9.7
9+
opensearch-py==3.0.0
10+
psycopg2-binary==2.9.10
1111
pgvector==0.2.2
12-
urllib3<2
12+
urllib3==2.5.0
1313
openai==1.47.0
1414
beautifulsoup4==4.12.2
15-
requests==2.32.2
15+
requests==2.32.4
1616
attrs==23.1.0
1717
feedparser==6.0.11
1818
PyJWT==2.9.0

lib/shared/layers/common/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ langchain-text-splitters==0.3.5
1212
opensearch-py==2.4.2
1313
psycopg2-binary==2.9.7
1414
pgvector==0.2.2
15-
urllib3<2
15+
urllib3==2.5.0
1616
beautifulsoup4==4.12.2
17-
requests==2.32.0
17+
requests==2.32.4
1818
attrs==23.1.0
1919
feedparser==6.0.11
2020
defusedxml==0.7.1

lib/shared/layers/python-sdk/python/genai_core/embeddings.py

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,85 @@
1818
logger = Logger()
1919

2020

21+
def get_model_token_limit(model_name):
22+
# Extract provider from model name
23+
model_provider = model_name.split(".")[0]
24+
25+
# https://docs.aws.amazon.com/bedrock/latest/userguide/titan-embedding-models.html
26+
# https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/embeddings
27+
# https://docs.cohere.com/v2/docs/cohere-embed
28+
PROVIDER_TOKEN_LIMITS = {
29+
Provider.AMAZON.value: 8000, # Amazon Titan models
30+
Provider.COHERE.value: 512, # Cohere models
31+
Provider.OPENAI.value: 8191, # OpenAI models
32+
"default": 2500, # Default fallback (2500 * 4 = 10000)
33+
}
34+
35+
return PROVIDER_TOKEN_LIMITS.get(model_provider, PROVIDER_TOKEN_LIMITS["default"])
36+
37+
2138
def generate_embeddings(
2239
model: EmbeddingsModel, input: list[str], task: str = "store", batch_size: int = 50
2340
) -> list[list[float]]:
24-
input = [x[:10000] for x in input]
25-
26-
ret_value = []
27-
batch_split = [input[i : i + batch_size] for i in range(0, len(input), batch_size)]
28-
29-
for batch in batch_split:
30-
if model.provider == Provider.OPENAI.value:
31-
ret_value.extend(_generate_embeddings_openai(model, batch))
32-
elif model.provider == Provider.BEDROCK.value:
33-
ret_value.extend(_generate_embeddings_bedrock(model, batch, task))
34-
elif model.provider == Provider.SAGEMAKER.value:
35-
ret_value.extend(_generate_embeddings_sagemaker(model, batch))
36-
else:
37-
raise CommonError(f"Unknown provider: {model.provider}")
41+
try:
42+
# Get model-specific token limit
43+
token_limit = get_model_token_limit(model.name)
44+
char_limit = min(token_limit * 4, 10000) # Use existing 10000 char limit as max
45+
46+
# Chunk inputs and track mapping
47+
chunked_input = []
48+
chunk_mapping = []
49+
current_idx = 0
50+
51+
for text in input:
52+
# Split text into chunks if it exceeds the limit
53+
if len(text) <= char_limit:
54+
chunks = [text]
55+
else:
56+
chunks = [
57+
text[i : i + char_limit] for i in range(0, len(text), char_limit)
58+
]
59+
60+
# Track which chunks belong to which original input using a chunk map
61+
chunk_indices = list(range(current_idx, current_idx + len(chunks)))
62+
chunk_mapping.append(chunk_indices)
63+
current_idx += len(chunks)
64+
65+
chunked_input.extend(chunks)
66+
67+
ret_value = []
68+
batch_split = [
69+
chunked_input[i : i + batch_size]
70+
for i in range(0, len(chunked_input), batch_size)
71+
]
72+
73+
for batch in batch_split:
74+
if model.provider == Provider.OPENAI.value:
75+
ret_value.extend(_generate_embeddings_openai(model, batch))
76+
elif model.provider == Provider.BEDROCK.value:
77+
ret_value.extend(_generate_embeddings_bedrock(model, batch, task))
78+
elif model.provider == Provider.SAGEMAKER.value:
79+
ret_value.extend(_generate_embeddings_sagemaker(model, batch))
80+
else:
81+
raise CommonError(f"Unknown provider: {model.provider}")
3882

39-
return ret_value
83+
# Combine embeddings from the same original input
84+
final_embeddings = []
85+
for chunks_idx in chunk_mapping:
86+
if len(chunks_idx) == 1:
87+
final_embeddings.append(ret_value[chunks_idx[0]])
88+
else:
89+
# Average the embeddings
90+
chunk_embeddings = [ret_value[idx] for idx in chunks_idx]
91+
avg_embedding = [
92+
sum(values) / len(values) for values in zip(*chunk_embeddings)
93+
]
94+
final_embeddings.append(avg_embedding)
95+
96+
return final_embeddings
97+
except Exception as e:
98+
logger.error(f"Error in generate_embeddings: {str(e)}")
99+
raise CommonError(f"Failed to generate embeddings: {str(e)}")
40100

41101

42102
def get_embeddings_models():

0 commit comments

Comments
 (0)