diff --git a/backend/adapter_processor_v2/adapter_processor.py b/backend/adapter_processor_v2/adapter_processor.py index 31ad4a383a..1f6564acfa 100644 --- a/backend/adapter_processor_v2/adapter_processor.py +++ b/backend/adapter_processor_v2/adapter_processor.py @@ -6,8 +6,13 @@ from cryptography.fernet import Fernet from django.conf import settings from django.core.exceptions import ObjectDoesNotExist -from platform_settings_v2.platform_auth_service import PlatformAuthenticationService -from tenant_account_v2.organization_member_service import OrganizationMemberService +from platform_settings_v2.platform_auth_service import ( + PlatformAuthenticationService, +) +from rest_framework.exceptions import ValidationError +from tenant_account_v2.organization_member_service import ( + OrganizationMemberService, +) from adapter_processor_v2.constants import AdapterKeys, AllowedDomains from adapter_processor_v2.exceptions import ( @@ -27,7 +32,9 @@ logger = logging.getLogger(__name__) try: - from plugins.subscription.time_trials.subscription_adapter import add_unstract_key + from plugins.subscription.time_trials.subscription_adapter import ( + add_unstract_key, + ) except ImportError: add_unstract_key = None @@ -92,8 +99,8 @@ def get_adapter_data_with_key(adapter_id: str, key_value: str) -> Any: def test_adapter(adapter_id: str, adapter_metadata: dict[str, Any]) -> bool: logger.info(f"Testing adapter: {adapter_id}") try: - adapter_class = Adapterkit().get_adapter_class_by_adapter_id(adapter_id) - + # Defensive copy; don't mutate caller dict + adapter_metadata = dict(adapter_metadata) if adapter_metadata.pop(AdapterKeys.ADAPTER_TYPE) == AdapterKeys.X2TEXT: if ( adapter_metadata.get(AdapterKeys.PLATFORM_PROVIDED_UNSTRACT_KEY) @@ -107,12 +114,22 @@ def test_adapter(adapter_id: str, adapter_metadata: dict[str, Any]) -> bool: platform_key.key ) - adapter_instance = adapter_class(adapter_metadata) + # Validate URLs for this adapter configuration + try: + adapter_instance = AdapterProcessor.validate_adapter_urls( + adapter_id, adapter_metadata + ) + except Exception as e: + # Format error message similar to test adapter API + adapter_name = adapter_metadata.get(AdapterKeys.ADAPTER_NAME, "adapter") + error_detail = f"Error testing '{adapter_name}'. {e!s}" + raise ValidationError(error_detail) from e test_result: bool = adapter_instance.test_connection() return test_result except SdkError as e: raise TestAdapterError( - e, adapter_name=adapter_metadata[AdapterKeys.ADAPTER_NAME] + e, + adapter_name=adapter_metadata.get(AdapterKeys.ADAPTER_NAME, "adapter"), ) @staticmethod @@ -130,6 +147,39 @@ def update_adapter_metadata(adapter_metadata_b: Any, **kwargs) -> Any: return adapter_metadata_b return adapter_metadata_b + @staticmethod + def validate_adapter_urls(adapter_id: str, adapter_metadata: dict) -> Adapter: + """Validate URLs for an adapter configuration without full connection test. + + This method only validates URLs for security (SSRF protection) without + attempting actual network connections. + + Args: + adapter_id: The adapter ID (e.g., "postgres|70ab6cc2...") + adapter_metadata: The adapter configuration metadata + + Returns: + Adapter: The adapter instance if validation passes + + Raises: + AdapterError: If URL validation fails due to security violations + """ + try: + # Get the adapter class + adapterkit = Adapterkit() + adapter_class = adapterkit.get_adapter_class_by_adapter_id(adapter_id) + + # Create a temporary instance just to validate URLs + # Pass validate_urls=True to trigger URL validation + return adapter_class(adapter_metadata, validate_urls=True) + + except Exception as e: + logger.error( + f"URL validation failed for adapter {adapter_id}: {str(e)}", + exc_info=True, + ) + raise + @staticmethod def __fetch_adapters_by_key_value(key: str, value: Any) -> Adapter: """Fetches a list of adapters that have an attribute matching key and diff --git a/backend/adapter_processor_v2/views.py b/backend/adapter_processor_v2/views.py index 36a2db0116..209dd0fb0f 100644 --- a/backend/adapter_processor_v2/views.py +++ b/backend/adapter_processor_v2/views.py @@ -1,7 +1,10 @@ +import json import logging import uuid from typing import Any +from cryptography.fernet import Fernet +from django.conf import settings from django.db import IntegrityError from django.db.models import ProtectedError, QuerySet from django.http import HttpRequest @@ -14,12 +17,15 @@ ) from rest_framework import status from rest_framework.decorators import action +from rest_framework.exceptions import ValidationError from rest_framework.request import Request from rest_framework.response import Response from rest_framework.serializers import ModelSerializer from rest_framework.versioning import URLPathVersioning from rest_framework.viewsets import GenericViewSet, ModelViewSet -from tenant_account_v2.organization_member_service import OrganizationMemberService +from tenant_account_v2.organization_member_service import ( + OrganizationMemberService, +) from utils.filtering import FilterHelper from adapter_processor_v2.adapter_processor import AdapterProcessor @@ -166,75 +172,140 @@ def get_serializer_class( return AdapterListSerializer return AdapterInstanceSerializer - def create(self, request: Any) -> Response: - serializer = self.get_serializer(data=request.data) + def _decrypt_and_validate_metadata(self, adapter_metadata_b: bytes) -> dict[str, Any]: + """Decrypt adapter metadata and validate its format.""" + if not adapter_metadata_b: + raise ValidationError("Missing adapter metadata for validation.") - use_platform_unstract_key = False - adapter_metadata = request.data.get(AdapterKeys.ADAPTER_METADATA) - if adapter_metadata and adapter_metadata.get( - AdapterKeys.PLATFORM_PROVIDED_UNSTRACT_KEY, False - ): - use_platform_unstract_key = True - - serializer.is_valid(raise_exception=True) try: - adapter_type = serializer.validated_data.get(AdapterKeys.ADAPTER_TYPE) + fernet = Fernet(settings.ENCRYPTION_KEY.encode("utf-8")) + decrypted_json = fernet.decrypt(adapter_metadata_b) + decrypted_metadata = json.loads(decrypted_json.decode("utf-8")) - if adapter_type == AdapterKeys.X2TEXT and use_platform_unstract_key: - adapter_metadata_b = serializer.validated_data.get( - AdapterKeys.ADAPTER_METADATA_B - ) - adapter_metadata_b = AdapterProcessor.update_adapter_metadata( - adapter_metadata_b - ) - # Update the validated data with the new adapter_metadata - serializer.validated_data[AdapterKeys.ADAPTER_METADATA_B] = ( - adapter_metadata_b + if not isinstance(decrypted_metadata, dict): + raise ValidationError( + "Invalid adapter metadata format: expected JSON object." ) + return decrypted_metadata + except Exception as e: + raise ValidationError("Invalid adapter metadata.") from e + + def _validate_adapter_urls( + self, adapter_id: str, decrypted_metadata: dict[str, Any] + ) -> None: + """Validate URLs for adapter configuration.""" + try: + AdapterProcessor.validate_adapter_urls(adapter_id, decrypted_metadata) + except Exception as e: + adapter_name = decrypted_metadata.get(AdapterKeys.ADAPTER_NAME, "adapter") + error_detail = f"Error testing '{adapter_name}'. {e!s}" + raise ValidationError(error_detail) from e + + def _check_platform_key_usage(self, request_data: dict[str, Any]) -> bool: + """Check if platform unstract key should be used.""" + adapter_metadata = request_data.get(AdapterKeys.ADAPTER_METADATA) + return bool( + adapter_metadata + and adapter_metadata.get(AdapterKeys.PLATFORM_PROVIDED_UNSTRACT_KEY, False) + ) - instance = serializer.save() - organization_member = OrganizationMemberService.get_user_by_id( - request.user.id + def _update_metadata_for_platform_key( + self, + serializer_validated_data: dict[str, Any], + adapter_type: str, + is_paid_subscription: bool = False, + ) -> None: + """Update adapter metadata when using platform key.""" + if adapter_type == AdapterKeys.X2TEXT: + adapter_metadata_b = serializer_validated_data.get( + AdapterKeys.ADAPTER_METADATA_B ) - - # Check to see if there is a default configured - # for this adapter_type and for the current user - ( - user_default_adapter, - created, - ) = UserDefaultAdapter.objects.get_or_create( - organization_member=organization_member + updated_metadata_b = AdapterProcessor.update_adapter_metadata( + adapter_metadata_b, is_paid_subscription=is_paid_subscription ) + serializer_validated_data[AdapterKeys.ADAPTER_METADATA_B] = updated_metadata_b + + def _set_default_adapter_if_needed( + self, adapter_instance: AdapterInstance, adapter_type: str, user_id: int + ) -> None: + """Set adapter as default if no default exists for this type.""" + organization_member = OrganizationMemberService.get_user_by_id(user_id) + user_default_adapter, _ = UserDefaultAdapter.objects.get_or_create( + organization_member=organization_member + ) - if (adapter_type == AdapterKeys.LLM) and ( - not user_default_adapter.default_llm_adapter - ): - user_default_adapter.default_llm_adapter = instance + # Map adapter types to their default fields + adapter_type_mapping = { + AdapterKeys.LLM: "default_llm_adapter", + AdapterKeys.EMBEDDING: "default_embedding_adapter", + AdapterKeys.VECTOR_DB: "default_vector_db_adapter", + AdapterKeys.X2TEXT: "default_x2text_adapter", + } + + if adapter_type in adapter_type_mapping: + field_name = adapter_type_mapping[adapter_type] + if not getattr(user_default_adapter, field_name): + setattr(user_default_adapter, field_name, adapter_instance) + user_default_adapter.organization_member = organization_member + user_default_adapter.save() + + def _validate_update_metadata( + self, + serializer_validated_data: dict[str, Any], + current_adapter: AdapterInstance, + ) -> tuple[str | None, dict[str, Any] | None]: + """Validate metadata for update operations.""" + if AdapterKeys.ADAPTER_METADATA_B not in serializer_validated_data: + return None, None + + adapter_id = ( + serializer_validated_data.get(AdapterKeys.ADAPTER_ID) + or current_adapter.adapter_id + ) + adapter_metadata_b = serializer_validated_data.get(AdapterKeys.ADAPTER_METADATA_B) - elif (adapter_type == AdapterKeys.EMBEDDING) and ( - not user_default_adapter.default_embedding_adapter - ): - user_default_adapter.default_embedding_adapter = instance - elif (adapter_type == AdapterKeys.VECTOR_DB) and ( - not user_default_adapter.default_vector_db_adapter - ): - user_default_adapter.default_vector_db_adapter = instance - elif (adapter_type == AdapterKeys.X2TEXT) and ( - not user_default_adapter.default_x2text_adapter - ): - user_default_adapter.default_x2text_adapter = instance + if not adapter_id or not adapter_metadata_b: + raise ValidationError("Missing adapter metadata for validation.") - organization_member = OrganizationMemberService.get_user_by_id( - request.user.id - ) - user_default_adapter.organization_member = organization_member + decrypted_metadata = self._decrypt_and_validate_metadata(adapter_metadata_b) + self._validate_adapter_urls(adapter_id, decrypted_metadata) - user_default_adapter.save() + return adapter_id, decrypted_metadata + + def create(self, request: Any) -> Response: + serializer = self.get_serializer(data=request.data) + use_platform_unstract_key = self._check_platform_key_usage(request.data) + + serializer.is_valid(raise_exception=True) + + # Extract and validate metadata + adapter_id = serializer.validated_data.get(AdapterKeys.ADAPTER_ID) + adapter_metadata_b = serializer.validated_data.get(AdapterKeys.ADAPTER_METADATA_B) + decrypted_metadata = self._decrypt_and_validate_metadata(adapter_metadata_b) + + # Validate URLs for security (pre-mutation) + self._validate_adapter_urls(adapter_id, decrypted_metadata) + + try: + adapter_type = serializer.validated_data.get(AdapterKeys.ADAPTER_TYPE) + + # Update metadata if using platform key + if use_platform_unstract_key: + self._update_metadata_for_platform_key( + serializer.validated_data, adapter_type + ) + + # Save the adapter instance + instance = serializer.save() + + # Set as default adapter if needed + self._set_default_adapter_if_needed(instance, adapter_type, request.user.id) except IntegrityError: raise DuplicateAdapterNameError( name=serializer.validated_data.get(AdapterKeys.ADAPTER_NAME) ) + headers = self.get_success_headers(serializer.data) return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers) @@ -346,42 +417,30 @@ def list_of_shared_users(self, request: HttpRequest, pk: Any = None) -> Response def update( self, request: Request, *args: tuple[Any], **kwargs: dict[str, Any] ) -> Response: - # Check if adapter metadata is being updated and contains the platform key flag - use_platform_unstract_key = False - adapter_metadata = request.data.get(AdapterKeys.ADAPTER_METADATA) + use_platform_unstract_key = self._check_platform_key_usage(request.data) + adapter = self.get_object() - if adapter_metadata and adapter_metadata.get( - AdapterKeys.PLATFORM_PROVIDED_UNSTRACT_KEY, False - ): - use_platform_unstract_key = True - logger.error(f"Platform key flag detected: {use_platform_unstract_key}") + # Get serializer and validate data + serializer = self.get_serializer(adapter, data=request.data, partial=True) + serializer.is_valid(raise_exception=True) - # Get the adapter instance for update - adapter = self.get_object() + # Validate metadata if being updated + _, _ = self._validate_update_metadata(serializer.validated_data, adapter) + # Handle platform key updates if use_platform_unstract_key: logger.error("Processing adapter with platform key") - serializer = self.get_serializer(adapter, data=request.data, partial=True) - serializer.is_valid(raise_exception=True) - - # Get adapter_type from validated data (consistent with create method) adapter_type = serializer.validated_data.get(AdapterKeys.ADAPTER_TYPE) logger.error(f"Adapter type from validated data: {adapter_type}") - if adapter_type == AdapterKeys.X2TEXT: - logger.error("Processing X2TEXT adapter with platform key") - adapter_metadata_b = serializer.validated_data.get( - AdapterKeys.ADAPTER_METADATA_B - ) - adapter_metadata_b = AdapterProcessor.update_adapter_metadata( - adapter_metadata_b, is_paid_subscription=True - ) - # Update the validated data with the new adapter_metadata - serializer.validated_data[AdapterKeys.ADAPTER_METADATA_B] = ( - adapter_metadata_b - ) + # Update metadata for platform key usage + self._update_metadata_for_platform_key( + serializer.validated_data, + adapter_type, + is_paid_subscription=True, + ) - # Save the instance with updated metadata + # Save and return updated instance serializer.save() return Response(serializer.data) diff --git a/backend/notification_v2/provider/webhook/webhook.py b/backend/notification_v2/provider/webhook/webhook.py index 37fb7431bf..7934a7abcc 100644 --- a/backend/notification_v2/provider/webhook/webhook.py +++ b/backend/notification_v2/provider/webhook/webhook.py @@ -7,6 +7,7 @@ from backend.celery_service import app as celery_app from notification_v2.enums import AuthorizationType from notification_v2.provider.notification_provider import NotificationProvider +from unstract.sdk.adapters.url_validator import URLValidator logger = logging.getLogger(__name__) @@ -51,6 +52,12 @@ def validate(self): """ if not self.notification.url: raise ValueError("Webhook URL is required.") + + # Validate webhook URL for security + is_valid, error_message = URLValidator.validate_url(self.notification.url) + logger.info("Validating webhook URL.") + if not is_valid: + raise ValueError(f"Webhook URL validation failed: {error_message}") if not self.payload: raise ValueError("Payload is required.") return super().validate() diff --git a/backend/sample.env b/backend/sample.env index dcd79b5a16..0fb453d633 100644 --- a/backend/sample.env +++ b/backend/sample.env @@ -200,3 +200,8 @@ RUNNER_POLLING_INTERVAL_SECONDS=2 # Default: 1800 seconds (30 minutes) # Examples: 900 (15 min), 1800 (30 min), 3600 (60 min) MIN_SCHEDULE_INTERVAL_SECONDS=1800 + +# Whitelisted adapter URLs to allow user to connect to locally hosted adapters. +# Whitelisting 10.68.0.10 to allow frictionless adapter connection to +# managed Postgres for VectorDB +WHITELISTED_ENDPOINTS="10.68.0.10" diff --git a/prompt-service/sample.env b/prompt-service/sample.env index e26e6cbcd2..f70d895086 100644 --- a/prompt-service/sample.env +++ b/prompt-service/sample.env @@ -64,3 +64,7 @@ ADAPTER_LLMW_STATUS_RETRIES=5 # Rentroll Service RENTROLL_SERVICE_HOST=http://unstract-rentroll-service RENTROLL_SERVICE_PORT=5003 + +# Whitelisted adapter URLs to allow user to connect to locally hosted adapters. +# Whitelisting 10.68.0.10 to allow URLs in variable replacement and postprocessor hooks +WHITELISTED_ENDPOINTS="10.68.0.10" diff --git a/prompt-service/src/unstract/prompt_service/controllers/answer_prompt.py b/prompt-service/src/unstract/prompt_service/controllers/answer_prompt.py index 35ad2d5c06..02cb631c5f 100644 --- a/prompt-service/src/unstract/prompt_service/controllers/answer_prompt.py +++ b/prompt-service/src/unstract/prompt_service/controllers/answer_prompt.py @@ -11,7 +11,9 @@ from unstract.prompt_service.exceptions import BadRequest from unstract.prompt_service.helpers.auth import AuthHelper from unstract.prompt_service.helpers.plugin import PluginManager -from unstract.prompt_service.helpers.prompt_ide_base_tool import PromptServiceBaseTool +from unstract.prompt_service.helpers.prompt_ide_base_tool import ( + PromptServiceBaseTool, +) from unstract.prompt_service.helpers.usage import UsageHelper from unstract.prompt_service.services.answer_prompt import AnswerPromptService from unstract.prompt_service.services.rentrolls_extractor.interface import ( @@ -86,15 +88,33 @@ def prompt_processor() -> Any: app.logger.info(f"[{tool_id}] chunk size: {chunk_size}") util = PromptServiceBaseTool(platform_key=platform_key) index = Index(tool=util, run_id=run_id, capture_metrics=True) - if VariableReplacementService.is_variables_present(prompt_text=prompt_text): - prompt_text = VariableReplacementService.replace_variables_in_prompt( - prompt=output, - structured_output=structured_output, - log_events_id=log_events_id, - tool_id=tool_id, - prompt_name=prompt_name, - doc_name=doc_name, + try: + if VariableReplacementService.is_variables_present(prompt_text=prompt_text): + prompt_text = VariableReplacementService.replace_variables_in_prompt( + prompt=output, + structured_output=structured_output, + log_events_id=log_events_id, + tool_id=tool_id, + prompt_name=prompt_name, + doc_name=doc_name, + ) + except BadRequest as e: + app.logger.error( + f"[{tool_id}] Error during variable replacement: {e}", + exc_info=True, + ) + publish_log( + log_events_id, + { + "tool_id": tool_id, + "prompt_key": prompt_name, + "doc_name": doc_name, + }, + LogLevel.ERROR, + RunLevel.RUN, + f"Error during variable replacement: {e}", ) + raise app.logger.info(f"[{tool_id}] Executing prompt: '{prompt_name}'") publish_log( @@ -243,7 +263,9 @@ def prompt_processor() -> Any: # Track token usage by sending to the audit service try: - from unstract.sdk.utils.token_counter import TokenCounter + from unstract.sdk.utils.token_counter import ( + TokenCounter, + ) # Get metrics from the extraction result metrics = extraction_result.get("metrics", {}) diff --git a/prompt-service/src/unstract/prompt_service/helpers/variable_replacement.py b/prompt-service/src/unstract/prompt_service/helpers/variable_replacement.py index 8b10092522..87c7d7a8fe 100644 --- a/prompt-service/src/unstract/prompt_service/helpers/variable_replacement.py +++ b/prompt-service/src/unstract/prompt_service/helpers/variable_replacement.py @@ -6,7 +6,9 @@ from flask import current_app as app from unstract.prompt_service.constants import VariableConstants, VariableType +from unstract.prompt_service.exceptions import BadRequest from unstract.prompt_service.utils.request import HTTPMethod, make_http_request +from unstract.sdk.adapters.url_validator import URLValidator class VariableReplacementHelper: @@ -22,7 +24,9 @@ def replace_static_variable( static_variable_marker_string = "".join(["{{", variable, "}}"]) replaced_prompt: str = VariableReplacementHelper.replace_generic_string_value( - prompt=prompt, variable=static_variable_marker_string, value=output_value + prompt=prompt, + variable=static_variable_marker_string, + value=output_value, ) return replaced_prompt @@ -71,7 +75,12 @@ def identify_variable_type(variable: str) -> VariableType: def replace_dynamic_variable( prompt: str, variable: str, structured_output: dict[str, Any] ) -> str: - url = re.search(VariableConstants.DYNAMIC_VARIABLE_URL_REGEX, variable).group(0) + url_match = re.search(VariableConstants.DYNAMIC_VARIABLE_URL_REGEX, variable) + if not url_match: + app.logger.error(f"No URL found in dynamic variable: {variable}") + return prompt + + url = url_match.group(0) data = re.findall(VariableConstants.DYNAMIC_VARIABLE_DATA_REGEX, variable)[0] output_value = VariableReplacementHelper.check_static_variable_run_status( structure_output=structured_output, variable=data @@ -108,6 +117,15 @@ def fetch_dynamic_variable_value(url: str, data: str) -> Any: # Future versions may include support for # authentication and other input formats. + # Validate URL before making the request + is_valid, error_message = URLValidator.validate_url(url) + if not is_valid: + # app.logger.error( + # f"Invalid or unsafe URL detected: {url} - {error_message}", + # exc_info=True, + # ) + raise BadRequest(f"Invalid or unsafe URL: {url} - {error_message}") + verb: HTTPMethod = HTTPMethod.POST headers = {"Content-Type": "text/plain"} response: Any = make_http_request(verb=verb, url=url, data=data, headers=headers) diff --git a/prompt-service/src/unstract/prompt_service/services/answer_prompt.py b/prompt-service/src/unstract/prompt_service/services/answer_prompt.py index bf92262daa..35e807e5bc 100644 --- a/prompt-service/src/unstract/prompt_service/services/answer_prompt.py +++ b/prompt-service/src/unstract/prompt_service/services/answer_prompt.py @@ -1,8 +1,5 @@ -import ipaddress -import socket from logging import Logger from typing import Any -from urllib.parse import urlparse from flask import current_app as app @@ -17,6 +14,7 @@ repair_json_with_best_structure, ) from unstract.prompt_service.utils.log import publish_log +from unstract.sdk.adapters.url_validator import URLValidator from unstract.sdk.constants import LogLevel from unstract.sdk.exceptions import RateLimitError as SdkRateLimitError from unstract.sdk.exceptions import SdkError @@ -26,58 +24,6 @@ from unstract.sdk.llm import LLM -def _is_safe_public_url(url: str) -> bool: - """Validate webhook URL for SSRF protection. - - Only allows HTTPS and blocks private/loopback/internal addresses. - Resolves all DNS records (A/AAAA) to prevent DNS rebinding attacks. - """ - try: - p = urlparse(url) - if p.scheme not in ("https",): # Only allow HTTPS for security - return False - host = p.hostname or "" - # Block obvious local hosts - if host in ("localhost",): - return False - - addrs: set[str] = set() - # If literal IP, validate directly; else resolve all records (A/AAAA) - try: - ipaddress.ip_address(host) - addrs.add(host) - except ValueError: - try: - for family, _type, _proto, _canonname, sockaddr in socket.getaddrinfo( - host, None, type=socket.SOCK_STREAM - ): - addr = sockaddr[0] - addrs.add(addr) - except Exception: - return False - - if not addrs: - return False - - # Validate all resolved addresses - for addr in addrs: - try: - ip = ipaddress.ip_address(addr) - except ValueError: - return False - if ( - ip.is_private - or ip.is_loopback - or ip.is_link_local - or ip.is_reserved - or ip.is_multicast - ): - return False - return True - except Exception: - return False - - class AnswerPromptService: @staticmethod def extract_variable( @@ -342,23 +288,25 @@ def handle_json( app.logger.warning( "Postprocessing webhook enabled but URL missing; skipping." ) - elif not _is_safe_public_url(webhook_url): - app.logger.warning( - "Postprocessing webhook URL is not allowed; skipping." - ) else: - try: - processed_data, updated_highlight_data = postprocess_data( - parsed_data, - webhook_enabled=True, - webhook_url=webhook_url, - highlight_data=highlight_data, - timeout=60, - ) - except Exception as e: + is_valid, error_message = URLValidator.validate_url(webhook_url) + if not is_valid: app.logger.warning( - f"Postprocessing webhook failed: {e}. Using unprocessed data." + f"Postprocessing webhook URL validation failed: {error_message}; skipping." ) + else: + try: + processed_data, updated_highlight_data = postprocess_data( + parsed_data, + webhook_enabled=True, + webhook_url=webhook_url, + highlight_data=highlight_data, + timeout=60, + ) + except Exception as e: + app.logger.warning( + f"Postprocessing webhook failed: {e}. Using unprocessed data." + ) structured_output[prompt_key] = processed_data diff --git a/prompt-service/src/unstract/prompt_service/services/variable_replacement.py b/prompt-service/src/unstract/prompt_service/services/variable_replacement.py index 6d6184a7ec..c76b0ae434 100644 --- a/prompt-service/src/unstract/prompt_service/services/variable_replacement.py +++ b/prompt-service/src/unstract/prompt_service/services/variable_replacement.py @@ -4,6 +4,7 @@ from unstract.prompt_service.constants import PromptServiceConstants as PSKeys from unstract.prompt_service.constants import RunLevel, VariableType +from unstract.prompt_service.exceptions import BadRequest from unstract.prompt_service.helpers.variable_replacement import ( VariableReplacementHelper, ) @@ -74,6 +75,8 @@ def replace_variables_in_prompt( prompt_text = VariableReplacementService._execute_variable_replacement( prompt_text=prompt_text, variable_map=structured_output ) + except BadRequest: + raise finally: app.logger.info( f"[{tool_id}] Prompt after variable replacement: {prompt_text}"