Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lib/crewai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ boto3 = [
google-genai = [
"google-genai>=1.2.0",
]
azure-ai-inference = [
"azure-ai-inference>=1.0.0b9",
]


[project.scripts]
Expand Down
2 changes: 1 addition & 1 deletion lib/crewai/src/crewai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def _get_native_provider(cls, provider: str) -> type | None:
except ImportError:
return None

elif provider == "azure":
elif provider == "azure" or provider == "azure_openai":
try:
from crewai.llms.providers.azure.completion import AzureCompletion

Expand Down
102 changes: 79 additions & 23 deletions lib/crewai/src/crewai/llms/providers/azure/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@


try:
from azure.ai.inference import ChatCompletionsClient # type: ignore
from azure.ai.inference.models import ( # type: ignore
from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.models import (
ChatCompletions,
ChatCompletionsToolCall,
StreamingChatCompletionsUpdate,
)
from azure.core.credentials import AzureKeyCredential # type: ignore
from azure.core.exceptions import HttpResponseError # type: ignore
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM

Expand Down Expand Up @@ -80,7 +80,9 @@ def __init__(
or os.getenv("AZURE_OPENAI_ENDPOINT")
or os.getenv("AZURE_API_BASE")
)
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-02-01"
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01"
self.timeout = timeout
self.max_retries = max_retries

if not self.api_key:
raise ValueError(
Expand All @@ -91,10 +93,20 @@ def __init__(
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
)

self.client = ChatCompletionsClient(
endpoint=self.endpoint,
credential=AzureKeyCredential(self.api_key),
)
# Validate and potentially fix Azure OpenAI endpoint URL
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)

# Build client kwargs
client_kwargs = {
"endpoint": self.endpoint,
"credential": AzureKeyCredential(self.api_key),
}

# Add api_version if specified (primarily for Azure OpenAI endpoints)
if self.api_version:
client_kwargs["api_version"] = self.api_version

self.client = ChatCompletionsClient(**client_kwargs)

self.top_p = top_p
self.frequency_penalty = frequency_penalty
Expand All @@ -106,6 +118,34 @@ def __init__(
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
)

self.is_azure_openai_endpoint = (
"openai.azure.com" in self.endpoint
and "/openai/deployments/" in self.endpoint
)

def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str:
"""Validate and fix Azure endpoint URL format.

Azure OpenAI endpoints should be in the format:
https://<resource-name>.openai.azure.com/openai/deployments/<deployment-name>

Args:
endpoint: The endpoint URL
model: The model/deployment name

Returns:
Validated and potentially corrected endpoint URL
"""
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
endpoint = endpoint.rstrip("/")

if not endpoint.endswith("/openai/deployments"):
deployment_name = model.replace("azure/", "")
endpoint = f"{endpoint}/openai/deployments/{deployment_name}"
logging.info(f"Constructed Azure OpenAI endpoint URL: {endpoint}")

return endpoint

def call(
self,
messages: str | list[dict[str, str]],
Expand Down Expand Up @@ -158,7 +198,17 @@ def call(
)

except HttpResponseError as e:
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
if e.status_code == 401:
error_msg = "Azure authentication failed. Check your API key."
elif e.status_code == 404:
error_msg = (
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
)
elif e.status_code == 429:
error_msg = "Azure API rate limit exceeded. Please retry later."
else:
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"

logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
Expand Down Expand Up @@ -187,11 +237,15 @@ def _prepare_completion_params(
Parameters dictionary for Azure API
"""
params = {
"model": self.model,
"messages": messages,
"stream": self.stream,
}

# Only include model parameter for non-Azure OpenAI endpoints
# Azure OpenAI endpoints have the deployment name in the URL
if not self.is_azure_openai_endpoint:
params["model"] = self.model

# Add optional parameters if set
if self.temperature is not None:
params["temperature"] = self.temperature
Expand Down Expand Up @@ -250,26 +304,19 @@ def _format_messages_for_azure(
messages: Input messages

Returns:
List of dict objects
List of dict objects with 'role' and 'content' keys
"""
# Use base class formatting first
base_formatted = super()._format_messages(messages)

azure_messages = []

for message in base_formatted:
role = message.get("role")
role = message.get("role", "user") # Default to user if no role
content = message.get("content", "")

if role == "system":
azure_messages.append(dict(content=content))
elif role == "user":
azure_messages.append(dict(content=content))
elif role == "assistant":
azure_messages.append(dict(content=content))
else:
# Default to user message for unknown roles
azure_messages.append(dict(content=content))
# Azure AI Inference requires both 'role' and 'content'
azure_messages.append({"role": role, "content": content})

return azure_messages

Expand Down Expand Up @@ -339,6 +386,13 @@ def _handle_completion(
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e

error_msg = f"Azure API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise e

return content

def _handle_streaming_completion(
Expand Down Expand Up @@ -454,7 +508,9 @@ def get_context_window_size(self) -> int:
}

# Find the best match for the model name
for model_prefix, size in context_windows.items():
for model_prefix, size in sorted(
context_windows.items(), key=lambda x: len(x[0]), reverse=True
):
if self.model.startswith(model_prefix):
return int(size * CONTEXT_WINDOW_USAGE_RATIO)

Expand Down
Empty file.
3 changes: 3 additions & 0 deletions lib/crewai/tests/llms/azure/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Azure LLM tests


Loading
Loading