Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion src/unstract/sdk/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from abc import ABC, abstractmethod

from unstract.sdk.adapters.enums import AdapterTypes
from unstract.sdk.adapters.exceptions import AdapterError
from unstract.sdk.adapters.url_validator import URLValidator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -32,7 +34,7 @@ def get_icon() -> str:

@classmethod
def get_json_schema(cls) -> str:
schema_path = getattr(cls, 'SCHEMA_PATH', None)
schema_path = getattr(cls, "SCHEMA_PATH", None)
if schema_path is None:
raise ValueError(f"SCHEMA_PATH not defined for {cls.__name__}")
with open(schema_path) as f:
Expand All @@ -43,6 +45,35 @@ def get_json_schema(cls) -> str:
def get_adapter_type() -> AdapterTypes:
return ""

def get_configured_urls(self) -> list[str]:
"""Return all URLs that this adapter will connect to.

This method should return a list of all URLs that the adapter
uses for external connections. These URLs will be validated
for security before allowing connection attempts.

Returns:
list[str]: List of URLs that will be accessed by this adapter
"""
return []

def _validate_urls(self) -> None:
"""Validate all configured URLs against security rules."""
urls = self.get_configured_urls()

for url in urls:
if not url: # Skip empty/None URLs
continue

is_valid, error_message = URLValidator.validate_url(url)
if not is_valid:
# Use class name as fallback when self.name isn't set yet
adapter_name = getattr(self, "name", self.__class__.__name__)
logger.error(
f"URL validation failed for adapter '{adapter_name}': {error_message}"
)
raise AdapterError(f"URL validation failed: {error_message}")

@abstractmethod
def test_connection(self) -> bool:
"""Override to test connection for a adapter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def __init__(self, settings: dict[str, Any]):
super().__init__("AzureOpenAIEmbedding")
self.config = settings

# Validate URLs BEFORE any network operations
self._validate_urls()

SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json"

@staticmethod
Expand All @@ -48,6 +51,11 @@ def get_provider() -> str:
def get_icon() -> str:
return "/icons/adapter-icons/AzureopenAI.png"

def get_configured_urls(self) -> list[str]:
"""Return all URLs this adapter will connect to."""
endpoint = self.config.get("azure_endpoint")
return [endpoint] if endpoint else []

def get_embedding_instance(self) -> BaseEmbedding:
try:
embedding_batch_size = EmbeddingHelper.get_embedding_batch_size(
Expand Down
8 changes: 8 additions & 0 deletions src/unstract/sdk/adapters/embedding/ollama/src/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def __init__(self, settings: dict[str, Any]):
super().__init__("Ollama")
self.config = settings

# Validate URLs BEFORE any network operations
self._validate_urls()

SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json"

@staticmethod
Expand All @@ -41,6 +44,11 @@ def get_provider() -> str:
def get_icon() -> str:
return "/icons/adapter-icons/ollama.png"

def get_configured_urls(self) -> list[str]:
"""Return all URLs this adapter will connect to."""
base_url = self.config.get("base_url")
return [base_url] if base_url else []

def get_embedding_instance(self) -> BaseEmbedding:
try:
embedding_batch_size = EmbeddingHelper.get_embedding_batch_size(
Expand Down
8 changes: 8 additions & 0 deletions src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def __init__(self, settings: dict[str, Any]):
super().__init__("OpenAI")
self.config = settings

# Validate URLs BEFORE any network operations
self._validate_urls()

SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json"

@staticmethod
Expand All @@ -47,6 +50,11 @@ def get_provider() -> str:
def get_icon() -> str:
return "/icons/adapter-icons/OpenAI.png"

def get_configured_urls(self) -> list[str]:
"""Return all URLs this adapter will connect to."""
api_base = self.config.get("api_base")
return [api_base] if api_base else []

def get_embedding_instance(self) -> BaseEmbedding:
try:
timeout = int(self.config.get(Constants.TIMEOUT, Constants.DEFAULT_TIMEOUT))
Expand Down
9 changes: 8 additions & 1 deletion src/unstract/sdk/adapters/llm/any_scale/src/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from llama_index.core.constants import DEFAULT_NUM_OUTPUTS
from llama_index.core.llms import LLM
from llama_index.llms.anyscale import Anyscale

from unstract.sdk.adapters.exceptions import AdapterError
from unstract.sdk.adapters.llm.constants import LLMKeys
from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter
Expand All @@ -24,6 +23,9 @@ def __init__(self, settings: dict[str, Any]):
super().__init__("AnyScale")
self.config = settings

# Validate URLs BEFORE any network operations
self._validate_urls()

SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json"

@staticmethod
Expand All @@ -46,6 +48,11 @@ def get_provider() -> str:
def get_icon() -> str:
return "/icons/adapter-icons/anyscale.png"

def get_configured_urls(self) -> list[str]:
"""Return all URLs this adapter will connect to."""
api_base = self.config.get(Constants.API_BASE)
return [api_base] if api_base else []

def get_llm_instance(self) -> LLM:
try:
max_tokens = int(self.config.get(Constants.MAX_TOKENS, DEFAULT_NUM_OUTPUTS))
Expand Down
13 changes: 9 additions & 4 deletions src/unstract/sdk/adapters/llm/azure_open_ai/src/azure_open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from llama_index.core.llms import LLM
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.llms.openai.utils import O1_MODELS

from unstract.sdk.adapters.exceptions import AdapterError
from unstract.sdk.adapters.llm.constants import LLMKeys
from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter
Expand All @@ -30,6 +29,9 @@ def __init__(self, settings: dict[str, Any]):
super().__init__("AzureOpenAI")
self.config = settings

# Validate URLs BEFORE any network operations
self._validate_urls()

SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json"

@staticmethod
Expand All @@ -52,6 +54,11 @@ def get_provider() -> str:
def get_icon() -> str:
return "/icons/adapter-icons/AzureopenAI.png"

def get_configured_urls(self) -> list[str]:
"""Return all URLs this adapter will connect to."""
endpoint = self.config.get("azure_endpoint")
return [endpoint] if endpoint else []

def get_llm_instance(self) -> LLM:
max_retries = int(
self.config.get(Constants.MAX_RETRIES, LLMKeys.DEFAULT_MAX_RETRIES)
Expand All @@ -74,9 +81,7 @@ def get_llm_instance(self) -> LLM:
}

if enable_reasoning:
llm_kwargs["reasoning_effort"] = self.config.get(
Constants.REASONING_EFFORT
)
llm_kwargs["reasoning_effort"] = self.config.get(Constants.REASONING_EFFORT)

if model not in O1_MODELS:
llm_kwargs["max_completion_tokens"] = max_tokens
Expand Down
14 changes: 12 additions & 2 deletions src/unstract/sdk/adapters/llm/ollama/src/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from httpx import ConnectError, HTTPStatusError
from llama_index.core.llms import LLM
from llama_index.llms.ollama import Ollama

from unstract.sdk.adapters.exceptions import AdapterError
from unstract.sdk.adapters.llm.constants import LLMKeys
from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter
Expand All @@ -29,6 +28,9 @@ def __init__(self, settings: dict[str, Any]):
super().__init__("Ollama")
self.config = settings

# Validate URLs BEFORE any network operations
self._validate_urls()

SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json"

@staticmethod
Expand All @@ -51,6 +53,11 @@ def get_provider() -> str:
def get_icon() -> str:
return "/icons/adapter-icons/ollama.png"

def get_configured_urls(self) -> list[str]:
"""Return all URLs this adapter will connect to."""
base_url = self.config.get(Constants.BASE_URL)
return [base_url] if base_url else []

def get_llm_instance(self) -> LLM:
try:
llm: LLM = Ollama(
Expand All @@ -60,7 +67,9 @@ def get_llm_instance(self) -> LLM:
self.config.get(Constants.TIMEOUT, LLMKeys.DEFAULT_TIMEOUT)
),
json_mode=False,
context_window=int(self.config.get(Constants.CONTEXT_WINDOW, 3900)),
context_window=int(
self.config.get(Constants.CONTEXT_WINDOW, 3900)
),
temperature=0.01,
)
return llm
Expand All @@ -77,6 +86,7 @@ def get_llm_instance(self) -> LLM:
raise AdapterError(str(exc))

def test_connection(self) -> bool:

try:
llm = self.get_llm_instance()
if not llm:
Expand Down
8 changes: 8 additions & 0 deletions src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def __init__(self, settings: dict[str, Any]):
super().__init__("OpenAI")
self.config = settings

# Validate URLs BEFORE any network operations
self._validate_urls()

SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json"

@staticmethod
Expand All @@ -51,6 +54,11 @@ def get_provider() -> str:
def get_icon() -> str:
return "/icons/adapter-icons/OpenAI.png"

def get_configured_urls(self) -> list[str]:
"""Return all URLs this adapter will connect to."""
api_base = self.config.get("api_base")
return [api_base] if api_base else []

def get_llm_instance(self) -> LLM:
try:
max_tokens = self.config.get(Constants.MAX_TOKENS)
Expand Down
Loading