Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/en/learn/llm-connections.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ mode: "wide"
CrewAI uses LiteLLM to connect to a wide variety of Language Models (LLMs). This integration provides extensive versatility, allowing you to use models from numerous providers with a simple, unified interface.

<Note>
By default, CrewAI uses the `gpt-4o-mini` model. This is determined by the `OPENAI_MODEL_NAME` environment variable, which defaults to "gpt-4o-mini" if not set.
By default, CrewAI uses the `gpt-4o-mini` model. This is determined by the `OPENAI_MODEL_NAME` environment variable, which defaults to "gpt-4o-mini" if not set.
You can easily configure your agents to use a different model or provider as described in this guide.
</Note>

Expand Down Expand Up @@ -190,7 +190,7 @@ For local models like those provided by Ollama:

You can change the base API URL for any LLM provider by setting the `base_url` parameter:

```python Code
```python Code
llm = LLM(
model="custom-model-name",
base_url="https://api.your-provider.com/v1",
Expand Down
173 changes: 163 additions & 10 deletions src/crewai/rag/embeddings/providers/voyageai/embedding_callable.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,36 @@
"""VoyageAI embedding function implementation."""

from collections.abc import Callable, Generator
from typing import cast

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from typing_extensions import Unpack

from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig

# Token limits for different VoyageAI models
VOYAGE_TOTAL_TOKEN_LIMITS = {
"voyage-context-3": 32_000,
"voyage-3.5-lite": 1_000_000,
"voyage-3.5": 320_000,
"voyage-2": 320_000,
"voyage-3-large": 120_000,
"voyage-code-3": 120_000,
"voyage-large-2-instruct": 120_000,
"voyage-finance-2": 120_000,
"voyage-multilingual-2": 120_000,
"voyage-law-2": 120_000,
"voyage-large-2": 120_000,
"voyage-3": 120_000,
"voyage-3-lite": 120_000,
"voyage-code-2": 120_000,
"voyage-3-m-exp": 120_000,
"voyage-multimodal-3": 120_000,
}

# Batch size for embedding requests
BATCH_SIZE = 1000


class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for VoyageAI models."""
Expand Down Expand Up @@ -46,17 +70,146 @@ def __call__(self, input: Documents) -> Embeddings:
Returns:
List of embedding vectors.
"""

if isinstance(input, str):
input = [input]

result = self._client.embed(
texts=input,
model=self._config.get("model", "voyage-2"),
input_type=self._config.get("input_type"),
truncation=self._config.get("truncation", True),
output_dtype=self._config.get("output_dtype"),
output_dimension=self._config.get("output_dimension"),
)
# Early return for empty input
if not input:
return []
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Early Return Alters Expected Error Handling

The __call__ method's new early return for empty input now returns an empty list. This bypasses the expected ValueError for empty inputs, which existing tests and ChromaDB's EmbeddingFunction validation rely on.

Fix in Cursor Fix in Web


# Use unified batching for all text inputs
embeddings = self._embed_with_batching(list(input))

return cast(Embeddings, embeddings)

def _build_batches(self, texts: list[str]) -> Generator[list[str], None, None]:
"""
Generate batches of texts based on token limits using a generator.

Args:
texts: List of texts to batch.

Yields:
Batches of texts as lists.
"""
if not texts:
return

max_tokens_per_batch = self.get_token_limit()
current_batch: list[str] = []
current_batch_tokens = 0

# Tokenize all texts in one API call
all_token_lists = self._client.tokenize(texts, model=self._config["model"])
token_counts = [len(tokens) for tokens in all_token_lists]

for i, text in enumerate(texts):
n_tokens = token_counts[i]

# Check if adding this text would exceed limits
if current_batch and (
len(current_batch) >= BATCH_SIZE
or (current_batch_tokens + n_tokens > max_tokens_per_batch)
):
# Yield the current batch and start a new one
yield current_batch
current_batch = []
current_batch_tokens = 0

current_batch.append(text)
current_batch_tokens += n_tokens

# Yield the last batch (always has at least one text)
if current_batch:
yield current_batch

def _get_embed_function(self) -> Callable[[list[str]], list[list[float]]]:
"""
Get the appropriate embedding function based on model type.

Returns:
A callable that takes a batch of texts and returns embeddings.
"""
model_name = self._config["model"]

if self._is_context_model():

def embed_batch(batch: list[str]) -> list[list[float]]:
result = self._client.contextualized_embed(
inputs=[batch],
model=model_name,
input_type=self._config.get("input_type"),
output_dimension=self._config.get("output_dimension"),
)
return [list(emb) for emb in result.results[0].embeddings]

return embed_batch

def embed_batch_regular(batch: list[str]) -> list[list[float]]:
result = self._client.embed(
texts=batch,
model=model_name,
input_type=self._config.get("input_type"),
truncation=self._config.get("truncation", True),
output_dimension=self._config.get("output_dimension"),
)
return [list(emb) for emb in result.embeddings]

return embed_batch_regular

def _embed_with_batching(self, texts: list[str]) -> list[list[float]]:
"""
Unified method to embed texts with automatic batching based on token limits.
Works for regular and contextual models.

Args:
texts: List of texts to embed.

Returns:
List of embeddings.
"""
if not texts:
return []

# Get the appropriate embedding function for this model type
embed_fn = self._get_embed_function()

# Process each batch
all_embeddings = []
for batch in self._build_batches(texts):
batch_embeddings = embed_fn(batch)
all_embeddings.extend(batch_embeddings)

return all_embeddings

def count_tokens(self, texts: list[str]) -> list[int]:
"""
Count tokens for the given texts.

Args:
texts: List of texts to count tokens for.

Returns:
List of token counts for each text.
"""
if not texts:
return []

# Use the VoyageAI tokenize API to get token counts
token_lists = self._client.tokenize(texts, model=self._config["model"])
return [len(token_list) for token_list in token_lists]

def get_token_limit(self) -> int:
"""
Get the token limit for the current model.

Returns:
Token limit for the model, or default of 120_000 if not found.
"""
model_name = self._config["model"]
return VOYAGE_TOTAL_TOKEN_LIMITS.get(model_name, 120_000)

return cast(Embeddings, result.embeddings)
def _is_context_model(self) -> bool:
"""Check if the model is a contextualized embedding model."""
model_name = self._config["model"]
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Model Removal Breaks Existing Code Compatibility

Removing the default "voyage-2" for the model field makes it required for VoyageAIProvider and causes KeyError in VoyageAIEmbeddingFunction methods that directly access self._config["model"]. This breaks backward compatibility for existing code.

Additional Locations (1)

Fix in Cursor Fix in Web

return "context" in model_name
10 changes: 5 additions & 5 deletions src/crewai/rag/embeddings/providers/voyageai/types.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
"""Type definitions for VoyageAI embedding providers."""

from typing import Annotated, Literal
from typing import Literal

from typing_extensions import Required, TypedDict


class VoyageAIProviderConfig(TypedDict, total=False):
"""Configuration for VoyageAI provider."""

api_key: str
model: Annotated[str, "voyage-2"]
api_key: Required[str]
model: Required[str]
input_type: str
truncation: Annotated[bool, True]
truncation: bool
output_dtype: str
output_dimension: int
max_retries: Annotated[int, 0]
max_retries: int
timeout: float


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
description="Voyage AI embedding function class",
)
model: str = Field(
default="voyage-2",
description="Model to use for embeddings",
validation_alias="EMBEDDINGS_VOYAGEAI_MODEL",
)
Expand Down
Loading