diff --git a/docs/en/learn/llm-connections.mdx b/docs/en/learn/llm-connections.mdx index daedc21a27..7af38b71c4 100644 --- a/docs/en/learn/llm-connections.mdx +++ b/docs/en/learn/llm-connections.mdx @@ -10,7 +10,7 @@ mode: "wide" CrewAI uses LiteLLM to connect to a wide variety of Language Models (LLMs). This integration provides extensive versatility, allowing you to use models from numerous providers with a simple, unified interface. - By default, CrewAI uses the `gpt-4o-mini` model. This is determined by the `OPENAI_MODEL_NAME` environment variable, which defaults to "gpt-4o-mini" if not set. + By default, CrewAI uses the `gpt-4o-mini` model. This is determined by the `OPENAI_MODEL_NAME` environment variable, which defaults to "gpt-4o-mini" if not set. You can easily configure your agents to use a different model or provider as described in this guide. @@ -190,7 +190,7 @@ For local models like those provided by Ollama: You can change the base API URL for any LLM provider by setting the `base_url` parameter: -```python Code +```python Code llm = LLM( model="custom-model-name", base_url="https://api.your-provider.com/v1", diff --git a/src/crewai/rag/embeddings/providers/voyageai/embedding_callable.py b/src/crewai/rag/embeddings/providers/voyageai/embedding_callable.py index f7d7f71039..8449a07c55 100644 --- a/src/crewai/rag/embeddings/providers/voyageai/embedding_callable.py +++ b/src/crewai/rag/embeddings/providers/voyageai/embedding_callable.py @@ -1,5 +1,6 @@ """VoyageAI embedding function implementation.""" +from collections.abc import Callable, Generator from typing import cast from chromadb.api.types import Documents, EmbeddingFunction, Embeddings @@ -7,6 +8,29 @@ from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig +# Token limits for different VoyageAI models +VOYAGE_TOTAL_TOKEN_LIMITS = { + "voyage-context-3": 32_000, + "voyage-3.5-lite": 1_000_000, + "voyage-3.5": 320_000, + "voyage-2": 320_000, + "voyage-3-large": 120_000, + "voyage-code-3": 120_000, + "voyage-large-2-instruct": 120_000, + "voyage-finance-2": 120_000, + "voyage-multilingual-2": 120_000, + "voyage-law-2": 120_000, + "voyage-large-2": 120_000, + "voyage-3": 120_000, + "voyage-3-lite": 120_000, + "voyage-code-2": 120_000, + "voyage-3-m-exp": 120_000, + "voyage-multimodal-3": 120_000, +} + +# Batch size for embedding requests +BATCH_SIZE = 1000 + class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]): """Embedding function for VoyageAI models.""" @@ -46,17 +70,146 @@ def __call__(self, input: Documents) -> Embeddings: Returns: List of embedding vectors. """ - if isinstance(input, str): input = [input] - result = self._client.embed( - texts=input, - model=self._config.get("model", "voyage-2"), - input_type=self._config.get("input_type"), - truncation=self._config.get("truncation", True), - output_dtype=self._config.get("output_dtype"), - output_dimension=self._config.get("output_dimension"), - ) + # Early return for empty input + if not input: + return [] + + # Use unified batching for all text inputs + embeddings = self._embed_with_batching(list(input)) + + return cast(Embeddings, embeddings) + + def _build_batches(self, texts: list[str]) -> Generator[list[str], None, None]: + """ + Generate batches of texts based on token limits using a generator. + + Args: + texts: List of texts to batch. + + Yields: + Batches of texts as lists. + """ + if not texts: + return + + max_tokens_per_batch = self.get_token_limit() + current_batch: list[str] = [] + current_batch_tokens = 0 + + # Tokenize all texts in one API call + all_token_lists = self._client.tokenize(texts, model=self._config["model"]) + token_counts = [len(tokens) for tokens in all_token_lists] + + for i, text in enumerate(texts): + n_tokens = token_counts[i] + + # Check if adding this text would exceed limits + if current_batch and ( + len(current_batch) >= BATCH_SIZE + or (current_batch_tokens + n_tokens > max_tokens_per_batch) + ): + # Yield the current batch and start a new one + yield current_batch + current_batch = [] + current_batch_tokens = 0 + + current_batch.append(text) + current_batch_tokens += n_tokens + + # Yield the last batch (always has at least one text) + if current_batch: + yield current_batch + + def _get_embed_function(self) -> Callable[[list[str]], list[list[float]]]: + """ + Get the appropriate embedding function based on model type. + + Returns: + A callable that takes a batch of texts and returns embeddings. + """ + model_name = self._config["model"] + + if self._is_context_model(): + + def embed_batch(batch: list[str]) -> list[list[float]]: + result = self._client.contextualized_embed( + inputs=[batch], + model=model_name, + input_type=self._config.get("input_type"), + output_dimension=self._config.get("output_dimension"), + ) + return [list(emb) for emb in result.results[0].embeddings] + + return embed_batch + + def embed_batch_regular(batch: list[str]) -> list[list[float]]: + result = self._client.embed( + texts=batch, + model=model_name, + input_type=self._config.get("input_type"), + truncation=self._config.get("truncation", True), + output_dimension=self._config.get("output_dimension"), + ) + return [list(emb) for emb in result.embeddings] + + return embed_batch_regular + + def _embed_with_batching(self, texts: list[str]) -> list[list[float]]: + """ + Unified method to embed texts with automatic batching based on token limits. + Works for regular and contextual models. + + Args: + texts: List of texts to embed. + + Returns: + List of embeddings. + """ + if not texts: + return [] + + # Get the appropriate embedding function for this model type + embed_fn = self._get_embed_function() + + # Process each batch + all_embeddings = [] + for batch in self._build_batches(texts): + batch_embeddings = embed_fn(batch) + all_embeddings.extend(batch_embeddings) + + return all_embeddings + + def count_tokens(self, texts: list[str]) -> list[int]: + """ + Count tokens for the given texts. + + Args: + texts: List of texts to count tokens for. + + Returns: + List of token counts for each text. + """ + if not texts: + return [] + + # Use the VoyageAI tokenize API to get token counts + token_lists = self._client.tokenize(texts, model=self._config["model"]) + return [len(token_list) for token_list in token_lists] + + def get_token_limit(self) -> int: + """ + Get the token limit for the current model. + + Returns: + Token limit for the model, or default of 120_000 if not found. + """ + model_name = self._config["model"] + return VOYAGE_TOTAL_TOKEN_LIMITS.get(model_name, 120_000) - return cast(Embeddings, result.embeddings) + def _is_context_model(self) -> bool: + """Check if the model is a contextualized embedding model.""" + model_name = self._config["model"] + return "context" in model_name diff --git a/src/crewai/rag/embeddings/providers/voyageai/types.py b/src/crewai/rag/embeddings/providers/voyageai/types.py index 92579ba5d0..395778f765 100644 --- a/src/crewai/rag/embeddings/providers/voyageai/types.py +++ b/src/crewai/rag/embeddings/providers/voyageai/types.py @@ -1,6 +1,6 @@ """Type definitions for VoyageAI embedding providers.""" -from typing import Annotated, Literal +from typing import Literal from typing_extensions import Required, TypedDict @@ -8,13 +8,13 @@ class VoyageAIProviderConfig(TypedDict, total=False): """Configuration for VoyageAI provider.""" - api_key: str - model: Annotated[str, "voyage-2"] + api_key: Required[str] + model: Required[str] input_type: str - truncation: Annotated[bool, True] + truncation: bool output_dtype: str output_dimension: int - max_retries: Annotated[int, 0] + max_retries: int timeout: float diff --git a/src/crewai/rag/embeddings/providers/voyageai/voyageai_provider.py b/src/crewai/rag/embeddings/providers/voyageai/voyageai_provider.py index 133b02db7b..62a8680b10 100644 --- a/src/crewai/rag/embeddings/providers/voyageai/voyageai_provider.py +++ b/src/crewai/rag/embeddings/providers/voyageai/voyageai_provider.py @@ -16,7 +16,6 @@ class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]): description="Voyage AI embedding function class", ) model: str = Field( - default="voyage-2", description="Model to use for embeddings", validation_alias="EMBEDDINGS_VOYAGEAI_MODEL", ) diff --git a/tests/rag/embeddings/test_voyageai_embedding.py b/tests/rag/embeddings/test_voyageai_embedding.py new file mode 100644 index 0000000000..6150380425 --- /dev/null +++ b/tests/rag/embeddings/test_voyageai_embedding.py @@ -0,0 +1,266 @@ +"""Tests for VoyageAI embedding function.""" + +import os + +import pytest + +from crewai.rag.embeddings.providers.voyageai.embedding_callable import ( + VoyageAIEmbeddingFunction, +) + +voyageai = pytest.importorskip("voyageai", reason="voyageai not installed") + + +def test_basic_embedding() -> None: + """Test basic embedding generation.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-3.5", + ) + embeddings = ef(["hello world"]) + assert embeddings is not None + assert len(embeddings) == 1 + assert len(embeddings[0]) > 0 + + +def test_with_embedding_dimensions() -> None: + """Test embedding generation with custom dimensions.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-3.5", + output_dimension=2048, + ) + embeddings = ef(["hello world"]) + assert embeddings is not None + assert len(embeddings) == 1 + assert len(embeddings[0]) == 2048 + + +def test_with_contextual_embedding() -> None: + """Test contextual embedding generation.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-context-3", + output_dimension=2048, + ) + embeddings = ef(["hello world", "in chroma"]) + assert embeddings is not None + assert len(embeddings) == 2 + assert len(embeddings[0]) == 2048 + + +def test_count_tokens() -> None: + """Test token counting functionality.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-3.5", + ) + texts = ["hello world", "this is a longer text with more tokens"] + token_counts = ef.count_tokens(texts) + assert len(token_counts) == 2 + assert token_counts[0] > 0 + assert token_counts[1] > token_counts[0] # Longer text should have more tokens + + +def test_count_tokens_empty_list() -> None: + """Test token counting with empty list.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-3.5", + ) + token_counts = ef.count_tokens([]) + assert token_counts == [] + + +def test_count_tokens_single_text() -> None: + """Test token counting with single text.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-2", + ) + token_counts = ef.count_tokens(["hello"]) + assert len(token_counts) == 1 + assert token_counts[0] > 0 + + +def test_get_token_limit() -> None: + """Test getting token limit for different models.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + + # Test voyage-2 model + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-2", + ) + assert ef.get_token_limit() == 320_000 + + # Test context model + ef_context = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-context-3", + ) + assert ef_context.get_token_limit() == 32_000 + + # Test voyage-3-large model + ef_large = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-3-large", + ) + assert ef_large.get_token_limit() == 120_000 + + +def test_batching_with_multiple_texts() -> None: + """Test that batching works with multiple texts.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-3.5", + ) + texts = ["text1", "text2", "text3", "text4", "text5"] + embeddings = ef(texts) + assert len(embeddings) == 5 + assert all(len(emb) > 0 for emb in embeddings) + + +def test_build_batches() -> None: + """Test the _build_batches method.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-2", + ) + texts = ["short", "text", "here", "now"] + batches = list(ef._build_batches(texts)) + # Should create at least one batch + assert len(batches) >= 1 + # Total texts should be preserved + total_texts = sum(len(batch) for batch in batches) + assert total_texts == len(texts) + + +def test_batching_with_large_texts() -> None: + """Test batching with texts that may exceed token limits.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-3.5", + ) + # Create long texts + long_text = "This is a long text with many words. " * 100 + texts = [long_text, long_text, long_text] + embeddings = ef(texts) + assert len(embeddings) == 3 + assert all(len(emb) > 0 for emb in embeddings) + + +def test_contextual_batching() -> None: + """Test that contextual models support batching.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-context-3", + ) + texts = ["text1", "text2", "text3", "text4"] + embeddings = ef(texts) + assert len(embeddings) == 4 + assert all(len(emb) > 0 for emb in embeddings) + + +def test_contextual_build_batches() -> None: + """Test that contextual models use _build_batches correctly.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-context-3", + ) + texts = ["short", "text", "here", "now", "more"] + batches = list(ef._build_batches(texts)) + # Should create at least one batch + assert len(batches) >= 1 + # Total texts should be preserved + total_texts = sum(len(batch) for batch in batches) + assert total_texts == len(texts) + + +def test_contextual_with_large_batch() -> None: + """Test contextual model with large batch that should be split.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-context-3", + ) + # Create many texts + texts = [f"Document number {i} with some content" for i in range(15)] + embeddings = ef(texts) + assert len(embeddings) == 15 + assert all(len(emb) > 0 for emb in embeddings) + + +def test_empty_input() -> None: + """Test with empty input.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-3.5", + ) + # ChromaDB's EmbeddingFunction validates that embeddings are non-empty + with pytest.raises(ValueError, match="Expected Embeddings to be non-empty"): + ef([]) + + +def test_single_string_input() -> None: + """Test with single string input (not in a list).""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + ef = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-3.5", + ) + embeddings = ef("hello world") + assert len(embeddings) == 1 + assert len(embeddings[0]) > 0 + + +def test_is_context_model() -> None: + """Test the _is_context_model helper method.""" + if os.environ.get("VOYAGE_API_KEY") is None: + pytest.skip("VOYAGE_API_KEY not set") + + # Test with context model + ef_context = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-context-3", + ) + assert ef_context._is_context_model() is True + + # Test with regular model + ef_regular = VoyageAIEmbeddingFunction( + api_key=os.environ["VOYAGE_API_KEY"], + model="voyage-3.5", + ) + assert ef_regular._is_context_model() is False + + +def test_name() -> None: + """Test the static name method.""" + assert VoyageAIEmbeddingFunction.name() == "voyageai"