-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Improving the VoyageAI integration #3705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
fzowl
wants to merge
3
commits into
crewAIInc:main
Choose a base branch
from
voyage-ai:context_model_and_token_counting
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,36 @@ | ||
"""VoyageAI embedding function implementation.""" | ||
|
||
from collections.abc import Callable, Generator | ||
from typing import cast | ||
|
||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings | ||
from typing_extensions import Unpack | ||
|
||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig | ||
|
||
# Token limits for different VoyageAI models | ||
VOYAGE_TOTAL_TOKEN_LIMITS = { | ||
"voyage-context-3": 32_000, | ||
"voyage-3.5-lite": 1_000_000, | ||
"voyage-3.5": 320_000, | ||
"voyage-2": 320_000, | ||
"voyage-3-large": 120_000, | ||
"voyage-code-3": 120_000, | ||
"voyage-large-2-instruct": 120_000, | ||
"voyage-finance-2": 120_000, | ||
"voyage-multilingual-2": 120_000, | ||
"voyage-law-2": 120_000, | ||
"voyage-large-2": 120_000, | ||
"voyage-3": 120_000, | ||
"voyage-3-lite": 120_000, | ||
"voyage-code-2": 120_000, | ||
"voyage-3-m-exp": 120_000, | ||
"voyage-multimodal-3": 120_000, | ||
} | ||
|
||
# Batch size for embedding requests | ||
BATCH_SIZE = 1000 | ||
|
||
|
||
class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]): | ||
"""Embedding function for VoyageAI models.""" | ||
|
@@ -46,17 +70,146 @@ def __call__(self, input: Documents) -> Embeddings: | |
Returns: | ||
List of embedding vectors. | ||
""" | ||
|
||
if isinstance(input, str): | ||
input = [input] | ||
|
||
result = self._client.embed( | ||
texts=input, | ||
model=self._config.get("model", "voyage-2"), | ||
input_type=self._config.get("input_type"), | ||
truncation=self._config.get("truncation", True), | ||
output_dtype=self._config.get("output_dtype"), | ||
output_dimension=self._config.get("output_dimension"), | ||
) | ||
# Early return for empty input | ||
if not input: | ||
return [] | ||
|
||
# Use unified batching for all text inputs | ||
embeddings = self._embed_with_batching(list(input)) | ||
|
||
return cast(Embeddings, embeddings) | ||
|
||
def _build_batches(self, texts: list[str]) -> Generator[list[str], None, None]: | ||
""" | ||
Generate batches of texts based on token limits using a generator. | ||
|
||
Args: | ||
texts: List of texts to batch. | ||
|
||
Yields: | ||
Batches of texts as lists. | ||
""" | ||
if not texts: | ||
return | ||
|
||
max_tokens_per_batch = self.get_token_limit() | ||
current_batch: list[str] = [] | ||
current_batch_tokens = 0 | ||
|
||
# Tokenize all texts in one API call | ||
all_token_lists = self._client.tokenize(texts, model=self._config["model"]) | ||
token_counts = [len(tokens) for tokens in all_token_lists] | ||
|
||
for i, text in enumerate(texts): | ||
n_tokens = token_counts[i] | ||
|
||
# Check if adding this text would exceed limits | ||
if current_batch and ( | ||
len(current_batch) >= BATCH_SIZE | ||
or (current_batch_tokens + n_tokens > max_tokens_per_batch) | ||
): | ||
# Yield the current batch and start a new one | ||
yield current_batch | ||
current_batch = [] | ||
current_batch_tokens = 0 | ||
|
||
current_batch.append(text) | ||
current_batch_tokens += n_tokens | ||
|
||
# Yield the last batch (always has at least one text) | ||
if current_batch: | ||
yield current_batch | ||
|
||
def _get_embed_function(self) -> Callable[[list[str]], list[list[float]]]: | ||
""" | ||
Get the appropriate embedding function based on model type. | ||
|
||
Returns: | ||
A callable that takes a batch of texts and returns embeddings. | ||
""" | ||
model_name = self._config["model"] | ||
|
||
if self._is_context_model(): | ||
|
||
def embed_batch(batch: list[str]) -> list[list[float]]: | ||
result = self._client.contextualized_embed( | ||
inputs=[batch], | ||
model=model_name, | ||
input_type=self._config.get("input_type"), | ||
output_dimension=self._config.get("output_dimension"), | ||
) | ||
return [list(emb) for emb in result.results[0].embeddings] | ||
|
||
return embed_batch | ||
|
||
def embed_batch_regular(batch: list[str]) -> list[list[float]]: | ||
result = self._client.embed( | ||
texts=batch, | ||
model=model_name, | ||
input_type=self._config.get("input_type"), | ||
truncation=self._config.get("truncation", True), | ||
output_dimension=self._config.get("output_dimension"), | ||
) | ||
return [list(emb) for emb in result.embeddings] | ||
|
||
return embed_batch_regular | ||
|
||
def _embed_with_batching(self, texts: list[str]) -> list[list[float]]: | ||
""" | ||
Unified method to embed texts with automatic batching based on token limits. | ||
Works for regular and contextual models. | ||
|
||
Args: | ||
texts: List of texts to embed. | ||
|
||
Returns: | ||
List of embeddings. | ||
""" | ||
if not texts: | ||
return [] | ||
|
||
# Get the appropriate embedding function for this model type | ||
embed_fn = self._get_embed_function() | ||
|
||
# Process each batch | ||
all_embeddings = [] | ||
for batch in self._build_batches(texts): | ||
batch_embeddings = embed_fn(batch) | ||
all_embeddings.extend(batch_embeddings) | ||
|
||
return all_embeddings | ||
|
||
def count_tokens(self, texts: list[str]) -> list[int]: | ||
""" | ||
Count tokens for the given texts. | ||
|
||
Args: | ||
texts: List of texts to count tokens for. | ||
|
||
Returns: | ||
List of token counts for each text. | ||
""" | ||
if not texts: | ||
return [] | ||
|
||
# Use the VoyageAI tokenize API to get token counts | ||
token_lists = self._client.tokenize(texts, model=self._config["model"]) | ||
return [len(token_list) for token_list in token_lists] | ||
|
||
def get_token_limit(self) -> int: | ||
""" | ||
Get the token limit for the current model. | ||
|
||
Returns: | ||
Token limit for the model, or default of 120_000 if not found. | ||
""" | ||
model_name = self._config["model"] | ||
return VOYAGE_TOTAL_TOKEN_LIMITS.get(model_name, 120_000) | ||
|
||
return cast(Embeddings, result.embeddings) | ||
def _is_context_model(self) -> bool: | ||
"""Check if the model is a contextualized embedding model.""" | ||
model_name = self._config["model"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Model Removal Breaks Existing Code CompatibilityRemoving the default "voyage-2" for the Additional Locations (1) |
||
return "context" in model_name |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Early Return Alters Expected Error Handling
The
__call__
method's new early return for empty input now returns an empty list. This bypasses the expectedValueError
for empty inputs, which existing tests and ChromaDB'sEmbeddingFunction
validation rely on.