Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/unstract/sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "v0.77.1"
__version__ = "v0.78.0"


def get_sdk_version() -> str:
Expand Down
34 changes: 33 additions & 1 deletion src/unstract/sdk/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from abc import ABC, abstractmethod

from unstract.sdk.adapters.enums import AdapterTypes
from unstract.sdk.adapters.exceptions import AdapterError
from unstract.sdk.adapters.url_validator import URLValidator

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -32,7 +34,7 @@ def get_icon() -> str:

@classmethod
def get_json_schema(cls) -> str:
schema_path = getattr(cls, 'SCHEMA_PATH', None)
schema_path = getattr(cls, "SCHEMA_PATH", None)
if schema_path is None:
raise ValueError(f"SCHEMA_PATH not defined for {cls.__name__}")
with open(schema_path) as f:
Expand All @@ -43,6 +45,36 @@ def get_json_schema(cls) -> str:
def get_adapter_type() -> AdapterTypes:
return ""

@abstractmethod
def get_configured_urls(self) -> list[str]:
"""Return all URLs that this adapter will connect to.

This method should return a list of all URLs that the adapter
uses for external connections. These URLs will be validated
for security before allowing connection attempts.

Returns:
list[str]: List of URLs that will be accessed by this adapter
"""
return []

def _validate_urls(self) -> None:
"""Validate all configured URLs against security rules."""
urls = self.get_configured_urls()

for url in urls:
if not url: # Skip empty/None URLs
continue

is_valid, error_message = URLValidator.validate_url(url)
if not is_valid:
# Use class name as fallback when self.name isn't set yet
adapter_name = getattr(self, "name", self.__class__.__name__)
logger.error(
f"URL validation failed for adapter '{adapter_name}': {error_message}"
)
raise AdapterError(f"URL validation failed: {error_message}")

@abstractmethod
def test_connection(self) -> bool:
"""Override to test connection for a adapter.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ class Constants:


class AzureOpenAI(EmbeddingAdapter):
def __init__(self, settings: dict[str, Any]):
def __init__(self, settings: dict[str, Any], validate_urls: bool = False):
super().__init__("AzureOpenAIEmbedding")
self.config = settings

# Validate URLs BEFORE any network operations
if validate_urls:
self._validate_urls()

SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json"

@staticmethod
Expand All @@ -48,6 +52,11 @@ def get_provider() -> str:
def get_icon() -> str:
return "/icons/adapter-icons/AzureopenAI.png"

def get_configured_urls(self) -> list[str]:
"""Return all URLs this adapter will connect to."""
endpoint = self.config.get("azure_endpoint")
return [endpoint] if endpoint else []

def get_embedding_instance(self) -> BaseEmbedding:
try:
embedding_batch_size = EmbeddingHelper.get_embedding_batch_size(
Expand Down
11 changes: 10 additions & 1 deletion src/unstract/sdk/adapters/embedding/ollama/src/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@ class Constants:


class Ollama(EmbeddingAdapter):
def __init__(self, settings: dict[str, Any]):
def __init__(self, settings: dict[str, Any], validate_urls: bool = False):
super().__init__("Ollama")
self.config = settings

# Validate URLs BEFORE any network operations
if validate_urls:
self._validate_urls()

SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json"

@staticmethod
Expand All @@ -41,6 +45,11 @@ def get_provider() -> str:
def get_icon() -> str:
return "/icons/adapter-icons/ollama.png"

def get_configured_urls(self) -> list[str]:
"""Return all URLs this adapter will connect to."""
base_url = self.config.get("base_url")
return [base_url] if base_url else []

def get_embedding_instance(self) -> BaseEmbedding:
try:
embedding_batch_size = EmbeddingHelper.get_embedding_batch_size(
Expand Down
11 changes: 10 additions & 1 deletion src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@ class Constants:


class OpenAI(EmbeddingAdapter):
def __init__(self, settings: dict[str, Any]):
def __init__(self, settings: dict[str, Any], validate_urls: bool = False):
super().__init__("OpenAI")
self.config = settings

# Validate URLs BEFORE any network operations
if validate_urls:
self._validate_urls()

SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json"

@staticmethod
Expand All @@ -47,6 +51,11 @@ def get_provider() -> str:
def get_icon() -> str:
return "/icons/adapter-icons/OpenAI.png"

def get_configured_urls(self) -> list[str]:
"""Return all URLs this adapter will connect to."""
api_base = self.config.get("api_base")
return [api_base] if api_base else []

def get_embedding_instance(self) -> BaseEmbedding:
try:
timeout = int(self.config.get(Constants.TIMEOUT, Constants.DEFAULT_TIMEOUT))
Expand Down
12 changes: 10 additions & 2 deletions src/unstract/sdk/adapters/llm/any_scale/src/anyscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from llama_index.core.constants import DEFAULT_NUM_OUTPUTS
from llama_index.core.llms import LLM
from llama_index.llms.anyscale import Anyscale

from unstract.sdk.adapters.exceptions import AdapterError
from unstract.sdk.adapters.llm.constants import LLMKeys
from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter
Expand All @@ -20,10 +19,14 @@ class Constants:


class AnyScaleLLM(LLMAdapter):
def __init__(self, settings: dict[str, Any]):
def __init__(self, settings: dict[str, Any], validate_urls: bool = False):
super().__init__("AnyScale")
self.config = settings

# Validate URLs BEFORE any network operations
if validate_urls:
self._validate_urls()

SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json"

@staticmethod
Expand All @@ -46,6 +49,11 @@ def get_provider() -> str:
def get_icon() -> str:
return "/icons/adapter-icons/anyscale.png"

def get_configured_urls(self) -> list[str]:
"""Return all URLs this adapter will connect to."""
api_base = self.config.get(Constants.API_BASE)
return [api_base] if api_base else []

def get_llm_instance(self) -> LLM:
try:
max_tokens = int(self.config.get(Constants.MAX_TOKENS, DEFAULT_NUM_OUTPUTS))
Expand Down
16 changes: 11 additions & 5 deletions src/unstract/sdk/adapters/llm/azure_open_ai/src/azure_open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from llama_index.core.llms import LLM
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.llms.openai.utils import O1_MODELS

from unstract.sdk.adapters.exceptions import AdapterError
from unstract.sdk.adapters.llm.constants import LLMKeys
from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter
Expand All @@ -26,10 +25,14 @@ class Constants:


class AzureOpenAILLM(LLMAdapter):
def __init__(self, settings: dict[str, Any]):
def __init__(self, settings: dict[str, Any], validate_urls: bool = False):
super().__init__("AzureOpenAI")
self.config = settings

# Validate URLs BEFORE any network operations
if validate_urls:
self._validate_urls()

SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json"

@staticmethod
Expand All @@ -52,6 +55,11 @@ def get_provider() -> str:
def get_icon() -> str:
return "/icons/adapter-icons/AzureopenAI.png"

def get_configured_urls(self) -> list[str]:
"""Return all URLs this adapter will connect to."""
endpoint = self.config.get("azure_endpoint")
return [endpoint] if endpoint else []

def get_llm_instance(self) -> LLM:
max_retries = int(
self.config.get(Constants.MAX_RETRIES, LLMKeys.DEFAULT_MAX_RETRIES)
Expand All @@ -74,9 +82,7 @@ def get_llm_instance(self) -> LLM:
}

if enable_reasoning:
llm_kwargs["reasoning_effort"] = self.config.get(
Constants.REASONING_EFFORT
)
llm_kwargs["reasoning_effort"] = self.config.get(Constants.REASONING_EFFORT)

if model not in O1_MODELS:
llm_kwargs["max_completion_tokens"] = max_tokens
Expand Down
11 changes: 9 additions & 2 deletions src/unstract/sdk/adapters/llm/ollama/src/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from httpx import ConnectError, HTTPStatusError
from llama_index.core.llms import LLM
from llama_index.llms.ollama import Ollama

from unstract.sdk.adapters.exceptions import AdapterError
from unstract.sdk.adapters.llm.constants import LLMKeys
from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter
Expand All @@ -25,10 +24,13 @@ class Constants:


class OllamaLLM(LLMAdapter):
def __init__(self, settings: dict[str, Any]):
def __init__(self, settings: dict[str, Any], validate_urls: bool = False):
super().__init__("Ollama")
self.config = settings

if validate_urls:
self._validate_urls()

SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json"

@staticmethod
Expand All @@ -51,6 +53,11 @@ def get_provider() -> str:
def get_icon() -> str:
return "/icons/adapter-icons/ollama.png"

def get_configured_urls(self) -> list[str]:
"""Return all URLs this adapter will connect to."""
base_url = self.config.get(Constants.BASE_URL)
return [base_url] if base_url else []

def get_llm_instance(self) -> LLM:
try:
llm: LLM = Ollama(
Expand Down
10 changes: 9 additions & 1 deletion src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@ class Constants:


class OpenAILLM(LLMAdapter):
def __init__(self, settings: dict[str, Any]):
def __init__(self, settings: dict[str, Any], validate_urls: bool = False):
super().__init__("OpenAI")
self.config = settings

if validate_urls:
self._validate_urls()

SCHEMA_PATH = f"{os.path.dirname(__file__)}/static/json_schema.json"

@staticmethod
Expand All @@ -51,6 +54,11 @@ def get_provider() -> str:
def get_icon() -> str:
return "/icons/adapter-icons/OpenAI.png"

def get_configured_urls(self) -> list[str]:
"""Return all URLs this adapter will connect to."""
api_base = self.config.get("api_base")
return [api_base] if api_base else []

def get_llm_instance(self) -> LLM:
try:
max_tokens = self.config.get(Constants.MAX_TOKENS)
Expand Down
Loading