diff --git a/python/instrumentation/openinference-instrumentation-mistralai/README.md b/python/instrumentation/openinference-instrumentation-mistralai/README.md index 1624095ed3..cf490b77dd 100644 --- a/python/instrumentation/openinference-instrumentation-mistralai/README.md +++ b/python/instrumentation/openinference-instrumentation-mistralai/README.md @@ -76,6 +76,74 @@ Now simply run the python file and observe the traces in Phoenix. python your_file.py ``` +## OCR and Input Image Tracing + +The MistralAI instrumentation automatically traces input images and documents passed to the OCR API, following OpenInference semantic conventions. This includes: + +### Supported Input Types + +- **HTTP Image URLs**: `https://example.com/image.jpg` +- **Base64 Images**: `data:image/jpeg;base64,{base64_data}` +- **PDF URLs**: `https://example.com/document.pdf` +- **Base64 PDFs**: `data:application/pdf;base64,{base64_data}` + +### Trace Attributes + +For **image inputs**, the instrumentation creates: +- `input.message_content.type`: `"image"` +- `input.message_content.image.image.url`: The image URL or base64 data URL +- `input.message_content.image.metadata`: JSON metadata including source, encoding type, and MIME type + +For **document inputs**, the instrumentation creates: +- `input.message_content.type`: `"document"` +- `input.document.url`: The document URL or base64 data URL +- `input.document.metadata`: JSON metadata including source, encoding type, and MIME type + +### Example Usage + +```python +import base64 +import os +from mistralai import Mistral +from openinference.instrumentation.mistralai import MistralAIInstrumentor + +# Set up instrumentation +MistralAIInstrumentor().instrument() + +client = Mistral(api_key=os.environ["MISTRAL_API_KEY"]) + +# OCR with HTTP image URL +response = client.ocr.process( + model="mistral-ocr-latest", + document={ + "type": "image_url", + "image_url": "https://example.com/receipt.png" + }, + include_image_base64=True +) + +# OCR with base64 image +with open("image.jpg", "rb") as f: + base64_image = base64.b64encode(f.read()).decode('utf-8') + +response = client.ocr.process( + model="mistral-ocr-latest", + document={ + "type": "image_url", + "image_url": f"data:image/jpeg;base64,{base64_image}" + }, + include_image_base64=True +) +``` + +### Privacy and Configuration + +Input image tracing works seamlessly with [TraceConfig](https://github.com/Arize-ai/openinference/tree/main/python/openinference-instrumentation#tracing-configuration) for: + +- **Image size limits**: Control maximum base64 image length with `base64_image_max_length` +- **Privacy controls**: Hide input images with `hide_inputs` or `hide_input_images` +- **MIME type detection**: Automatic detection and proper formatting of image data URLs + ## More Info * [More info on OpenInference and Phoenix](https://docs.arize.com/phoenix) diff --git a/python/instrumentation/openinference-instrumentation-mistralai/examples/ocr.py b/python/instrumentation/openinference-instrumentation-mistralai/examples/ocr.py new file mode 100644 index 0000000000..2e5ed18948 --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-mistralai/examples/ocr.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 + +import os + +from dotenv import load_dotenv +from mistralai import Mistral +from phoenix.otel import register + +from openinference.instrumentation.mistralai import MistralAIInstrumentor + +load_dotenv() + +tracer = register( + project_name="mistral-ocr", + endpoint=os.getenv("PHOENIX_COLLECTOR_ENDPOINT"), +) + +# Initialize instrumentation +MistralAIInstrumentor().instrument(tracer_provider=tracer) + +# Initialize client +client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY")) + + +def test_ocr_with_working_image(): + """Test OCR with a working image URL that should display in Phoenix""" + + # Using a reliable image URL - this is a simple diagram/chart + image_url = "https://upload.wikimedia.org/wikipedia/commons/d/d1/Ai_lizard.png" + try: + print("šŸ” Testing OCR with working image URL...") + print(f"Image URL: {image_url}") + + ocr_response = client.ocr.process( + model="mistral-ocr-latest", + document={ + "type": "image_url", + "image_url": image_url, + }, + include_image_base64=True, + ) + + print("āœ… OCR completed successfully!") + + if hasattr(ocr_response, "pages") and ocr_response.pages: + print(f"šŸ“„ Pages processed: {len(ocr_response.pages)}") + + for i, page in enumerate(ocr_response.pages): + print(f"\n--- Page {i + 1} ---") + if hasattr(page, "markdown") and page.markdown: + print("šŸ“ Markdown content:") + print( + page.markdown[:200] + "..." if len(page.markdown) > 200 else page.markdown + ) + + if hasattr(page, "images") and page.images: + print(f"šŸ–¼ļø Extracted images: {len(page.images)}") + for j, img in enumerate(page.images): + if hasattr(img, "id"): + print(f" - Image {j + 1}: {img.id}") + + print("\nšŸ”— View traces in your Phoenix project") + + except Exception as e: + print(f"āŒ Error: {e}") + print("Make sure:") + print("1. MISTRAL_API_KEY environment variable is set") + print("2. You have sufficient credits") + print("3. The image URL is accessible") + + +if __name__ == "__main__": + test_ocr_with_working_image() diff --git a/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/__init__.py b/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/__init__.py index bcefd0b490..8afcf3b133 100644 --- a/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/__init__.py +++ b/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/__init__.py @@ -14,6 +14,10 @@ _AsyncStreamChatWrapper, _SyncChatWrapper, ) +from openinference.instrumentation.mistralai._ocr_wrapper import ( + _AsyncOCRWrapper, + _SyncOCRWrapper, +) from openinference.instrumentation.mistralai.package import _instruments from openinference.instrumentation.mistralai.version import __version__ @@ -38,6 +42,8 @@ class MistralAIInstrumentor(BaseInstrumentor): # type: ignore "_original_sync_stream_agent_method", "_original_async_agent_method", "_original_async_stream_agent_method", + "_original_sync_ocr_method", + "_original_async_ocr_method", ) def instrumentation_dependencies(self) -> Collection[str]: @@ -59,6 +65,12 @@ def _instrument(self, **kwargs: Any) -> None: import mistralai from mistralai.agents import Agents from mistralai.chat import Chat + + Ocr: Any = None + try: + from mistralai.ocr import Ocr + except ImportError: + print("Outdated version of mistralai: currently version does not support Ocr") except ImportError as err: raise Exception( "Could not import mistralai. Please install with `pip install mistralai`." @@ -72,6 +84,10 @@ def _instrument(self, **kwargs: Any) -> None: self._original_sync_stream_agent_method = Agents.stream self._original_async_agent_method = Agents.complete_async self._original_async_stream_agent_method = Agents.stream_async + if Ocr is not None: + self._original_sync_ocr_method = Ocr.process + self._original_async_ocr_method = Ocr.process_async + wrap_function_wrapper( module="mistralai.chat", name="Chat.complete", @@ -120,6 +136,20 @@ def _instrument(self, **kwargs: Any) -> None: wrapper=_AsyncStreamChatWrapper("MistralAsyncClient.agents", self._tracer, mistralai), ) + # Instrument OCR methods + if Ocr is not None: + wrap_function_wrapper( + module="mistralai.ocr", + name="Ocr.process", + wrapper=_SyncOCRWrapper("MistralClient.ocr", self._tracer, mistralai), + ) + + wrap_function_wrapper( + module="mistralai.ocr", + name="Ocr.process_async", + wrapper=_AsyncOCRWrapper("MistralAsyncClient.ocr", self._tracer, mistralai), + ) + def _uninstrument(self, **kwargs: Any) -> None: from mistralai.agents import Agents from mistralai.chat import Chat @@ -132,3 +162,11 @@ def _uninstrument(self, **kwargs: Any) -> None: Agents.stream = self._original_sync_stream_agent_method # type: ignore Agents.complete_async = self._original_async_agent_method # type: ignore Agents.stream_async = self._original_async_stream_agent_method # type: ignore + try: + from mistralai.ocr import Ocr + + Ocr.process = self._original_sync_ocr_method # type: ignore + Ocr.process_async = self._original_async_ocr_method # type: ignore + except ImportError: + # OCR module not available, nothing to uninstrument + pass diff --git a/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/_ocr_wrapper.py b/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/_ocr_wrapper.py new file mode 100644 index 0000000000..d86fadb0fc --- /dev/null +++ b/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/_ocr_wrapper.py @@ -0,0 +1,305 @@ +import logging +from contextlib import contextmanager +from inspect import Signature, signature +from types import ModuleType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Iterator, + Mapping, + Tuple, +) + +from opentelemetry import context as context_api +from opentelemetry import trace as trace_api +from opentelemetry.context import _SUPPRESS_INSTRUMENTATION_KEY +from opentelemetry.trace import INVALID_SPAN +from opentelemetry.util.types import AttributeValue + +from openinference.instrumentation import get_attributes_from_context, safe_json_dumps +from openinference.instrumentation.mistralai._request_attributes_extractor import ( + _RequestAttributesExtractor, +) +from openinference.instrumentation.mistralai._response_attributes_extractor import ( + _OCRResponseAttributesExtractor, +) +from openinference.instrumentation.mistralai._utils import ( + _as_input_attributes, + _finish_tracing, + _io_value_and_type, +) +from openinference.instrumentation.mistralai._with_span import _WithSpan +from openinference.semconv.trace import ( + OpenInferenceMimeTypeValues, + OpenInferenceSpanKindValues, + SpanAttributes, +) + +if TYPE_CHECKING: + from mistralai import Mistral + +__all__ = ( + "_SyncOCRWrapper", + "_AsyncOCRWrapper", +) + +# Define OCR span kind since it's not in openinference-semantic-conventions yet +_OCR_SPAN_KIND = OpenInferenceSpanKindValues.LLM.value + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +class _OCRResponseAttributes: + __slots__ = ( + "_response", + "_request_parameters", + "_response_attributes_extractor", + ) + + def __init__( + self, + response: Any, + request_parameters: Mapping[str, Any], + response_attributes_extractor: _OCRResponseAttributesExtractor, + ) -> None: + self._response = response + self._request_parameters = request_parameters + self._response_attributes_extractor = response_attributes_extractor + + def get_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + if hasattr(self._response, "model_dump_json") and callable(self._response.model_dump_json): + try: + value = self._response.model_dump_json(exclude_unset=True) + assert isinstance(value, str) + yield SpanAttributes.OUTPUT_VALUE, value + yield SpanAttributes.OUTPUT_MIME_TYPE, OpenInferenceMimeTypeValues.JSON.value + except Exception: + logger.exception("Failed to get model dump json") + + def get_extra_attributes(self) -> Iterator[Tuple[str, AttributeValue]]: + yield from self._response_attributes_extractor.get_attributes_from_response( + response=self._response, + request_parameters=self._request_parameters, + ) + + +class _WithTracer: + def __init__(self, tracer: trace_api.Tracer) -> None: + self._tracer = tracer + + @contextmanager + def _start_as_current_span( + self, + span_name: str, + attributes: Iterable[Tuple[str, AttributeValue]], + context_attributes: Iterable[Tuple[str, AttributeValue]], + extra_attributes: Iterable[Tuple[str, AttributeValue]], + ) -> Iterator[_WithSpan]: + # Because OTEL has a default limit of 128 attributes, we split our + # attributes into two tiers, where "extra_attributes" are added first to + # ensure that the most important "attributes" are added last and are not + # dropped. + try: + span = self._tracer.start_span(name=span_name, attributes=dict(extra_attributes)) + except Exception: + logger.exception("Failed to start span") + span = INVALID_SPAN + with trace_api.use_span( + span, + end_on_exit=False, + record_exception=False, + set_status_on_exception=False, + ) as span: + yield _WithSpan( + span=span, + context_attributes=dict(context_attributes), + extra_attributes=dict(attributes), + ) + + +class _WithMistralAI: + __slots__ = ( + "_request_attributes_extractor", + "_response_attributes_extractor", + ) + + def __init__(self, mistralai: ModuleType) -> None: + self._request_attributes_extractor = _RequestAttributesExtractor(mistralai) + self._response_attributes_extractor = _OCRResponseAttributesExtractor() + + def _get_span_kind(self) -> str: + return _OCR_SPAN_KIND + + def _get_attributes_from_request( + self, + request_parameters: Dict[str, Any], + ) -> Iterator[Tuple[str, AttributeValue]]: + yield SpanAttributes.OPENINFERENCE_SPAN_KIND, self._get_span_kind() + try: + yield from _as_input_attributes(_io_value_and_type(request_parameters)) + except Exception: + logger.exception( + f"Failed to get input attributes from request parameters of " + f"type {type(request_parameters)}" + ) + + def _get_extra_attributes_from_request( + self, + request_parameters: Mapping[str, Any], + ) -> Iterator[Tuple[str, AttributeValue]]: + try: + yield from self._request_attributes_extractor.get_attributes_from_ocr_request( + request_parameters=request_parameters, + ) + except Exception: + logger.exception( + f"Failed to get extra attributes from request options of " + f"type {type(request_parameters)}" + ) + + def _parse_args( + self, + signature: Signature, + mistral_client: "Mistral", + *args: Tuple[Any], + **kwargs: Mapping[str, Any], + ) -> Dict[str, Any]: + """ + Serialize parameters to JSON. + """ + bound_signature = signature.bind(*args, **kwargs) + bound_signature.apply_defaults() + bound_arguments = bound_signature.arguments + request_data: Dict[str, Any] = {} + for key, value in bound_arguments.items(): + try: + if value is not None: + try: + # ensure the value is JSON-serializable + safe_json_dumps(value) + request_data[key] = value + except Exception: + request_data[key] = str(value) + except Exception: + request_data[key] = str(value) + return request_data + + def _finalize_response( + self, + response: Any, + with_span: _WithSpan, + request_parameters: Mapping[str, Any], + ) -> Any: + """ + Finish tracing for the OCR response. + """ + try: + _finish_tracing( + status=trace_api.Status(status_code=trace_api.StatusCode.OK), + with_span=with_span, + has_attributes=_OCRResponseAttributes( + request_parameters=request_parameters, + response=response, + response_attributes_extractor=self._response_attributes_extractor, + ), + ) + except Exception: + logger.exception(f"Failed to finish tracing for response of type {type(response)}") + with_span.finish_tracing() + return response + + +class _SyncOCRWrapper(_WithTracer, _WithMistralAI): + def __init__(self, span_name: str, tracer: trace_api.Tracer, mistralai: ModuleType): + _WithTracer.__init__(self, tracer) + _WithMistralAI.__init__(self, mistralai) + self._span_name = span_name + + def __call__( + self, + wrapped: Callable[..., Any], + instance: "Mistral", + args: Tuple[Any], + kwargs: Mapping[str, Any], + ) -> Any: + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): + return wrapped(*args, **kwargs) + try: + request_parameters = self._parse_args(signature(wrapped), instance, *args, **kwargs) + except Exception: + logger.exception("Failed to parse request args") + return wrapped(*args, **kwargs) + with self._start_as_current_span( + span_name=self._span_name, + attributes=self._get_attributes_from_request(request_parameters), + context_attributes=get_attributes_from_context(), + extra_attributes=self._get_extra_attributes_from_request( + request_parameters + ), # redundant under the current span type of LLM + ) as with_span: + try: + response = wrapped(*args, **kwargs) + except Exception as exception: + with_span.record_exception(exception) + status = trace_api.Status( + status_code=trace_api.StatusCode.ERROR, + # Follow the format in OTEL SDK for description, see: + # https://github.com/open-telemetry/opentelemetry-python/blob/2b9dcfc5d853d1c10176937a6bcaade54cda1a31/opentelemetry-api/src/opentelemetry/trace/__init__.py#L588 # noqa E501 + description=f"{type(exception).__name__}: {exception}", + ) + with_span.finish_tracing(status=status) + raise + return self._finalize_response( + response=response, + with_span=with_span, + request_parameters=request_parameters, + ) + + +class _AsyncOCRWrapper(_WithTracer, _WithMistralAI): + def __init__(self, span_name: str, tracer: trace_api.Tracer, mistralai: ModuleType): + _WithTracer.__init__(self, tracer) + _WithMistralAI.__init__(self, mistralai) + self._span_name = span_name + + async def __call__( + self, + wrapped: Callable[..., Any], + instance: "Mistral", + args: Tuple[Any], + kwargs: Mapping[str, Any], + ) -> Any: + if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY): + return await wrapped(*args, **kwargs) + try: + request_parameters = self._parse_args(signature(wrapped), instance, *args, **kwargs) + except Exception: + logger.exception("Failed to parse request args") + return await wrapped(*args, **kwargs) + with self._start_as_current_span( + span_name=self._span_name, + attributes=self._get_attributes_from_request(request_parameters), + context_attributes=get_attributes_from_context(), + extra_attributes=self._get_extra_attributes_from_request(request_parameters), + ) as with_span: + try: + response = await wrapped(*args, **kwargs) + except Exception as exception: + with_span.record_exception(exception) + status = trace_api.Status( + status_code=trace_api.StatusCode.ERROR, + # Follow the format in OTEL SDK for description, see: + # https://github.com/open-telemetry/opentelemetry-python/blob/2b9dcfc5d853d1c10176937a6bcaade54cda1a31/opentelemetry-api/src/opentelemetry/trace/__init__.py#L588 # noqa E501 + description=f"{type(exception).__name__}: {exception}", + ) + with_span.finish_tracing(status=status) + raise + return self._finalize_response( + response=response, + with_span=with_span, + request_parameters=request_parameters, + ) diff --git a/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/_request_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/_request_attributes_extractor.py index 08b8697cdc..99be5b9759 100644 --- a/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/_request_attributes_extractor.py +++ b/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/_request_attributes_extractor.py @@ -13,9 +13,18 @@ from opentelemetry.util.types import AttributeValue from openinference.instrumentation import safe_json_dumps -from openinference.semconv.trace import MessageAttributes, SpanAttributes, ToolCallAttributes +from openinference.semconv.trace import ( + ImageAttributes, + MessageAttributes, + MessageContentAttributes, + SpanAttributes, + ToolCallAttributes, +) -__all__ = ("_RequestAttributesExtractor",) +__all__ = ( + "_RequestAttributesExtractor", + "_get_attributes_from_ocr_process_param", +) logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -40,6 +49,13 @@ def get_attributes_from_request( return yield from _get_attributes_from_chat_completion_create_param(request_parameters) + def get_attributes_from_ocr_request( + self, request_parameters: Mapping[str, Any] + ) -> Iterator[Tuple[str, AttributeValue]]: + if not isinstance(request_parameters, Mapping): + return + yield from _get_attributes_from_ocr_process_param(request_parameters) + def _get_attributes_from_chat_completion_create_param( params: Mapping[str, Any], @@ -59,6 +75,103 @@ def _get_attributes_from_chat_completion_create_param( yield f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{key}", value +def _is_base64_url(url: str) -> bool: + """Check if a URL is a base64 data URL.""" + return url.startswith("data:") and "base64" in url + + +def _get_attributes_from_ocr_process_param( + params: Mapping[str, Any], +) -> Iterator[Tuple[str, AttributeValue]]: + if not isinstance(params, Mapping): + return + + # Extract model information + model = params.get("model") + if model: + yield SpanAttributes.LLM_MODEL_NAME, model + + # Extract basic OCR parameters + invocation_params = dict(params) + # # Remove document from params as it might contain binary data + # invocation_params.pop("document", None) + yield SpanAttributes.LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_params) + + # Extract document/image input as LLM input message (like OpenAI chat completions) + document = params.get("document") + if document: + yield from _get_attributes_from_document_as_message(document) + + # Extract annotation format information + bbox_format = params.get("bbox_annotation_format") + if bbox_format: + yield "ocr.bbox_annotation_format", safe_json_dumps(bbox_format) + + doc_format = params.get("document_annotation_format") + if doc_format: + yield "ocr.document_annotation_format", safe_json_dumps(doc_format) + + +def _get_attributes_from_document_as_message( + document: Mapping[str, Any], +) -> Iterator[Tuple[str, AttributeValue]]: + """Convert document input to LLM input message format that Phoenix can display.""" + if not hasattr(document, "get"): + return + + doc_type = document.get("type") + + # Create a synthetic LLM input message for the document + # This follows the exact same pattern as OpenAI chat completions + message_index = 0 + content_index = 0 + + # Add role for the synthetic message + yield ( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{MessageAttributes.MESSAGE_ROLE}", + "user", + ) + + if doc_type == "image_url": + # Handle image inputs - follow OpenAI pattern exactly + if image_url := document.get("image_url"): + yield ( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{MessageAttributes.MESSAGE_CONTENTS}.{content_index}.{MessageContentAttributes.MESSAGE_CONTENT_TYPE}", + "image", + ) + yield ( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{MessageAttributes.MESSAGE_CONTENTS}.{content_index}.{MessageContentAttributes.MESSAGE_CONTENT_IMAGE}.{ImageAttributes.IMAGE_URL}", + image_url, + ) + + elif doc_type == "document_url": + # Handle document/PDF inputs + if document_url := document.get("document_url"): + # Determine if it's an image based on URL pattern + if ( + _is_base64_url(document_url) and not document_url.startswith("data:application/pdf") + ) or document_url.lower().endswith((".jpg", ".jpeg", ".png", ".gif", ".webp")): + # Treat as image for Phoenix display + yield ( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{MessageAttributes.MESSAGE_CONTENTS}.{content_index}.{MessageContentAttributes.MESSAGE_CONTENT_TYPE}", + "image", + ) + yield ( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{MessageAttributes.MESSAGE_CONTENTS}.{content_index}.{MessageContentAttributes.MESSAGE_CONTENT_IMAGE}.{ImageAttributes.IMAGE_URL}", + document_url, + ) + else: + # For PDFs and other documents, add as text content with URL reference + yield ( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{MessageAttributes.MESSAGE_CONTENTS}.{content_index}.{MessageContentAttributes.MESSAGE_CONTENT_TYPE}", + "text", + ) + yield ( + f"{SpanAttributes.LLM_INPUT_MESSAGES}.{message_index}.{MessageAttributes.MESSAGE_CONTENTS}.{content_index}.{MessageContentAttributes.MESSAGE_CONTENT_TEXT}", + f"Document: {document_url}", + ) + + def _get_attributes_from_message_param( message: Mapping[str, Any], ) -> Iterator[Tuple[str, AttributeValue]]: diff --git a/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/_response_attributes_extractor.py b/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/_response_attributes_extractor.py index 9adcb8296e..84f3fee9bc 100644 --- a/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/_response_attributes_extractor.py +++ b/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/_response_attributes_extractor.py @@ -11,20 +11,110 @@ from opentelemetry.util.types import AttributeValue from openinference.semconv.trace import ( + ImageAttributes, MessageAttributes, + MessageContentAttributes, SpanAttributes, ToolCallAttributes, ) if TYPE_CHECKING: from mistralai.models import ChatCompletionResponse + from mistralai.models.ocrresponse import OCRResponse -__all__ = ("_ResponseAttributesExtractor",) +__all__ = ( + "_ResponseAttributesExtractor", + "_OCRResponseAttributesExtractor", +) logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) +class _OCRResponseAttributesExtractor: + def get_attributes_from_response( + self, + response: Any, + request_parameters: Mapping[str, Any], + ) -> Iterator[Tuple[str, AttributeValue]]: + yield from _get_attributes_from_ocr_response(response) + + +def _get_attributes_from_ocr_usage( + usage_info: object, +) -> Iterator[Tuple[str, AttributeValue]]: + """Extract usage information from OCR usage info.""" + if (pages_processed := _get_attribute_or_value(usage_info, "pages_processed")) is not None: + yield "ocr.pages_processed", pages_processed + + if (doc_size_bytes := _get_attribute_or_value(usage_info, "doc_size_bytes")) is not None: + yield "ocr.document_size_bytes", doc_size_bytes + + +def _get_attributes_from_ocr_response( + response: "OCRResponse", +) -> Iterator[Tuple[str, AttributeValue]]: + # Extract model name + if model := getattr(response, "model", None): + yield SpanAttributes.LLM_MODEL_NAME, model + + # Extract usage information + if usage_info := getattr(response, "usage_info", None): + yield from _get_attributes_from_ocr_usage(usage_info) + + # Extract document annotation if present - this is the main output + if document_annotation := getattr(response, "document_annotation", None): + yield SpanAttributes.OUTPUT_VALUE, document_annotation + + # Structure OCR output as LLM output messages - one message per page for Phoenix display + if (pages := getattr(response, "pages", None)) and isinstance(pages, Iterable): + for page_index, page in enumerate(pages): + message_index = page_index # Each page gets its own message + content_index = 0 + + # Add role for this page's output message + yield ( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{message_index}.{MessageAttributes.MESSAGE_ROLE}", + "assistant", + ) + + # Add markdown content for this page + if markdown := _get_attribute_or_value(page, "markdown"): + yield ( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{message_index}.{MessageAttributes.MESSAGE_CONTENTS}.{content_index}.{MessageContentAttributes.MESSAGE_CONTENT_TYPE}", + "text", + ) + yield ( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{message_index}.{MessageAttributes.MESSAGE_CONTENTS}.{content_index}.{MessageContentAttributes.MESSAGE_CONTENT_TEXT}", + markdown, + ) + content_index += 1 + + # Add extracted images from this page - keep it simple and robust + if (images := _get_attribute_or_value(page, "images")) and isinstance(images, Iterable): + for image in images: + if image_base64 := _get_attribute_or_value(image, "image_base64"): + # Add image content + yield ( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{message_index}.{MessageAttributes.MESSAGE_CONTENTS}.{content_index}.{MessageContentAttributes.MESSAGE_CONTENT_TYPE}", + "image", + ) + yield ( + f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{message_index}.{MessageAttributes.MESSAGE_CONTENTS}.{content_index}.{MessageContentAttributes.MESSAGE_CONTENT_IMAGE}.{ImageAttributes.IMAGE_URL}", + image_base64, + ) + content_index += 1 + + # Keep basic structured data for retrieval context + for page_index, page in enumerate(pages): + if markdown := _get_attribute_or_value(page, "markdown"): + yield f"retrieval.documents.{page_index}.document.content", markdown + yield ( + f"retrieval.documents.{page_index}.document.metadata", + f'{{"type": "ocr_page", "page_index": {page_index}}}', + ) + + class _ResponseAttributesExtractor: def get_attributes_from_response( self, diff --git a/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/version.py b/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/version.py index 7b1e312007..3e8d9f9462 100644 --- a/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/version.py +++ b/python/instrumentation/openinference-instrumentation-mistralai/src/openinference/instrumentation/mistralai/version.py @@ -1 +1 @@ -__version__ = "1.3.3" +__version__ = "1.4.0" diff --git a/python/instrumentation/openinference-instrumentation-mistralai/test-requirements.txt b/python/instrumentation/openinference-instrumentation-mistralai/test-requirements.txt index 63c4b5189c..fdbaa995c4 100644 --- a/python/instrumentation/openinference-instrumentation-mistralai/test-requirements.txt +++ b/python/instrumentation/openinference-instrumentation-mistralai/test-requirements.txt @@ -1,4 +1,4 @@ -mistralai == 1.0.2 +mistralai == 1.8.1 opentelemetry-sdk opentelemetry-instrumentation-httpx respx diff --git a/python/instrumentation/openinference-instrumentation-mistralai/tests/openinference/instrumentation/mistralai/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-mistralai/tests/openinference/instrumentation/mistralai/test_instrumentor.py index 49a201f4b6..756704570f 100644 --- a/python/instrumentation/openinference-instrumentation-mistralai/tests/openinference/instrumentation/mistralai/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-mistralai/tests/openinference/instrumentation/mistralai/test_instrumentor.py @@ -32,6 +32,7 @@ from openinference.semconv.trace import ( EmbeddingAttributes, MessageAttributes, + MessageContentAttributes, OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, SpanAttributes, @@ -1386,3 +1387,298 @@ def instrument( USER_ID = SpanAttributes.USER_ID METADATA = SpanAttributes.METADATA TAG_TAGS = SpanAttributes.TAG_TAGS + + +@pytest.mark.parametrize("use_context_attributes", [False, True]) +def test_synchronous_ocr_emits_expected_span( + use_context_attributes: bool, + mistral_sync_client: Mistral, + in_memory_span_exporter: InMemorySpanExporter, + respx_mock: Any, + session_id: str, + user_id: str, + metadata: Dict[str, Any], + tags: List[str], + prompt_template: str, + prompt_template_version: str, + prompt_template_variables: Dict[str, Any], +) -> None: + """Test synchronous OCR functionality with proper span instrumentation.""" + # Mock OCR response + respx.post("https://api.mistral.ai/v1/ocr").mock( + return_value=Response( + 200, + json={ + "model": "mistral-ocr-2505-completion", + "usage_info": {"pages_processed": 2, "doc_size_bytes": 1024000}, + "document_annotation": "# Document Title\n\nThis is a sample document with structured content.\n\n## Section 1\n\nContent here with some **bold** text.\n\n| Column 1 | Column 2 |\n|----------|----------|\n| Data 1 | Data 2 |", # noqa: E501 + "pages": [ + { + "index": 0, + "markdown": "# Document Title\n\nThis is a sample document with structured content.\n\n## Section 1\n\nContent here with some **bold** text.\n\n![img-0.jpeg](img-0.jpeg)", # noqa: E501 + "images": [ + { + "id": "img-0.jpeg", + "top_left_x": 28, + "top_left_y": 107, + "bottom_right_x": 330, + "bottom_right_y": 278, + "image_base64": "...", # noqa: E501 + "image_annotation": None, + } + ], + "dimensions": {"dpi": 200, "height": 915, "width": 672}, # noqa: E501 + }, + { + "index": 1, + "markdown": "| Column 1 | Column 2 |\n|----------|----------|\n| Data 1 | Data 2 |", # noqa: E501 + "images": [], + "dimensions": {"dpi": 200, "height": 915, "width": 672}, # noqa: E501 + }, + ], + }, + ) + ) + + def mistral_ocr() -> Any: + return mistral_sync_client.ocr.process( + model="mistral-ocr-2505-completion", + document={"type": "document_url", "document_url": "https://example.com/sample.pdf"}, + include_image_base64=True, + ) + + if use_context_attributes: + with using_attributes( + session_id=session_id, + user_id=user_id, + metadata=metadata, + tags=tags, + prompt_template=prompt_template, + prompt_template_version=prompt_template_version, + prompt_template_variables=prompt_template_variables, + ): + response = mistral_ocr() + else: + response = mistral_ocr() + + # Verify response structure + assert hasattr(response, "model") + assert response.model == "mistral-ocr-2505-completion" + assert hasattr(response, "pages") + assert len(response.pages) == 2 + + spans = in_memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + span = spans[0] + assert span.status.is_ok + assert not span.status.description + assert len(span.events) == 0 + + attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) + + # Check span kind is LLM (OCR uses LLM span kind for now) + assert attributes.pop(OPENINFERENCE_SPAN_KIND) == OpenInferenceSpanKindValues.LLM.value + + # Check input attributes + assert isinstance(attributes.pop(INPUT_VALUE), str) + assert ( + OpenInferenceMimeTypeValues(attributes.pop(INPUT_MIME_TYPE)) + == OpenInferenceMimeTypeValues.JSON + ) + + # Check model name + assert attributes.pop(LLM_MODEL_NAME) == "mistral-ocr-2505-completion" + + # Check OCR-specific attributes (using actual field names from OCRUsageInfo) + assert attributes.pop("ocr.pages_processed") == 2 + assert attributes.pop("ocr.document_size_bytes") == 1024000 + + # Check output structure - document_annotation is the main output + assert isinstance(attributes.pop(OUTPUT_VALUE), str) + + # Check that pages are structured as LLM output messages + assert attributes.pop(f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}") == "assistant" + assert ( + attributes.pop( + f"{LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENTS}.0.{MessageContentAttributes.MESSAGE_CONTENT_TYPE}" + ) + == "text" + ) + assert "Document Title" in str( + attributes.pop( + f"{LLM_OUTPUT_MESSAGES}.0.{MessageAttributes.MESSAGE_CONTENTS}.0.{MessageContentAttributes.MESSAGE_CONTENT_TEXT}" + ) + ) + + # Check second page message + assert attributes.pop(f"{LLM_OUTPUT_MESSAGES}.1.{MESSAGE_ROLE}") == "assistant" + assert ( + attributes.pop( + f"{LLM_OUTPUT_MESSAGES}.1.{MessageAttributes.MESSAGE_CONTENTS}.0.{MessageContentAttributes.MESSAGE_CONTENT_TYPE}" + ) + == "text" + ) + assert "Column 1" in str( + attributes.pop( + f"{LLM_OUTPUT_MESSAGES}.1.{MessageAttributes.MESSAGE_CONTENTS}.0.{MessageContentAttributes.MESSAGE_CONTENT_TEXT}" + ) + ) + + # Check retrieval document structure + assert "Document Title" in str(attributes.pop("retrieval.documents.0.document.content")) + assert '"type": "ocr_page"' in str(attributes.pop("retrieval.documents.0.document.metadata")) + assert "Column 1" in str(attributes.pop("retrieval.documents.1.document.content")) + assert '"page_index": 1' in str(attributes.pop("retrieval.documents.1.document.metadata")) + + # Check context attributes if used + if use_context_attributes: + _check_context_attributes( + attributes, + session_id, + user_id, + metadata, + tags, + prompt_template, + prompt_template_version, + prompt_template_variables, + ) + + +@pytest.mark.parametrize("use_context_attributes", [False, True]) +def test_synchronous_ocr_with_error_emits_span_with_exception( + use_context_attributes: bool, + mistral_sync_client: Mistral, + in_memory_span_exporter: InMemorySpanExporter, + respx_mock: Any, + session_id: str, + user_id: str, + metadata: Dict[str, Any], + tags: List[str], + prompt_template: str, + prompt_template_version: str, + prompt_template_variables: Dict[str, Any], +) -> None: + """Test that OCR errors are properly instrumented.""" + # Mock an error response + respx.post("https://api.mistral.ai/v1/ocr").mock( + return_value=Response( + 400, + json={ + "error": {"type": "invalid_request_error", "message": "Unsupported document format"} + }, + ) + ) + + def mistral_ocr_with_error() -> Any: + return mistral_sync_client.ocr.process( + model="mistral-ocr-2505-completion", + document={"type": "document_url", "document_url": "https://example.com/invalid.xyz"}, + ) + + # Test that the function raises an exception + with pytest.raises(Exception): + if use_context_attributes: + with using_attributes( + session_id=session_id, + user_id=user_id, + metadata=metadata, + tags=tags, + prompt_template=prompt_template, + prompt_template_version=prompt_template_version, + prompt_template_variables=prompt_template_variables, + ): + mistral_ocr_with_error() + else: + mistral_ocr_with_error() + + spans = in_memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + span = spans[0] + + # Check that span recorded the error + assert not span.status.is_ok + assert span.status.description is not None + assert len(span.events) == 1 + + # Check that exception event was recorded + exception_event = span.events[0] + assert exception_event.name == "exception" + + attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) + + # Check basic span attributes + assert attributes.pop(OPENINFERENCE_SPAN_KIND) == OpenInferenceSpanKindValues.LLM.value + assert isinstance(attributes.pop(INPUT_VALUE), str) + + # Check context attributes if used + if use_context_attributes: + _check_context_attributes( + attributes, + session_id, + user_id, + metadata, + tags, + prompt_template, + prompt_template_version, + prompt_template_variables, + ) + + +@pytest.mark.parametrize("document_type", ["document_url", "image_url"]) +def test_ocr_document_types( + document_type: str, + mistral_sync_client: Mistral, + in_memory_span_exporter: InMemorySpanExporter, + respx_mock: Any, +) -> None: + """Test OCR with different document types (PDF vs image).""" + # Mock OCR response + respx.post("https://api.mistral.ai/v1/ocr").mock( + return_value=Response( + 200, + json={ + "model": "mistral-ocr-2505-completion", + "usage_info": {"pages_processed": 1, "doc_size_bytes": 256000}, + "document_annotation": f"# {document_type.title()} Content\n\nProcessed from {document_type}", # noqa: E501 + "pages": [ + { + "index": 0, + "markdown": f"# {document_type.title()} Content\n\nProcessed from {document_type}", # noqa: E501 + "images": [], + "dimensions": {"dpi": 200, "height": 600, "width": 800}, + } + ], + }, + ) + ) + + if document_type == "document_url": + document_spec: Any = { + "type": "document_url", + "document_url": "https://example.com/test.pdf", + } + else: # image_url + document_spec = { + "type": "image_url", + "image_url": "", # noqa: E501 + } + + response = mistral_sync_client.ocr.process( + model="mistral-ocr-2505-completion", document=document_spec + ) + + # Verify response + assert hasattr(response, "model") + assert response.model == "mistral-ocr-2505-completion" + + spans = in_memory_span_exporter.get_finished_spans() + assert len(spans) == 1 + span = spans[0] + assert span.status.is_ok + + attributes = dict(cast(Mapping[str, AttributeValue], span.attributes)) + + # Check that input contains the document type + input_value = attributes.get(INPUT_VALUE) + assert input_value is not None + assert document_type in str(input_value)