diff --git a/singlestoredb/config.py b/singlestoredb/config.py index a5e4805de..4b687c55b 100644 --- a/singlestoredb/config.py +++ b/singlestoredb/config.py @@ -444,6 +444,13 @@ environ=['SINGLESTOREDB_EXT_FUNC_TIMEOUT'], ) +register_option( + 'external_function.concurrency_limit', 'int', check_int, 1, + 'Specifies the maximum number of subsets of a batch of rows ' + 'to process simultaneously.', + environ=['SINGLESTOREDB_EXT_FUNC_CONCURRENCY_LIMIT'], +) + # # Debugging options # diff --git a/singlestoredb/functions/decorator.py b/singlestoredb/functions/decorator.py index 687211368..0c5532c24 100644 --- a/singlestoredb/functions/decorator.py +++ b/singlestoredb/functions/decorator.py @@ -103,6 +103,7 @@ def _func( args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, timeout: Optional[int] = None, + concurrency_limit: Optional[int] = None, ) -> UDFType: """Generic wrapper for UDF and TVF decorators.""" @@ -112,6 +113,7 @@ def _func( args=expand_types(args), returns=expand_types(returns), timeout=timeout, + concurrency_limit=concurrency_limit, ).items() if v is not None } @@ -155,6 +157,7 @@ def udf( args: Optional[ParameterType] = None, returns: Optional[ReturnType] = None, timeout: Optional[int] = None, + concurrency_limit: Optional[int] = None, ) -> UDFType: """ Define a user-defined function (UDF). @@ -185,6 +188,10 @@ def udf( timeout : int, optional The timeout in seconds for the UDF execution. If not specified, the global default timeout is used. + concurrency_limit : int, optional + The maximum number of concurrent subsets of rows that will be + processed simultaneously by the UDF. If not specified, + the global default concurrency limit is used. Returns ------- @@ -197,4 +204,5 @@ def udf( args=args, returns=returns, timeout=timeout, + concurrency_limit=concurrency_limit, ) diff --git a/singlestoredb/functions/ext/asgi.py b/singlestoredb/functions/ext/asgi.py index 69b498bd4..61f84ce8e 100755 --- a/singlestoredb/functions/ext/asgi.py +++ b/singlestoredb/functions/ext/asgi.py @@ -53,6 +53,7 @@ from typing import Callable from typing import Dict from typing import Iterable +from typing import Iterator from typing import List from typing import Optional from typing import Sequence @@ -90,18 +91,6 @@ logger = utils.get_logger('singlestoredb.functions.ext.asgi') -# If a number of processes is specified, create a pool of workers -num_processes = max(0, int(os.environ.get('SINGLESTOREDB_EXT_NUM_PROCESSES', 0))) -if num_processes > 1: - try: - from ray.util.multiprocessing import Pool - except ImportError: - from multiprocessing import Pool - func_map = Pool(num_processes).starmap -else: - func_map = itertools.starmap - - async def to_thread( func: Any, /, *args: Any, **kwargs: Dict[str, Any], ) -> Any: @@ -293,6 +282,102 @@ def cancel_on_event( ) +def identity(x: Any) -> Any: + """Identity function.""" + return x + + +def chunked(seq: Sequence[Any], max_chunks: int) -> Iterator[Sequence[Any]]: + """Yield up to max_chunks chunks from seq, splitting as evenly as possible.""" + n = len(seq) + if max_chunks <= 0 or max_chunks > n: + max_chunks = n + chunk_size = (n + max_chunks - 1) // max_chunks # ceil division + for i in range(0, n, chunk_size): + yield seq[i:i + chunk_size] + + +async def run_in_parallel( + func: Callable[..., Any], + params_list: Sequence[Sequence[Any]], + cancel_event: threading.Event, + transformer: Callable[[Any], Any] = identity, +) -> List[Any]: + """" + Run a function in parallel with a limit on the number of concurrent tasks. + + Parameters + ---------- + func : Callable + The function to call in parallel + params_list : Sequence[Sequence[Any]] + The parameters to pass to the function + cancel_event : threading.Event + The event to check for cancellation + transformer : Callable[[Any], Any] + A function to transform the results + + Returns + ------- + List[Any] + The results of the function calls + + """ + limit = get_concurrency_limit(func) + is_async = asyncio.iscoroutinefunction(func) + + async def call_sync(batch: Sequence[Any]) -> Any: + """Loop over batches of parameters and call the sync function.""" + res = [] + for params in batch: + cancel_on_event(cancel_event) + res.append(transformer(func(*params))) + return res + + async def call_async(batch: Sequence[Any]) -> Any: + """Loop over batches of parameters and call the async function.""" + res = [] + for params in batch: + cancel_on_event(cancel_event) + res.append(transformer(await func(*params))) + return res + + async def thread_call(batch: Sequence[Any]) -> Any: + if is_async: + return await call_async(batch) + return await to_thread(lambda: asyncio.run(call_sync(batch))) + + # Create tasks in chunks to limit concurrency + tasks = [thread_call(batch) for batch in chunked(params_list, limit)] + + results = await asyncio.gather(*tasks) + + return list(itertools.chain.from_iterable(results)) + + +def get_concurrency_limit(func: Callable[..., Any]) -> int: + """ + Get the concurrency limit for a function. + + Parameters + ---------- + func : Callable + The function to get the concurrency limit for + + Returns + ------- + int + The concurrency limit for the function + + """ + return max( + 1, func._singlestoredb_attrs.get( # type: ignore + 'concurrency_limit', + get_option('external_function.concurrency_limit'), + ), + ) + + def build_udf_endpoint( func: Callable[..., Any], returns_data_format: str, @@ -315,8 +400,6 @@ def build_udf_endpoint( """ if returns_data_format in ['scalar', 'list']: - is_async = asyncio.iscoroutinefunction(func) - async def do_func( cancel_event: threading.Event, timer: Timer, @@ -324,14 +407,8 @@ async def do_func( rows: Sequence[Sequence[Any]], ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: '''Call function on given rows of data.''' - out = [] async with timer('call_function'): - for row in rows: - cancel_on_event(cancel_event) - if is_async: - out.append(await func(*row)) - else: - out.append(func(*row)) + out = await run_in_parallel(func, rows, cancel_event) return row_ids, list(zip(out)) return do_func @@ -426,8 +503,6 @@ def build_tvf_endpoint( """ if returns_data_format in ['scalar', 'list']: - is_async = asyncio.iscoroutinefunction(func) - async def do_func( cancel_event: threading.Event, timer: Timer, @@ -435,19 +510,13 @@ async def do_func( rows: Sequence[Sequence[Any]], ) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]: '''Call function on given rows of data.''' - out_ids: List[int] = [] - out = [] - # Call function on each row of data async with timer('call_function'): - for i, row in zip(row_ids, rows): - cancel_on_event(cancel_event) - if is_async: - res = await func(*row) - else: - res = func(*row) - out.extend(as_list_of_tuples(res)) - out_ids.extend([row_ids[i]] * (len(out)-len(out_ids))) - return out_ids, out + items = await run_in_parallel( + func, rows, cancel_event, + transformer=as_list_of_tuples, + ) + out = list(itertools.chain.from_iterable(items)) + return [row_ids[0]] * len(out), out return do_func