Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions singlestoredb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,13 @@
environ=['SINGLESTOREDB_EXT_FUNC_TIMEOUT'],
)

register_option(
'external_function.concurrency_limit', 'int', check_int, 1,
'Specifies the maximum number of subsets of a batch of rows '
'to process simultaneously.',
environ=['SINGLESTOREDB_EXT_FUNC_CONCURRENCY_LIMIT'],
)

#
# Debugging options
#
Expand Down
8 changes: 8 additions & 0 deletions singlestoredb/functions/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def _func(
args: Optional[ParameterType] = None,
returns: Optional[ReturnType] = None,
timeout: Optional[int] = None,
concurrency_limit: Optional[int] = None,
) -> UDFType:
"""Generic wrapper for UDF and TVF decorators."""

Expand All @@ -112,6 +113,7 @@ def _func(
args=expand_types(args),
returns=expand_types(returns),
timeout=timeout,
concurrency_limit=concurrency_limit,
).items() if v is not None
}

Expand Down Expand Up @@ -155,6 +157,7 @@ def udf(
args: Optional[ParameterType] = None,
returns: Optional[ReturnType] = None,
timeout: Optional[int] = None,
concurrency_limit: Optional[int] = None,
) -> UDFType:
"""
Define a user-defined function (UDF).
Expand Down Expand Up @@ -185,6 +188,10 @@ def udf(
timeout : int, optional
The timeout in seconds for the UDF execution. If not specified,
the global default timeout is used.
concurrency_limit : int, optional
The maximum number of concurrent subsets of rows that will be
processed simultaneously by the UDF. If not specified,
the global default concurrency limit is used.

Returns
-------
Expand All @@ -197,4 +204,5 @@ def udf(
args=args,
returns=returns,
timeout=timeout,
concurrency_limit=concurrency_limit,
)
139 changes: 104 additions & 35 deletions singlestoredb/functions/ext/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from typing import Callable
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Sequence
Expand Down Expand Up @@ -90,18 +91,6 @@
logger = utils.get_logger('singlestoredb.functions.ext.asgi')


# If a number of processes is specified, create a pool of workers
num_processes = max(0, int(os.environ.get('SINGLESTOREDB_EXT_NUM_PROCESSES', 0)))
if num_processes > 1:
try:
from ray.util.multiprocessing import Pool
except ImportError:
from multiprocessing import Pool
func_map = Pool(num_processes).starmap
else:
func_map = itertools.starmap


async def to_thread(
func: Any, /, *args: Any, **kwargs: Dict[str, Any],
) -> Any:
Expand Down Expand Up @@ -293,6 +282,102 @@ def cancel_on_event(
)


def identity(x: Any) -> Any:
"""Identity function."""
return x


def chunked(seq: Sequence[Any], max_chunks: int) -> Iterator[Sequence[Any]]:
"""Yield up to max_chunks chunks from seq, splitting as evenly as possible."""
n = len(seq)
if max_chunks <= 0 or max_chunks > n:
max_chunks = n
chunk_size = (n + max_chunks - 1) // max_chunks # ceil division
for i in range(0, n, chunk_size):
yield seq[i:i + chunk_size]


async def run_in_parallel(
func: Callable[..., Any],
params_list: Sequence[Sequence[Any]],
cancel_event: threading.Event,
transformer: Callable[[Any], Any] = identity,
) -> List[Any]:
""""
Run a function in parallel with a limit on the number of concurrent tasks.

Parameters
----------
func : Callable
The function to call in parallel
params_list : Sequence[Sequence[Any]]
The parameters to pass to the function
cancel_event : threading.Event
The event to check for cancellation
transformer : Callable[[Any], Any]
A function to transform the results

Returns
-------
List[Any]
The results of the function calls

"""
limit = get_concurrency_limit(func)
is_async = asyncio.iscoroutinefunction(func)

async def call_sync(batch: Sequence[Any]) -> Any:
"""Loop over batches of parameters and call the sync function."""
res = []
for params in batch:
cancel_on_event(cancel_event)
res.append(transformer(func(*params)))
return res

async def call_async(batch: Sequence[Any]) -> Any:
"""Loop over batches of parameters and call the async function."""
res = []
for params in batch:
cancel_on_event(cancel_event)
res.append(transformer(await func(*params)))
return res

async def thread_call(batch: Sequence[Any]) -> Any:
if is_async:
return await call_async(batch)
return await to_thread(lambda: asyncio.run(call_sync(batch)))

# Create tasks in chunks to limit concurrency
tasks = [thread_call(batch) for batch in chunked(params_list, limit)]

results = await asyncio.gather(*tasks)

return list(itertools.chain.from_iterable(results))


def get_concurrency_limit(func: Callable[..., Any]) -> int:
"""
Get the concurrency limit for a function.

Parameters
----------
func : Callable
The function to get the concurrency limit for

Returns
-------
int
The concurrency limit for the function

"""
return max(
1, func._singlestoredb_attrs.get( # type: ignore
'concurrency_limit',
get_option('external_function.concurrency_limit'),
),
)


def build_udf_endpoint(
func: Callable[..., Any],
returns_data_format: str,
Expand All @@ -315,23 +400,15 @@ def build_udf_endpoint(
"""
if returns_data_format in ['scalar', 'list']:

is_async = asyncio.iscoroutinefunction(func)

async def do_func(
cancel_event: threading.Event,
timer: Timer,
row_ids: Sequence[int],
rows: Sequence[Sequence[Any]],
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
'''Call function on given rows of data.'''
out = []
async with timer('call_function'):
for row in rows:
cancel_on_event(cancel_event)
if is_async:
out.append(await func(*row))
else:
out.append(func(*row))
out = await run_in_parallel(func, rows, cancel_event)
return row_ids, list(zip(out))

return do_func
Expand Down Expand Up @@ -426,28 +503,20 @@ def build_tvf_endpoint(
"""
if returns_data_format in ['scalar', 'list']:

is_async = asyncio.iscoroutinefunction(func)

async def do_func(
cancel_event: threading.Event,
timer: Timer,
row_ids: Sequence[int],
rows: Sequence[Sequence[Any]],
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
'''Call function on given rows of data.'''
out_ids: List[int] = []
out = []
# Call function on each row of data
async with timer('call_function'):
for i, row in zip(row_ids, rows):
cancel_on_event(cancel_event)
if is_async:
res = await func(*row)
else:
res = func(*row)
out.extend(as_list_of_tuples(res))
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
return out_ids, out
items = await run_in_parallel(
func, rows, cancel_event,
transformer=as_list_of_tuples,
)
out = list(itertools.chain.from_iterable(items))
return [row_ids[0]] * len(out), out

return do_func

Expand Down