From 00420525857b817303a91711f98f828654475785 Mon Sep 17 00:00:00 2001 From: Yann Stoneman Date: Tue, 16 Apr 2024 02:23:00 +0000 Subject: [PATCH 1/6] start adding cohere rerank support for cross-encoder --- bin/config.ts | 5 +++ cli/magic-config.ts | 4 ++ docs/documentation/inference-script.md | 16 ++++++++ .../sagemaker-rag-models/model/inference.py | 37 +++++++++++-------- .../python-sdk/python/genai_core/clients.py | 10 +++++ .../python/genai_core/cross_encoder.py | 23 ++++++++++++ lib/shared/types.ts | 2 +- 7 files changed, 80 insertions(+), 17 deletions(-) diff --git a/bin/config.ts b/bin/config.ts index 50d14bbff..aa5464fc7 100644 --- a/bin/config.ts +++ b/bin/config.ts @@ -84,6 +84,11 @@ export function getConfig(): SystemConfig { name: "cross-encoder/ms-marco-MiniLM-L-12-v2", default: true, }, + { + provider: "cohere", + name: "rerank-english-v3.0", + default: false, + }, ], }, }; diff --git a/cli/magic-config.ts b/cli/magic-config.ts index 2b272c3f0..e6fd41f0a 100644 --- a/cli/magic-config.ts +++ b/cli/magic-config.ts @@ -770,6 +770,10 @@ async function processCreateOptions(options: any): Promise { name: "cross-encoder/ms-marco-MiniLM-L-12-v2", default: true, }; + config.rag.crossEncoderModels[1] = { + provider: "cohere", + name: "rerank-english-v3.0", + }; config.rag.embeddingsModels = embeddingModels; config.rag.embeddingsModels.forEach((m: any) => { if (m.name === models.defaultEmbedding) { diff --git a/docs/documentation/inference-script.md b/docs/documentation/inference-script.md index b373e7a88..a06544997 100644 --- a/docs/documentation/inference-script.md +++ b/docs/documentation/inference-script.md @@ -30,3 +30,19 @@ The API is JSON body based: "passages": ["I love Paris", "I love London"] } ``` + +## Cohere Rerank 3 + +To use the Cohere Rerank 3 model, get an API key from Cohere, and include the following in the JSON request body: + +```json +{ + "type": "cross-encoder", + "model": "rerank-english-v3.0", + "input": "What is the capital of the United States?", + "passages": [ + "Carson City is the capital city of the American state of Nevada.", + "Washington, D.C. is the capital of the United States.", + ... + ] +} \ No newline at end of file diff --git a/lib/rag-engines/sagemaker-rag-models/model/inference.py b/lib/rag-engines/sagemaker-rag-models/model/inference.py index 546ae3358..bbfad81d1 100644 --- a/lib/rag-engines/sagemaker-rag-models/model/inference.py +++ b/lib/rag-engines/sagemaker-rag-models/model/inference.py @@ -27,7 +27,7 @@ "intfloat/multilingual-e5-large", "sentence-transformers/all-MiniLM-L6-v2", ] -cross_encoder_models = ["cross-encoder/ms-marco-MiniLM-L-12-v2"] +cross_encoder_models = ["cross-encoder/ms-marco-MiniLM-L-12-v2", "rerank-english-v3.0"] def process_model_list(model_list): @@ -130,21 +130,26 @@ def predict_fn(input_object, config): passages = input_object["passages"] data = [[current_input, passage] for passage in passages] - with torch.inference_mode(): - features = current_tokenizer( - data, padding=True, truncation=True, return_tensors="pt" - ) - - features = features.to(device) - - scores = current_model(**features).logits.cpu().numpy() - ret_value = list( - map( - lambda val: val[-1] if isinstance(val, list) else val, - scores.tolist(), + if current_model_id == "rerank-english-v3.0": + # Use Cohere Rerank 3 API + co = cohere.Client(os.environ["COHERE_API_KEY"]) + results = co.rerank(query=current_input, documents=passages, top_n=len(passages), model='rerank-english-v3.0') + ret_value = [result.relevance_score for result in results] + else: + with torch.inference_mode(): + features = current_tokenizer( + data, padding=True, truncation=True, return_tensors="pt" ) - ) - - return ret_value + + features = features.to(device) + + scores = current_model(**features).logits.cpu().numpy() + ret_value = list( + map( + lambda val: val[-1] if isinstance(val, list) else val, + scores.tolist(), + ) + ) + return ret_value return [] diff --git a/lib/shared/layers/python-sdk/python/genai_core/clients.py b/lib/shared/layers/python-sdk/python/genai_core/clients.py index 15c3d2a8e..456871aa9 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/clients.py +++ b/lib/shared/layers/python-sdk/python/genai_core/clients.py @@ -1,4 +1,5 @@ import boto3 +import cohere import openai import genai_core.types import genai_core.parameters @@ -52,3 +53,12 @@ def get_bedrock_client(service_name="bedrock-runtime"): bedrock_config_data["aws_session_token"] = credentials["SessionToken"] return boto3.client(**bedrock_config_data) + +def get_cohere_client(): + api_key = genai_core.parameters.get_external_api_key("COHERE_API_KEY") + if not api_key: + return None + + cohere_client = cohere.Client(api_key) + + return cohere_client diff --git a/lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py b/lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py index 67bae7da5..a74d841b2 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py +++ b/lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py @@ -18,6 +18,8 @@ def rank_passages( if model.provider == "sagemaker": return _rank_passages_sagemaker(model, input, passages) + elif model.provider == "cohere": + return _rank_passages_cohere(model, input, passages) raise genai_core.typesCommonError(f"Unknown provider") @@ -66,3 +68,24 @@ def _rank_passages_sagemaker( ret_value = json.loads(response["Body"].read().decode()) return ret_value + +def _rank_passages_cohere( + model: genai_core.types.CrossEncoderModel, input: str, passages: List[str] +): + cohere_client = genai_core.clients.get_cohere_client() + if not cohere_client: + raise genai_core.types.CommonError("Cohere API key not set") + + results = cohere_client.rerank( + query=input, + documents=passages, + model=model.name, + ) + + return [ + genai_core.types.RankedPassage( + passage=passage, + score=result.relevance_score, + ) + for passage, result in zip(passages, results) + ] diff --git a/lib/shared/types.ts b/lib/shared/types.ts index 423247c48..02c6fb26c 100644 --- a/lib/shared/types.ts +++ b/lib/shared/types.ts @@ -1,6 +1,6 @@ import * as sagemaker from "aws-cdk-lib/aws-sagemaker"; -export type ModelProvider = "sagemaker" | "bedrock" | "openai"; +export type ModelProvider = "sagemaker" | "bedrock" | "openai" | "cohere"; export enum SupportedSageMakerModels { FalconLite = "FalconLite [ml.g5.12xlarge]", From 3dff4449651e2e71087cdd4987bc0c305c76adeb Mon Sep 17 00:00:00 2001 From: root Date: Tue, 16 Apr 2024 15:26:17 +0000 Subject: [PATCH 2/6] add missing cohere module to requirements --- lib/shared/layers/common/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/shared/layers/common/requirements.txt b/lib/shared/layers/common/requirements.txt index 8a798c0da..1aa5855f9 100644 --- a/lib/shared/layers/common/requirements.txt +++ b/lib/shared/layers/common/requirements.txt @@ -11,6 +11,7 @@ pgvector==0.2.2 pydantic==2.3.0 urllib3<2 openai==0.28.1 +cohere==5.3.0 beautifulsoup4==4.12.2 requests==2.31.0 attrs==23.1.0 From 4b1d193f4dd2d2e179ab98841207b2829ca9d262 Mon Sep 17 00:00:00 2001 From: Yann Stoneman Date: Sun, 21 Apr 2024 18:33:00 -0400 Subject: [PATCH 3/6] crossencodermodels providers for config props --- lib/rag-engines/sagemaker-rag-models/index.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/rag-engines/sagemaker-rag-models/index.ts b/lib/rag-engines/sagemaker-rag-models/index.ts index bfc185777..a7bec0c9a 100644 --- a/lib/rag-engines/sagemaker-rag-models/index.ts +++ b/lib/rag-engines/sagemaker-rag-models/index.ts @@ -22,7 +22,7 @@ export class SageMakerRagModels extends Construct { .map((c) => c.name); const sageMakerCrossEncoderModelIds = props.config.rag.crossEncoderModels - .filter((c) => c.provider === "sagemaker") + .filter((c) => c.provider === "sagemaker" || c.provider === "cohere") .map((c) => c.name); const model = new SageMakerModel(this, "Model", { From 57dc67cefd9952297b21d601a4b5b6fae65009b7 Mon Sep 17 00:00:00 2001 From: Yann Stoneman Date: Thu, 9 May 2024 13:24:45 -0400 Subject: [PATCH 4/6] Resolve IAM error blocking deployment --- lib/model-interfaces/idefics/index.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/model-interfaces/idefics/index.ts b/lib/model-interfaces/idefics/index.ts index 8bb894053..eb7e671a0 100644 --- a/lib/model-interfaces/idefics/index.ts +++ b/lib/model-interfaces/idefics/index.ts @@ -129,7 +129,7 @@ export class IdeficsInterface extends Construct { new iam.PolicyStatement({ actions: ["kms:Decrypt", "kms:ReEncryptFrom"], effect: iam.Effect.ALLOW, - resources: ["arn:aws:kms:*"], + resources: ["*"], }) ); From 69cdc882696e257ae46dadff9d7c7aff846fa295 Mon Sep 17 00:00:00 2001 From: Yann Stoneman Date: Thu, 9 May 2024 13:27:04 -0400 Subject: [PATCH 5/6] Remove cohere rerank from *sagemaker* cross-encoders --- lib/rag-engines/sagemaker-rag-models/index.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/rag-engines/sagemaker-rag-models/index.ts b/lib/rag-engines/sagemaker-rag-models/index.ts index a7bec0c9a..bfc185777 100644 --- a/lib/rag-engines/sagemaker-rag-models/index.ts +++ b/lib/rag-engines/sagemaker-rag-models/index.ts @@ -22,7 +22,7 @@ export class SageMakerRagModels extends Construct { .map((c) => c.name); const sageMakerCrossEncoderModelIds = props.config.rag.crossEncoderModels - .filter((c) => c.provider === "sagemaker" || c.provider === "cohere") + .filter((c) => c.provider === "sagemaker") .map((c) => c.name); const model = new SageMakerModel(this, "Model", { From 8fad5328930b135564a726993f7304010abd43ef Mon Sep 17 00:00:00 2001 From: Yann Stoneman Date: Thu, 9 May 2024 13:29:15 -0400 Subject: [PATCH 6/6] Fix schema for cohere rerank --- .../layers/python-sdk/python/genai_core/cross_encoder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py b/lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py index a74d841b2..053bf1a81 100644 --- a/lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py +++ b/lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py @@ -31,6 +31,10 @@ def get_cross_encoder_models(): if not SAGEMAKER_RAG_MODELS_ENDPOINT: models = list(filter(lambda x: x["provider"] != "sagemaker", models)) + for model in models: + if 'default' not in model: + model['default'] = False + return models