From 9b987d48b198184b37e3c5957003f99efbbced18 Mon Sep 17 00:00:00 2001 From: birydrad <> Date: Mon, 8 Sep 2025 14:15:05 +0200 Subject: [PATCH] coroutines support in actors --- CMakeLists.txt | 16 +- assembly/native/build-macos-portable.sh | 2 +- assembly/native/build-macos-shared.sh | 2 +- assembly/native/build-ubuntu-appimages.sh | 2 +- assembly/native/build-ubuntu-portable.sh | 2 +- assembly/native/build-ubuntu-shared.sh | 2 +- assembly/native/build-windows-2019.bat | 2 +- assembly/native/build-windows-2022.bat | 2 +- assembly/native/build-windows.bat | 2 +- tdactor/CMakeLists.txt | 5 + tdactor/benchmark/CMakeLists.txt | 65 ++ tdactor/benchmark/benchmark-coro.cpp | 126 +++ tdactor/benchmark/benchmark.cpp | 4 + tdactor/benchmark/gbench-coro-folly.cpp | 105 ++ tdactor/benchmark/gbench-coro-yaclib.cpp | 189 ++++ tdactor/benchmark/gbench-coro.cpp | 1001 ++++++++++++++++++ tdactor/example/actor-example-coroutines.cpp | 182 ++++ tdactor/td/actor/PromiseFuture.h | 29 + tdactor/td/actor/actor.h | 132 ++- tdactor/td/actor/common.h | 105 +- tdactor/td/actor/core/ActorExecuteContext.h | 3 + tdactor/td/actor/core/CpuWorker.cpp | 50 +- tdactor/td/actor/core/CpuWorker.h | 16 +- tdactor/td/actor/core/Scheduler.cpp | 56 +- tdactor/td/actor/core/Scheduler.h | 7 +- tdactor/td/actor/core/SchedulerContext.h | 5 + tdactor/td/actor/coro.h | 11 + tdactor/td/actor/coro_awaitables.h | 234 ++++ tdactor/td/actor/coro_executor.h | 237 +++++ tdactor/td/actor/coro_task.h | 469 ++++++++ tdactor/td/actor/coro_types.h | 114 ++ tdactor/td/actor/coro_utils.h | 268 +++++ tdactor/test/CMakeLists.txt | 3 + tdactor/test/actors_core.cpp | 12 +- tdactor/test/test-coro.cpp | 877 +++++++++++++++ tdutils/td/utils/Closure.h | 4 +- tdutils/td/utils/Status.h | 29 + 37 files changed, 4206 insertions(+), 164 deletions(-) create mode 100644 tdactor/benchmark/benchmark-coro.cpp create mode 100644 tdactor/benchmark/gbench-coro-folly.cpp create mode 100644 tdactor/benchmark/gbench-coro-yaclib.cpp create mode 100644 tdactor/benchmark/gbench-coro.cpp create mode 100644 tdactor/example/actor-example-coroutines.cpp create mode 100644 tdactor/td/actor/coro.h create mode 100644 tdactor/td/actor/coro_awaitables.h create mode 100644 tdactor/td/actor/coro_executor.h create mode 100644 tdactor/td/actor/coro_task.h create mode 100644 tdactor/td/actor/coro_types.h create mode 100644 tdactor/td/actor/coro_utils.h create mode 100644 tdactor/test/CMakeLists.txt create mode 100644 tdactor/test/test-coro.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index c37f351d8..f042368c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -56,6 +56,7 @@ option(TONLIB_ENABLE_JNI "Use \"ON\" to enable JNI-compatible TonLib API.") option(TON_USE_ASAN "Use \"ON\" to enable AddressSanitizer." OFF) option(TON_USE_TSAN "Use \"ON\" to enable ThreadSanitizer." OFF) option(TON_USE_UBSAN "Use \"ON\" to enable UndefinedBehaviorSanitizer." OFF) +option(TON_USE_COVERAGE "Use \"ON\" to enable code coverage with gcov." OFF) set(TON_ARCH "native" CACHE STRING "Architecture, will be passed to -march=") option(TON_PRINT_BACKTRACE_ON_CRASH "Attempt to print a backtrace when a fatal signal is caught" ON) @@ -251,7 +252,7 @@ elseif (CLANG OR GCC) set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--gc-sections -Wl,--exclude-libs,ALL") endif() set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,--gc-sections") - if (NOT TON_USE_ASAN AND NOT TON_USE_TSAN AND NOT MEMPROF) + if (NOT TON_USE_ASAN AND NOT TON_USE_TSAN AND NOT TON_USE_COVERAGE AND NOT MEMPROF) if (NOT USE_EMSCRIPTEN) set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,--exclude-libs,ALL") endif() @@ -337,6 +338,18 @@ endif() if (TON_USE_UBSAN) add_cxx_compiler_flag("-fsanitize=undefined") endif() +if (TON_USE_COVERAGE) + add_cxx_compiler_flag("-fprofile-arcs") + add_cxx_compiler_flag("-ftest-coverage") + add_cxx_compiler_flag("--coverage") + add_cxx_compiler_flag("-O0") + add_cxx_compiler_flag("-g") + add_cxx_compiler_flag("-fno-inline") + add_cxx_compiler_flag("-fno-inline-small-functions") + add_cxx_compiler_flag("-fno-default-inline") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -fprofile-arcs -ftest-coverage --coverage") + set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -fprofile-arcs -ftest-coverage --coverage") +endif() #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=thread") #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address") #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=undefined") @@ -554,6 +567,7 @@ add_test(test-cells test-cells ${TEST_OPTIONS}) add_test(test-smartcont test-smartcont) add_test(test-net test-net) add_test(test-actors test-tdactor) +add_test(test-actors-coro tdactor/test/test-coro) add_test(test-emulator test-emulator) #BEGIN tonlib diff --git a/assembly/native/build-macos-portable.sh b/assembly/native/build-macos-portable.sh index feb5bf72f..87955fe3f 100644 --- a/assembly/native/build-macos-portable.sh +++ b/assembly/native/build-macos-portable.sh @@ -153,7 +153,7 @@ if [ "$with_tests" = true ]; then lite-client validator-engine-console generate-random-id json2tlo dht-server dht-ping-servers dht-resolve \ http-proxy rldp-http-proxy adnl-proxy create-state create-hardfork tlbc emulator \ test-ed25519 test-bigint test-vm test-fift test-cells test-smartcont \ - test-net test-tdactor test-tdutils test-tonlib-offline test-adnl test-dht test-rldp \ + test-net test-tdactor test-coro test-tdutils test-tonlib-offline test-adnl test-dht test-rldp \ test-rldp2 test-catchain test-fec test-tddb test-db test-validator-session-state test-emulator proxy-liteserver test $? -eq 0 || { echo "Can't compile ton"; exit 1; } else diff --git a/assembly/native/build-macos-shared.sh b/assembly/native/build-macos-shared.sh index f5aa15b4a..592e5ee77 100644 --- a/assembly/native/build-macos-shared.sh +++ b/assembly/native/build-macos-shared.sh @@ -91,7 +91,7 @@ if [ "$with_tests" = true ]; then lite-client validator-engine-console generate-random-id json2tlo dht-server dht-ping-servers dht-resolve \ http-proxy rldp-http-proxy adnl-proxy create-state create-hardfork tlbc emulator \ test-ed25519 test-bigint test-vm test-fift test-cells test-smartcont \ - test-net test-tdactor test-tdutils test-tonlib-offline test-adnl test-dht test-rldp \ + test-net test-tdactor test-coro test-tdutils test-tonlib-offline test-adnl test-dht test-rldp \ test-rldp2 test-catchain test-fec test-tddb test-db test-validator-session-state test-emulator proxy-liteserver test $? -eq 0 || { echo "Can't compile ton"; exit 1; } else diff --git a/assembly/native/build-ubuntu-appimages.sh b/assembly/native/build-ubuntu-appimages.sh index 2cdc6c652..038471468 100644 --- a/assembly/native/build-ubuntu-appimages.sh +++ b/assembly/native/build-ubuntu-appimages.sh @@ -63,7 +63,7 @@ ninja storage-daemon storage-daemon-cli fift func tolk tonlib tonlibjson tonlib- validator-engine lite-client validator-engine-console blockchain-explorer \ generate-random-id json2tlo dht-server http-proxy rldp-http-proxy dht-ping-servers dht-resolve \ adnl-proxy create-state emulator test-ed25519 test-bigint \ - test-vm test-fift test-cells test-smartcont test-net test-tdactor test-tdutils \ + test-vm test-fift test-cells test-smartcont test-net test-tdactor test-coro test-tdutils \ test-tonlib-offline test-adnl test-dht test-rldp test-rldp2 test-catchain \ test-fec test-tddb test-db test-validator-session-state test-emulator proxy-liteserver test $? -eq 0 || { echo "Can't compile ton"; exit 1; } diff --git a/assembly/native/build-ubuntu-portable.sh b/assembly/native/build-ubuntu-portable.sh index a4f0bb233..c384c9c42 100644 --- a/assembly/native/build-ubuntu-portable.sh +++ b/assembly/native/build-ubuntu-portable.sh @@ -137,7 +137,7 @@ ninja storage-daemon storage-daemon-cli fift func tolk tonlib tonlibjson tonlib- validator-engine lite-client validator-engine-console blockchain-explorer \ generate-random-id json2tlo dht-server http-proxy rldp-http-proxy dht-ping-servers dht-resolve \ adnl-proxy create-state emulator test-ed25519 test-bigint \ - test-vm test-fift test-cells test-smartcont test-net test-tdactor test-tdutils \ + test-vm test-fift test-cells test-smartcont test-net test-tdactor test-coro test-tdutils \ test-tonlib-offline test-adnl test-dht test-rldp test-rldp2 test-catchain \ test-fec test-tddb test-db test-validator-session-state test-emulator proxy-liteserver test $? -eq 0 || { echo "Can't compile ton"; exit 1; } diff --git a/assembly/native/build-ubuntu-shared.sh b/assembly/native/build-ubuntu-shared.sh index 6ddbaf885..c866168cc 100644 --- a/assembly/native/build-ubuntu-shared.sh +++ b/assembly/native/build-ubuntu-shared.sh @@ -66,7 +66,7 @@ ninja storage-daemon storage-daemon-cli fift func tolk tonlib tonlibjson tonlib- validator-engine lite-client validator-engine-console blockchain-explorer \ generate-random-id json2tlo dht-server http-proxy rldp-http-proxy dht-ping-servers dht-resolve \ adnl-proxy create-state emulator test-ed25519 test-bigint \ - test-vm test-fift test-cells test-smartcont test-net test-tdactor test-tdutils \ + test-vm test-fift test-cells test-smartcont test-net test-tdactor test-coro test-tdutils \ test-tonlib-offline test-adnl test-dht test-rldp test-rldp2 test-catchain \ test-fec test-tddb test-db test-validator-session-state test-emulator proxy-liteserver test $? -eq 0 || { echo "Can't compile ton"; exit 1; } diff --git a/assembly/native/build-windows-2019.bat b/assembly/native/build-windows-2019.bat index e671d277c..1bc8bc68d 100644 --- a/assembly/native/build-windows-2019.bat +++ b/assembly/native/build-windows-2019.bat @@ -148,7 +148,7 @@ ninja storage-daemon storage-daemon-cli blockchain-explorer fift func tolk tonli tonlib-cli validator-engine lite-client validator-engine-console generate-random-id ^ json2tlo dht-server http-proxy rldp-http-proxy adnl-proxy create-state create-hardfork emulator ^ test-ed25519 test-bigint test-vm test-fift test-cells test-smartcont test-net ^ -test-tdactor test-tdutils test-tonlib-offline test-adnl test-dht test-rldp test-rldp2 test-catchain ^ +test-tdactor test-coro test-tdutils test-tonlib-offline test-adnl test-dht test-rldp test-rldp2 test-catchain ^ test-fec test-tddb test-db test-validator-session-state test-emulator proxy-liteserver dht-ping-servers dht-resolve IF %errorlevel% NEQ 0 ( echo Can't compile TON diff --git a/assembly/native/build-windows-2022.bat b/assembly/native/build-windows-2022.bat index 6ff7166af..88236c778 100644 --- a/assembly/native/build-windows-2022.bat +++ b/assembly/native/build-windows-2022.bat @@ -148,7 +148,7 @@ ninja storage-daemon storage-daemon-cli blockchain-explorer fift func tolk tonli tonlib-cli validator-engine lite-client validator-engine-console generate-random-id ^ json2tlo dht-server http-proxy rldp-http-proxy adnl-proxy create-state create-hardfork emulator ^ test-ed25519 test-bigint test-vm test-fift test-cells test-smartcont test-net ^ -test-tdactor test-tdutils test-tonlib-offline test-adnl test-dht test-rldp test-rldp2 test-catchain ^ +test-tdactor test-coro test-tdutils test-tonlib-offline test-adnl test-dht test-rldp test-rldp2 test-catchain ^ test-fec test-tddb test-db test-validator-session-state test-emulator proxy-liteserver IF %errorlevel% NEQ 0 ( echo Can't compile TON diff --git a/assembly/native/build-windows.bat b/assembly/native/build-windows.bat index f0a65670b..0f5db3019 100644 --- a/assembly/native/build-windows.bat +++ b/assembly/native/build-windows.bat @@ -148,7 +148,7 @@ ninja storage-daemon storage-daemon-cli blockchain-explorer fift func tolk tonli tonlib-cli validator-engine lite-client validator-engine-console generate-random-id ^ json2tlo dht-server http-proxy rldp-http-proxy adnl-proxy create-state create-hardfork emulator ^ test-ed25519 test-bigint test-vm test-fift test-cells test-smartcont test-net ^ -test-tdactor test-tdutils test-tonlib-offline test-adnl test-dht test-rldp test-rldp2 test-catchain ^ +test-tdactor test-coro test-tdutils test-tonlib-offline test-adnl test-dht test-rldp test-rldp2 test-catchain ^ test-fec test-tddb test-db test-validator-session-state test-emulator proxy-liteserver dht-ping-servers dht-resolve IF %errorlevel% NEQ 0 ( echo Can't compile TON diff --git a/tdactor/CMakeLists.txt b/tdactor/CMakeLists.txt index 44580dd59..8328cb1bc 100644 --- a/tdactor/CMakeLists.txt +++ b/tdactor/CMakeLists.txt @@ -15,6 +15,7 @@ set(TDACTOR_SOURCE td/actor/ActorShared.h td/actor/ActorStats.h td/actor/common.h + td/actor/coro.h td/actor/PromiseFuture.h td/actor/MultiPromise.h @@ -52,8 +53,12 @@ add_library(tdactor STATIC ${TDACTOR_SOURCE}) target_include_directories(tdactor PUBLIC $) target_link_libraries(tdactor PUBLIC tdutils) +add_executable(tdactor-example-coroutines example/actor-example-coroutines.cpp) +target_link_libraries(tdactor-example-coroutines PUBLIC tdactor) + # BEGIN-INTERNAL add_subdirectory(benchmark) +add_subdirectory(test) # END-INTERNAL install(TARGETS tdactor EXPORT TdTargets diff --git a/tdactor/benchmark/CMakeLists.txt b/tdactor/benchmark/CMakeLists.txt index 662ae20c4..10360d4e4 100644 --- a/tdactor/benchmark/CMakeLists.txt +++ b/tdactor/benchmark/CMakeLists.txt @@ -15,3 +15,68 @@ if (MSVC) set_property(SOURCE benchmark.cpp APPEND_STRING PROPERTY COMPILE_FLAGS " /wd4457 /wd4316") endif() +add_executable(benchmark-coro benchmark-coro.cpp) +target_link_libraries(benchmark-coro tdactor tdutils) + +# Google Benchmark based coroutine benchmarks +find_package(PkgConfig) +if(PkgConfig_FOUND) + pkg_check_modules(GBENCH benchmark) + if(GBENCH_FOUND) + add_executable(gbench-coro gbench-coro.cpp) + target_link_libraries(gbench-coro tdactor tdutils) + # Use full library paths to avoid conflicts with our "benchmark" executable + target_link_directories(gbench-coro PRIVATE ${GBENCH_LIBRARY_DIRS}) + target_link_libraries(gbench-coro -lbenchmark) + target_include_directories(gbench-coro PRIVATE ${GBENCH_INCLUDE_DIRS}) + target_compile_options(gbench-coro PRIVATE ${GBENCH_CFLAGS_OTHER}) + + message(STATUS "Google Benchmark found via pkg-config - gbench-coro target enabled") + + # Check for Folly + find_package(folly CONFIG REQUIRED) + if(folly_FOUND) + add_executable(gbench-coro-folly gbench-coro-folly.cpp) + target_link_libraries(gbench-coro-folly -lbenchmark) + target_link_directories(gbench-coro-folly PRIVATE ${GBENCH_LIBRARY_DIRS}) + target_include_directories(gbench-coro-folly PRIVATE ${GBENCH_INCLUDE_DIRS}) + target_compile_options(gbench-coro-folly PRIVATE ${GBENCH_CFLAGS_OTHER}) + + find_package(gflags REQUIRED) + target_link_libraries(gbench-coro-folly Folly::folly gflags) + + message(STATUS "Folly found - enabling Folly benchmarks in gbench-coro-folly") + else() + message(STATUS "Folly not found - Folly benchmarks disabled") + endif() + + add_executable(gbench-coro-yaclib gbench-coro-yaclib.cpp) + target_link_libraries(gbench-coro-yaclib -lbenchmark) + target_link_directories(gbench-coro-yaclib PRIVATE ${GBENCH_LIBRARY_DIRS}) + target_include_directories(gbench-coro-yaclib PRIVATE ${GBENCH_INCLUDE_DIRS}) + target_compile_options(gbench-coro-yaclib PRIVATE ${GBENCH_CFLAGS_OTHER}) + + # Check for YACLib + find_package(yaclib CONFIG QUIET) + if(yaclib_FOUND) + target_link_libraries(gbench-coro-yaclib yaclib::yaclib) + message(STATUS "YACLib found - enabling YACLib benchmarks in gbench-coro-yaclib") + else() + message(STATUS "YACLib not found - fetching via FetchContent for gbench-coro-yaclib") + include(FetchContent) + set(YACLIB_FLAGS "CORO" CACHE STRING "" FORCE) + FetchContent_Declare(yaclib + GIT_REPOSITORY https://github.com/YACLib/YACLib.git + GIT_TAG main + ) + FetchContent_MakeAvailable(yaclib) + target_link_libraries(gbench-coro-yaclib yaclib) + endif() + target_compile_definitions(gbench-coro-yaclib PRIVATE YACLIB_CORO=2) + + else() + message(STATUS "Google Benchmark not found via pkg-config - gbench-coro target disabled") + endif() +else() + message(STATUS "PkgConfig not found - gbench-coro target disabled") +endif() \ No newline at end of file diff --git a/tdactor/benchmark/benchmark-coro.cpp b/tdactor/benchmark/benchmark-coro.cpp new file mode 100644 index 000000000..6b06c6a38 --- /dev/null +++ b/tdactor/benchmark/benchmark-coro.cpp @@ -0,0 +1,126 @@ +#include "td/actor/coro.h" +#include "td/actor/actor.h" +#include "td/utils/benchmark.h" + +#include + +using namespace td::actor; + +class BenchmarkDatabase final : public td::actor::Actor { + public: + Task square(size_t x) { + co_return x * x; + } +}; + +class CoroBenchmark final : public td::actor::Actor { + public: + void start_up() override { + db_ = td::actor::create_actor("BenchmarkDatabase").release(); + flow().start_immediate().detach(); + } + + Task run_benchmarks() { + LOG(INFO) << "=== Running benchmarks ==="; + std::vector thread_counts = {1, 10}; + size_t total_ops = 100000; + + for (size_t thread_count : thread_counts) { + size_t ops_per_thread = std::max(1, total_ops / thread_count); + std::vector> tasks; + + auto run_benchmark = [&](const char* name, bool immediate) -> Task { + td::Timer timer; + + for (size_t t = 0; t < thread_count; t++) { + auto task_name = PSTRING() << name << "_" << t; + tasks.push_back(spawn_actor(std::move(task_name), [](auto db, auto ops_per_thread, auto immediate) -> Task { + for (size_t i = 0; i < ops_per_thread; i++) { + if (immediate) { + (void)co_await ask_immediate(db, &BenchmarkDatabase::square, i); + } else { + (void)co_await ask(db, &BenchmarkDatabase::square, i); + } + } + co_return td::Unit(); + }(db_, ops_per_thread, immediate))); + } + + for (auto& task : tasks) { + co_await std::move(task); + } + tasks.clear(); + + auto elapsed = timer.elapsed(); + auto ops_per_sec = total_ops / elapsed; + LOG(INFO) << name << " " << ops_per_thread << " ops: " << elapsed << "s (threads=" << thread_count + << ", " << static_cast(ops_per_sec) << " ops/sec)"; + co_return td::Unit(); + }; + + co_await run_benchmark("Immediate", true); + co_await run_benchmark("Delayed", false); + } + + LOG(INFO) << "=== Optimized benchmarks (direct) ==="; + constexpr size_t single_thread = 1; + size_t ops_count = total_ops / single_thread; + + // Warm up + for (size_t i = 0; i < 1000; i++) { + (void)co_await ask(db_, &BenchmarkDatabase::square, i); + } + + td::Timer timer; + for (size_t i = 0; i < ops_count; i++) { + (void)co_await ask(db_, &BenchmarkDatabase::square, i); + } + auto elapsed = timer.elapsed(); + auto ops_per_sec = ops_count / elapsed; + LOG(INFO) << "Direct delayed " << ops_count << " ops: " << elapsed << "s (" + << static_cast(ops_per_sec) << " ops/sec)"; + + timer = {}; + for (size_t i = 0; i < ops_count; i++) { + (void)co_await ask_immediate(db_, &BenchmarkDatabase::square, i); + } + elapsed = timer.elapsed(); + ops_per_sec = ops_count / elapsed; + LOG(INFO) << "Direct immediate " << ops_count << " ops: " << elapsed << "s (" + << static_cast(ops_per_sec) << " ops/sec)"; + + timer = {}; + for (size_t i = 0; i < ops_count; i++) { + auto local_square = [](size_t x) -> Task { co_return x * x; }; + (void)co_await local_square(i); + } + elapsed = timer.elapsed(); + ops_per_sec = ops_count / elapsed; + LOG(INFO) << "Local coroutine " << ops_count << " ops: " << elapsed << "s (" + << static_cast(ops_per_sec) << " ops/sec)"; + + co_return td::Unit(); + } + + Task flow() { + LOG(INFO) << "Starting benchmarks"; + (void)co_await run_benchmarks(); + LOG(INFO) << "Benchmarks completed"; + td::actor::SchedulerContext::get()->stop(); + stop(); + co_return td::Unit(); + } + + private: + td::actor::ActorId db_; +}; + +int main() { + SET_VERBOSITY_LEVEL(VERBOSITY_NAME(INFO)); + td::actor::Scheduler scheduler({4}); + + scheduler.run_in_context([&] { td::actor::create_actor("CoroBenchmark").release(); }); + + scheduler.run(); + return 0; +} diff --git a/tdactor/benchmark/benchmark.cpp b/tdactor/benchmark/benchmark.cpp index deddcb22c..a7d30373e 100644 --- a/tdactor/benchmark/benchmark.cpp +++ b/tdactor/benchmark/benchmark.cpp @@ -313,6 +313,10 @@ class ActorExecutorBenchmark : public td::Benchmark { //queue.push_back(std::move(ptr)); q.push(ptr, 0); } + void add_token_to_cpu_queue(SchedulerToken token, SchedulerId scheduler_id) override { + SchedulerMessage::Raw *raw = reinterpret_cast(token); + q.push(SchedulerMessage(SchedulerMessage::acquire_t{}, raw), 0); + } void set_alarm_timestamp(const ActorInfoPtr &actor_info_ptr) override { UNREACHABLE(); } diff --git a/tdactor/benchmark/gbench-coro-folly.cpp b/tdactor/benchmark/gbench-coro-folly.cpp new file mode 100644 index 000000000..1b5c80891 --- /dev/null +++ b/tdactor/benchmark/gbench-coro-folly.cpp @@ -0,0 +1,105 @@ +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + + +namespace { + +using folly::coro::Task; +struct DatabaseService { + folly::Executor::KeepAlive<> ex; + + Task query_user(int user_id) { + // Same "latency": cheap CPU work + int sum = 0; + for (int i = 0; i < 100; ++i) { + sum += user_id * i; + } + co_return "user_" + std::to_string(user_id) + "_data_" + std::to_string(sum); + } + + Task check_auth(int user_id) { + co_return (user_id % 2) == 0; // Even IDs are authorized + } +}; + +struct RequestHandler { + DatabaseService& db; + folly::Executor::KeepAlive<> ex; + + + Task handle_request(int request_id) { + const int user_id = request_id % 1000; + + const bool authorized = co_await co_withExecutor(db.ex, db.check_auth(user_id)); + if (!authorized) { + co_return "401 Unauthorized"; + } + + std::string user_data = co_await co_withExecutor(db.ex, db.query_user(user_id)); + co_return "200 OK: " + user_data; + } +}; + +} // namespace + +// Real-world benchmark: HTTP-like request handler with database lookups +// Simulates: Request -> Auth Check -> DB Query -> Response +static void BM_HttpRequestHandler_Folly(benchmark::State& state) { + const int concurrent_requests = static_cast(state.range(0)); + state.SetLabel("HttpHandler_Folly_" + std::to_string(concurrent_requests) + "_requests"); + + // Executor sized to available HW, capped by requested concurrency (avoid oversubscribe) + const int hw = std::max(1u, std::thread::hardware_concurrency()); + const int threads = 4; + folly::CPUThreadPoolExecutor exec(threads); + auto ex = folly::getKeepAliveToken(&exec); + + DatabaseService db{folly::SerialExecutor::create(ex)}; + RequestHandler handler{db, ex}; + + int request_counter = 0; + + while (state.KeepRunningBatch(concurrent_requests)) { + std::vector> requests; + requests.reserve(concurrent_requests); + + // Launch concurrent requests on the executor + for (int i = 0; i < concurrent_requests; ++i) { + requests.emplace_back(co_withExecutor(ex, handler.handle_request(request_counter++))); + } + + // Await all responses concurrently + auto responses = + folly::coro::blockingWait(folly::coro::collectAllRange(std::move(requests))); + + for (auto& r : responses) { + benchmark::DoNotOptimize(r); + } + } +} + +// Register HTTP handler benchmark with various concurrency levels +BENCHMARK(BM_HttpRequestHandler_Folly)->Arg(1)->UseRealTime()->MeasureProcessCPUTime(); +BENCHMARK(BM_HttpRequestHandler_Folly)->Arg(10)->UseRealTime()->MeasureProcessCPUTime(); +BENCHMARK(BM_HttpRequestHandler_Folly)->Arg(100)->UseRealTime()->MeasureProcessCPUTime(); +BENCHMARK(BM_HttpRequestHandler_Folly)->Arg(1000)->UseRealTime()->MeasureProcessCPUTime(); + +int main(int argc, char** argv) { + benchmark::Initialize(&argc, argv); + if (benchmark::ReportUnrecognizedArguments(argc, argv)) + return 1; + benchmark::SetBenchmarkFilter("Http"); + benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/tdactor/benchmark/gbench-coro-yaclib.cpp b/tdactor/benchmark/gbench-coro-yaclib.cpp new file mode 100644 index 000000000..c32e417bc --- /dev/null +++ b/tdactor/benchmark/gbench-coro-yaclib.cpp @@ -0,0 +1,189 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +struct DatabaseService { + DatabaseService(yaclib::IExecutorPtr executor) : strand_(executor) { + } + yaclib::Strand strand_; + + yaclib::Task check_auth(int user_id) { + co_await On(strand_); + // Simulate auth check + co_return (user_id % 2) == 0; // Even IDs are authorized + } +}; + +static void BM_HttpRequestHandler_Yaclib(benchmark::State& state) { + const int concurrent_requests = static_cast(state.range(0)); + state.SetLabel("HttpHandler_" + std::to_string(concurrent_requests) + "_requests"); + + // Thread pool backing the "DB"; strand enforces single-coro entry to DB section + auto db_pool = yaclib::MakeFairThreadPool(10); + auto db_strand = yaclib::MakeStrand(db_pool); + + auto query_user = [&](int user_id) -> yaclib::Task { + co_await On(*db_strand); + // Simulate DB latency with a small computation + int sum = 0; + for (int i = 0; i < 100; ++i) { + sum += user_id * i; + } + std::stringstream ss; + ss << "user_" << user_id << "_data_" << sum; + co_return ss.str(); + }; + auto check_auth = [&](int user_id) -> yaclib::Task { + co_await On(*db_strand); + // Simulate auth check + co_return (user_id % 2) == 0; // Even IDs are authorized + }; + + std::vector executors; + for (size_t i = 0; i < concurrent_requests; ++i) { + executors.emplace_back(yaclib::MakeStrand(db_pool)); + } + // "RequestHandler" + auto handle_request = [&](int request_id) -> yaclib::Task { + co_await On(*executors[request_id % executors.size()]); + const int user_id = request_id % 1000; + + const bool authorized = co_await check_auth(user_id); + if (!authorized) + co_return std::string{"401 Unauthorized"}; + + const std::string user_data = co_await query_user(user_id); + co_return std::string{"200 OK: "} + user_data; + }; + + auto run_test = [&]() -> yaclib::Task { + int request_counter = 0; + while (state.KeepRunningBatch(concurrent_requests)) { + std::vector> batch; + batch.reserve(concurrent_requests); + + for (int i = 0; i < concurrent_requests; ++i) { + // Convert lazy Task -> started Future + batch.emplace_back(handle_request(request_counter++).ToFuture()); + } + + // Wait for all; keeps concurrency + auto all = co_await yaclib::WhenAll(batch.begin(), batch.end()); + benchmark::DoNotOptimize(all); + } + co_return 0; + }; + + benchmark::DoNotOptimize(run_test().Get()); + + db_pool->Stop(); + db_pool->Wait(); +} + +// Register with various concurrency levels +BENCHMARK(BM_HttpRequestHandler_Yaclib)->Arg(1)->UseRealTime()->MeasureProcessCPUTime(); +BENCHMARK(BM_HttpRequestHandler_Yaclib)->Arg(10)->UseRealTime()->MeasureProcessCPUTime(); +BENCHMARK(BM_HttpRequestHandler_Yaclib)->Arg(100)->UseRealTime()->MeasureProcessCPUTime(); +BENCHMARK(BM_HttpRequestHandler_Yaclib)->Arg(1000)->UseRealTime()->MeasureProcessCPUTime(); + +static void BM_AskYaclib(benchmark::State& state) { + auto pool = yaclib::MakeFairThreadPool(10); + const int num_tasks = static_cast(state.range(0)); + std::vector executors; + for (int i = 0; i < num_tasks; ++i) { + executors.emplace_back(yaclib::MakeStrand(pool)); + } + + auto get = [&](size_t i) -> yaclib::Task { + co_await On(*executors[i % executors.size()]); + co_return 42; + }; + + // Set a clean label for this benchmark + auto run_test = [&]() -> yaclib::Task { + int request_counter = 0; + while (state.KeepRunningBatch(num_tasks)) { + if (num_tasks == 1) { + int result = co_await get(0); + benchmark::DoNotOptimize(result); + } else { + std::vector> tasks; + tasks.reserve(num_tasks); + for (int i = 0; i < num_tasks; ++i) { + auto task = get(i); + tasks.push_back(std::move(task)); + } + for (auto& task : tasks) { + co_await std::move(task); + } + } + } + co_return 0; + }; + + benchmark::DoNotOptimize(run_test().Get()); + + pool->Stop(); + pool->Wait(); +} + +// Register all combinations - the label is set inside the benchmark +// Single task +BENCHMARK(BM_AskYaclib)->Args({1})->UseRealTime()->MeasureProcessCPUTime(); +BENCHMARK(BM_AskYaclib)->Args({10})->UseRealTime()->MeasureProcessCPUTime(); +BENCHMARK(BM_AskYaclib)->Args({100})->UseRealTime()->MeasureProcessCPUTime(); + +yaclib::Task simple_task() { + co_return 42; +} +static void BM_TaskAwait(benchmark::State& state) { + auto db_pool = yaclib::MakeFairThreadPool(1); + auto run_test = [&]() -> yaclib::Task { + int sum = 0; + while (state.KeepRunning()) { + sum += co_await simple_task(); + } + benchmark::DoNotOptimize(sum); + co_return 0; + }; + + benchmark::DoNotOptimize(run_test().Get()); + + db_pool->Stop(); + db_pool->Wait(); +} + +// Register with various concurrency levels +BENCHMARK(BM_TaskAwait)->UseRealTime()->MeasureProcessCPUTime(); + +int main(int argc, char** argv) { + benchmark::Initialize(&argc, argv); + if (benchmark::ReportUnrecognizedArguments(argc, argv)) { + return 1; + } + benchmark::RunSpecifiedBenchmarks(); + return 0; +} diff --git a/tdactor/benchmark/gbench-coro.cpp b/tdactor/benchmark/gbench-coro.cpp new file mode 100644 index 000000000..bc029046d --- /dev/null +++ b/tdactor/benchmark/gbench-coro.cpp @@ -0,0 +1,1001 @@ +#include +#include "td/actor/coro.h" +#include "td/actor/actor.h" + +#include +#include +#include +#include + +#if defined(__APPLE__) || defined(__linux__) +#include +#elif defined(_WIN32) +#include +#else +#include +#endif + +using namespace td::actor; + +struct SchedulerGuard { + Scheduler sched{std::vector{Scheduler::NodeInfo{10}}}; + + template + void run_until_done(td::Slice name, Task task) { + auto done = std::make_shared>(false); + td::actor::ActorOwn<> actor_id; + + sched.run_in_context([&] { + actor_id = td::actor::create_actor(name); + task.set_executor(Executor::on_actor(actor_id.get())); + + auto wrapped_task = [](auto done, Task task) -> Task { + auto r = co_await std::move(task).wrap(); + r.ensure(); + done->store(true, std::memory_order_release); + co_return td::Unit{}; + }(done, std::move(task)); + + (void)std::move(wrapped_task).start(); + }); + + while (!done->load(std::memory_order_acquire)) { + sched.run(0.001); + } + } +}; + +static SchedulerGuard& guard() { + static SchedulerGuard g; + return g; +} + +static inline double process_cpu_seconds() { +#if defined(_WIN32) + FILETIME creation, exit, kernel, user; + if (GetProcessTimes(GetCurrentProcess(), &creation, &exit, &kernel, &user)) { + auto to_seconds = [](const FILETIME& ft) { + ULARGE_INTEGER uli{.LowPart = ft.dwLowDateTime, .HighPart = ft.dwHighDateTime}; + return uli.QuadPart * 1e-7; // 100ns → seconds + }; + return to_seconds(kernel) + to_seconds(user); + } + return 0.0; +#elif defined(__APPLE__) || defined(__linux__) + rusage ru{}; + if (getrusage(RUSAGE_SELF, &ru) == 0) { + auto to_seconds = [](const timeval& tv) { return tv.tv_sec + tv.tv_usec * 1e-6; }; + return to_seconds(ru.ru_utime) + to_seconds(ru.ru_stime); + } + return 0.0; +#else + return static_cast(std::clock()) / CLOCKS_PER_SEC; +#endif +} + +template +void coro_benchmark(benchmark::State& state, F&& benchmark_code) { + const auto real_start = std::chrono::steady_clock::now(); + const double cpu_start = process_cpu_seconds(); + + auto task = benchmark_code(state); + guard().run_until_done(state.name(), std::move(task)); + + const double cpu_elapsed = process_cpu_seconds() - cpu_start; + const double real_elapsed = std::chrono::duration(std::chrono::steady_clock::now() - real_start).count(); + const double iterations = static_cast(state.iterations()); + + if (iterations > 0 && cpu_elapsed > 0 && real_elapsed > 0) { + state.counters["cpu_time"] = benchmark::Counter(cpu_elapsed / iterations); + state.counters["cpu_speed"] = benchmark::Counter(iterations / cpu_elapsed, benchmark::Counter::kIsRate); + state.counters["real_time"] = benchmark::Counter(real_elapsed / iterations); + state.counters["real_speed"] = benchmark::Counter(iterations / real_elapsed, benchmark::Counter::kIsRate); + state.counters.erase("items_per_second"); + } +} + +class BenchActor : public td::actor::core::Actor { + public: + Task compute_task(int x) { + co_return x * 2; + } + + td::Result compute_sync(int x) { + return x * 3; + } + + void compute_promise(int x, td::Promise promise) { + promise.set_value(x * 4); + } +}; + +Task simple_task() { + co_return 42; +} + +struct TestAwaitable { + int value; + bool is_ready{false}; + + bool await_ready() noexcept { + return is_ready; + } + std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept { + return h; + } + td::Result await_resume() noexcept { + return value; + } +}; + +static void BM_RawTaskAwait(benchmark::State& state) { + coro_benchmark(state, [&](auto& state) -> Task { + co_await detach_from_actor(); + td::int64 sum = 0; + while (state.KeepRunning()) { + sum += (co_await SkipAwaitTransform{TestAwaitable{42}}).move_as_ok(); + } + benchmark::DoNotOptimize(sum); + co_return td::Unit(); + }); +} +BENCHMARK(BM_RawTaskAwait)->UseRealTime()->MeasureProcessCPUTime(); + +static void BM_DelayedTaskAwait(benchmark::State& state) { + coro_benchmark(state, [&](auto& state) -> Task { + co_await detach_from_actor(); + td::int64 sum = 0; + while (state.KeepRunning()) { + sum += (co_await SkipAwaitTransform{simple_task()}).move_as_ok(); + } + benchmark::DoNotOptimize(sum); + co_return td::Unit(); + }); +} +BENCHMARK(BM_DelayedTaskAwait)->UseRealTime()->MeasureProcessCPUTime(); + +static void BM_StartedTaskAwait(benchmark::State& state) { + coro_benchmark(state, [&](auto& state) -> Task { + co_await detach_from_actor(); + td::int64 sum = 0; + while (state.KeepRunning()) { + sum += (co_await SkipAwaitTransform{simple_task().start_immediate()}).move_as_ok(); + } + benchmark::DoNotOptimize(sum); + + co_return td::Unit(); + }); +} +BENCHMARK(BM_StartedTaskAwait)->UseRealTime()->MeasureProcessCPUTime(); + +static void BM_ScheduledTaskAwait(benchmark::State& state) { + coro_benchmark(state, [&](auto& state) -> Task { + co_await detach_from_actor(); + td::int64 sum = 0; + while (state.KeepRunning()) { + sum += (co_await SkipAwaitTransform{simple_task().start()}).move_as_ok(); + } + benchmark::DoNotOptimize(sum); + + co_return td::Unit(); + }); +} +BENCHMARK(BM_ScheduledTaskAwait)->UseRealTime()->MeasureProcessCPUTime(); + +enum class ResumeMethod { Raw, Pass, Try }; +enum class ResumeLocation { Actor, Scheduler, Any }; +enum class AwaitableState { Ready, Suspended }; + +static void BM_AwaitThenResume(benchmark::State& state) { + coro_benchmark(state, [](auto& state) -> Task { + const auto method = static_cast(state.range(0)); + const auto location = static_cast(state.range(1)); + const auto awaitable = static_cast(state.range(2)); + + // Set a clean label for this benchmark + const char* awaitable_name[] = {"Ready", "Suspended"}; + const char* method_names[] = {"Raw", "Pass", "Try"}; + const char* location_names[] = {"Actor", "Scheduler", "Any"}; + auto label = PSTRING() << method_names[static_cast(method)] << "_" + << location_names[static_cast(location)] << "_" + << awaitable_name[static_cast(awaitable)]; + state.SetLabel(label); + // Create an actor to run on if needed + td::actor::ActorOwn actor; + if (location == ResumeLocation::Actor) { + actor = td::actor::create_actor("bench_actor"); + } + + TestAwaitable aw{42, awaitable == AwaitableState::Ready}; + Executor executor; + if (location == ResumeLocation::Actor) { + executor = Executor::on_actor(actor.get()); + } else if (location == ResumeLocation::Scheduler) { + executor = Executor::on_scheduler(); + } else { + executor = Executor::on_any(); + } + co_await resume_on(executor); + + td::int64 sum = 0; + td::int32 total_iterations = 0; + + while (state.KeepRunning()) { + int value = 0; + + switch (method) { + case ResumeMethod::Raw: { + // Raw awaitable without any resume_on wrapper + value = (co_await SkipAwaitTransform{aw}).move_as_ok(); + break; + } + case ResumeMethod::Pass: { + value = (co_await SkipAwaitTransform{wrap_and_resume_on_current(aw)}).move_as_ok(); + break; + } + case ResumeMethod::Try: { + value = (co_await SkipAwaitTransform{unwrap_and_resume_on_current(aw)}); + break; + } + } + + sum += value; + total_iterations++; + } + + CHECK(state.iterations() == total_iterations); + + benchmark::DoNotOptimize(sum); + co_return td::Unit(); + }); +} + +// Register all combinations +static void ApplyAwaitThenResumeArgs(benchmark::internal::Benchmark* b) { + for (auto location : {ResumeLocation::Actor, ResumeLocation::Scheduler, ResumeLocation::Any}) { + for (auto awaitable : {AwaitableState::Ready, AwaitableState::Suspended}) { + for (auto method : {ResumeMethod::Pass, ResumeMethod::Try, ResumeMethod::Raw}) { + b->Args({static_cast(method), static_cast(location), static_cast(awaitable)}); + } + } + } +} +BENCHMARK(BM_AwaitThenResume)->Apply(ApplyAwaitThenResumeArgs)->UseRealTime()->MeasureProcessCPUTime(); + +// Simple benchmarks +static void BM_TaskCreation(benchmark::State& state) { + coro_benchmark(state, [](auto& state) -> Task { + while (state.KeepRunning()) { + auto task = []() -> Task { co_return 42; }(); + benchmark::DoNotOptimize(task); + // Task is not started - just measuring creation overhead + } + co_return td::Unit(); + }); +} +BENCHMARK(BM_TaskCreation)->UseRealTime()->MeasureProcessCPUTime(); + +static void BM_SimpleCompute(benchmark::State& state) { + coro_benchmark(state, [](auto& state) -> Task { + while (state.KeepRunning()) { + auto task = []() -> Task { + int sum = 0; + for (int j = 0; j < 100; j++) { + sum += j; + } + co_return sum; + }(); + // Task will auto-start when awaited + auto result = co_await std::move(task); + benchmark::DoNotOptimize(result); + } + co_return td::Unit(); + }); +} +BENCHMARK(BM_SimpleCompute)->UseRealTime()->MeasureProcessCPUTime(); + +static void BM_TaskChain(benchmark::State& state) { + coro_benchmark(state, [](auto& state) -> Task { + while (state.KeepRunning()) { + auto task1 = []() -> Task { co_return 10; }(); + auto task2 = [](Task task1) -> Task { + auto v = co_await std::move(task1); + co_return v * 2; + }(std::move(task1)); + auto result = co_await std::move(task2); + benchmark::DoNotOptimize(result); + } + co_return td::Unit(); + }); +} +BENCHMARK(BM_TaskChain)->UseRealTime()->MeasureProcessCPUTime(); + +static void BM_ErrorHandling(benchmark::State& state) { + coro_benchmark(state, [](auto& state) -> Task { + while (state.KeepRunning()) { + auto error_task = []() -> Task { co_return td::Status::Error("test error"); }(); + auto result = co_await std::move(error_task).wrap(); + int value = result.is_error() ? 0 : result.ok(); + benchmark::DoNotOptimize(value); + } + co_return td::Unit(); + }); +} +BENCHMARK(BM_ErrorHandling)->UseRealTime()->MeasureProcessCPUTime(); + +static void BM_SpawnCoroutineOld(benchmark::State& state) { + coro_benchmark(state, [](auto& state) -> Task { + while (state.KeepRunning()) { + auto result = co_await spawn_actor("test", []() -> Task { co_return 42; }()); + benchmark::DoNotOptimize(result); + } + + co_return td::Unit(); + }); +} +BENCHMARK(BM_SpawnCoroutineOld)->UseRealTime()->MeasureProcessCPUTime(); + +// Benchmark with multiple operations per iteration +static void BM_BatchTaskCreation(benchmark::State& state) { + coro_benchmark(state, [](auto& state) -> Task { + while (state.KeepRunning()) { + auto task = []() -> Task { co_return 42; }(); + benchmark::DoNotOptimize(task); + } + co_return td::Unit(); + }); +} +BENCHMARK(BM_BatchTaskCreation)->UseRealTime()->MeasureProcessCPUTime(); + +// Concurrent tasks +static void BM_ConcurrentTasks(benchmark::State& state) { + coro_benchmark(state, [](auto& state) -> Task { + const int num_tasks_int = static_cast(state.range(0)); + while (state.KeepRunningBatch(num_tasks_int)) { + std::vector> tasks; + tasks.reserve(num_tasks_int); + for (int i = 0; i < num_tasks_int; ++i) { + tasks.emplace_back([](int i) -> Task { co_return i * 2; }(i).start()); + } + int total = 0; + for (auto& task : tasks) { + auto result = co_await std::move(task); + total += result; + } + benchmark::DoNotOptimize(total); + } + co_return td::Unit(); + }); +} +BENCHMARK(BM_ConcurrentTasks)->RangeMultiplier(2)->Range(1, 64)->UseRealTime()->MeasureProcessCPUTime(); + +// Memory pattern +static void BM_MemoryPattern(benchmark::State& state) { + coro_benchmark(state, [](auto& state) -> Task { + const int num_tasks = static_cast(state.range(0)); + while (state.KeepRunningBatch(num_tasks)) { + std::vector> tasks; + tasks.reserve(num_tasks); + for (int i = 0; i < num_tasks; ++i) { + tasks.emplace_back([i]() -> Task { co_return i; }()); + } + int sum = 0; + for (auto& task : tasks) { + auto result = co_await std::move(task); + sum += result; + } + benchmark::DoNotOptimize(sum); + } + co_return td::Unit(); + }); +} +BENCHMARK(BM_MemoryPattern)->RangeMultiplier(4)->Range(1, 256)->UseRealTime()->MeasureProcessCPUTime(); + +// Unified benchmark for ask operations +enum class AskMethod { Task, TaskWrap, Promise, Sync, Call, TaskNew }; +enum class AskMode { Scheduled, Immediate }; + +static void BM_Ask(benchmark::State& state) { + coro_benchmark(state, [](auto& state) -> Task { + + const auto method = static_cast(state.range(0)); + const auto mode = static_cast(state.range(1)); + const int num_tasks = static_cast(state.range(2)); + + // Set a clean label for this benchmark + const char* method_names[] = {"Task", "TaskTry", "Promise", "Sync", "Call", "New"}; + const char* mode_names[] = {"Scheduled", "Immediate"}; + auto label = PSTRING() << method_names[static_cast(method)] << "_" << mode_names[static_cast(mode)] << "_" + << num_tasks; + state.SetLabel(label); + td::int32 total_tasks = 0; + std::vector> actors; + for (int i = 0; i < num_tasks; ++i) { + actors.push_back(td::actor::create_actor("bench_actor")); + + // just ensure that the following ask_immediate will be immediate + while (true) { + auto task = ask_immediate(actors.back(), &BenchActor::compute_task, 42); + if (task.await_ready()) { + break; + } + } + } + co_await Yield{}; + + while (state.KeepRunningBatch(num_tasks)) { + if (num_tasks == 1) { + // Single task - direct execution + int result = 0; + + if (mode == AskMode::Immediate) { + switch (method) { + case AskMethod::Task: { + result = co_await ask_immediate(actors[0], &BenchActor::compute_task, 42); + break; + } + case AskMethod::TaskWrap: { + result = (co_await ask_immediate(actors[0], &BenchActor::compute_task, 42).wrap()).move_as_ok(); + break; + } + case AskMethod::TaskNew: { + result = co_await ask_new_immediate(actors[0], &BenchActor::compute_task, 42); + break; + } + case AskMethod::Promise: { + result = co_await ask_immediate(actors[0], &BenchActor::compute_promise, 42); + break; + } + case AskMethod::Sync: { + result = co_await ask_immediate(actors[0], &BenchActor::compute_sync, 42); + break; + } + case AskMethod::Call: { + result = co_await actors[0].get_actor_unsafe().compute_task(42); + break; + } + } + } else { + switch (method) { + case AskMethod::Task: { + result = co_await ask(actors[0], &BenchActor::compute_task, 42); + break; + } + case AskMethod::TaskWrap: { + result = (co_await ask(actors[0], &BenchActor::compute_task, 42).wrap()).ok(); + break; + } + case AskMethod::TaskNew: { + result = co_await ask_new(actors[0], &BenchActor::compute_task, 42); + break; + } + case AskMethod::Promise: { + result = co_await ask(actors[0], &BenchActor::compute_promise, 42); + break; + } + case AskMethod::Sync: { + result = co_await ask(actors[0], &BenchActor::compute_sync, 42); + break; + } + case AskMethod::Call: { + UNREACHABLE(); + } + } + } + + total_tasks++; + benchmark::DoNotOptimize(result); + } else { + // Multiple tasks - launch concurrently + std::vector> tasks; + tasks.reserve(num_tasks); + + // Launch all tasks + for (int i = 0; i < num_tasks; ++i) { + if (mode == AskMode::Immediate) { + switch (method) { + case AskMethod::Task: + case AskMethod::TaskWrap: + tasks.emplace_back(ask_immediate(actors[i], &BenchActor::compute_task, 42 + i)); + break; + case AskMethod::TaskNew: + tasks.emplace_back(ask_new_immediate(actors[i], &BenchActor::compute_task, 42 + i)); + break; + case AskMethod::Promise: + tasks.emplace_back(ask_immediate(actors[i], &BenchActor::compute_promise, 42 + i)); + break; + case AskMethod::Sync: + tasks.emplace_back(ask_immediate(actors[i], &BenchActor::compute_sync, 42 + i)); + break; + case AskMethod::Call: { + UNREACHABLE(); + } + } + CHECK(tasks.back().await_ready()); + } else if (mode == AskMode::Scheduled) { + switch (method) { + case AskMethod::Task: + case AskMethod::TaskWrap: + tasks.emplace_back(ask(actors[i], &BenchActor::compute_task, 42 + i)); + break; + case AskMethod::TaskNew: + tasks.emplace_back(ask_new(actors[i], &BenchActor::compute_task, 42 + i)); + break; + case AskMethod::Promise: + tasks.emplace_back(ask(actors[i], &BenchActor::compute_promise, 42 + i)); + break; + case AskMethod::Sync: + tasks.emplace_back(ask(actors[i], &BenchActor::compute_sync, 42 + i)); + break; + case AskMethod::Call: { + UNREACHABLE(); + } + } + } else { + UNREACHABLE(); + } + } + + // Await all results + int total = 0; + for (auto& task : tasks) { + if (method == AskMethod::TaskWrap) { + total += (co_await std::move(task).wrap()).ok(); + } else { + auto result = co_await std::move(task); + total += result; + } + total_tasks++; + } + benchmark::DoNotOptimize(total); + } + } + + CHECK(state.iterations() == total_tasks); + co_return td::Unit(); + }); +} + +// Register all combinations - the label is set inside the benchmark +static void ApplyBMAskArgs(benchmark::internal::Benchmark* b) { + for (auto method : {AskMethod::TaskNew, AskMethod::TaskWrap, AskMethod::Task, AskMethod::Promise, AskMethod::Sync}) { + for (auto mode : {AskMode::Scheduled, AskMode::Immediate}) { + for (int n : {1, 10, 100}) { + b->Args({static_cast(method), static_cast(mode), n}); + } + } + } + // Call method is only meaningful for single-task immediate mode + b->Args({static_cast(AskMethod::Call), static_cast(AskMode::Immediate), 1}); +} +BENCHMARK(BM_Ask)->Apply(ApplyBMAskArgs)->UseRealTime()->MeasureProcessCPUTime()->Repetitions(10000); + +// Benchmark send_closure with promise callback using Worker pattern +static void BM_SendClosureWorker(benchmark::State& state) { + struct Worker : public td::actor::core::Actor { + public: + Worker(benchmark::State& state, td::Promise promise, bool immediate, int num_tasks) + : state_(state) + , promise_(std::move(promise)) + , immediate_(immediate) + , num_actors_(num_tasks) + , tasks_completed_(0) { + for (int i = 0; i < num_actors_; i++) { + childs_.push_back(td::actor::create_actor("bench_actor")); + } + } + + private: + benchmark::State& state_; + td::Promise promise_; + std::vector> childs_; + bool immediate_; + int num_actors_; + int tasks_completed_; + int total_ = 0; + + void loop() override { + if (state_.KeepRunningBatch(num_actors_)) { + tasks_completed_ = 0; + total_ = 0; + for (int i = 0; i < num_actors_; ++i) { + auto promise = td::promise_send_closure(actor_id(this), &Worker::done, i); + if (immediate_) { + send_closure_immediate(childs_[i], &BenchActor::compute_promise, 42 + i, std::move(promise)); + } else { + send_closure(childs_[i], &BenchActor::compute_promise, 42 + i, std::move(promise)); + } + } + } else { + promise_.set_value(7); + } + } + + void done(int task_id, td::Result result) { + if (result.is_ok()) { + total_ += result.ok(); + } + tasks_completed_++; + benchmark::DoNotOptimize(total_); + + if (tasks_completed_ == num_actors_) { + loop(); + } + } + }; + + coro_benchmark(state, [](auto& state) -> Task { + const bool immediate = static_cast(state.range(0)); + const int num_tasks = static_cast(state.range(1)); + + // Set a clean label for this benchmark + state.SetLabel(PSTRING() << "SendClosure_" << (immediate ? "Immediate" : "Scheduled") << "_" << num_tasks); + + auto [task, promise] = StartedTask::make_bridge(); + auto td_promise = + td::Promise([p = std::move(promise)](td::Result r) mutable { p.set_result(std::move(r)); }); + auto worker = td::actor::create_actor("worker", state, std::move(td_promise), immediate, num_tasks); + auto result = co_await std::move(task); + benchmark::DoNotOptimize(result); + + co_return td::Unit(); + }); +} + +// Register send_closure benchmarks - the label is set inside the benchmark +static void ApplySendClosureArgs(benchmark::internal::Benchmark* b) { + for (int immediate : {0, 1}) { + for (int n : {1, 10, 100}) { + b->Args({immediate, n}); + } + } +} +BENCHMARK(BM_SendClosureWorker)->Apply(ApplySendClosureArgs)->UseRealTime()->MeasureProcessCPUTime(); + +// Real-world benchmark: HTTP-like request handler with database lookups +// Simulates: Request -> Auth Check -> DB Query -> Response +static void BM_HttpRequestHandler(benchmark::State& state) { + // Simulated database service + struct DatabaseService : public td::actor::core::Actor { + Task query_user(int user_id) { + // Simulate DB latency with a small computation + int sum = 0; + for (int i = 0; i < 100; ++i) { + sum += user_id * i; + } + co_return PSTRING() << "user_" << user_id << "_data_" << sum; + } + + Task check_auth(int user_id) { + // Simulate auth check + co_return (user_id % 2) == 0; // Even IDs are authorized + } + }; + + // Request handler service + struct RequestHandler { + static Task handle_request(ActorId db, int request_id) { + int user_id = request_id % 1000; + + // Step 1: Check authorization + auto authorized = co_await ask_immediate(db, &DatabaseService::check_auth, user_id); + if (!authorized) { + co_return "401 Unauthorized"; + } + + // Step 2: Query database + auto user_data = co_await ask_immediate(db, &DatabaseService::query_user, user_id); + + // Step 3: Process and return response + co_return PSTRING() << "200 OK: " << user_data; + } + }; + + coro_benchmark(state, [](auto& state) -> Task { + const int concurrent_requests = static_cast(state.range(0)); + state.SetLabel(PSTRING() << "HttpHandler_" << concurrent_requests << "_requests"); + + auto db = td::actor::create_actor("database"); + + int request_counter = 0; + + while (state.KeepRunningBatch(concurrent_requests)) { + std::vector> requests; + requests.reserve(concurrent_requests); + + // Launch concurrent requests + for (int i = 0; i < concurrent_requests; ++i) { + // TODO or lazy coroutine? + auto task = RequestHandler::handle_request(db.get(), request_counter).start(); + requests.emplace_back(std::move(task)); + } + + // Await all responses + for (auto& request : requests) { + auto response = co_await std::move(request); + benchmark::DoNotOptimize(response); + } + } + + co_return td::Unit(); + }); +} + +// Register HTTP handler benchmark with various concurrency levels +BENCHMARK(BM_HttpRequestHandler)->RangeMultiplier(10)->Range(1, 1000)->UseRealTime()->MeasureProcessCPUTime(); + +static void BM_HttpRequestHandlerOld(benchmark::State& state) { + // Simulated database service + struct DatabaseService : public td::actor::core::Actor { + td::Result query_user(int user_id) { + // Simulate DB latency with a small computation + int sum = 0; + for (int i = 0; i < 100; ++i) { + sum += user_id * i; + } + return PSTRING() << "user_" << user_id << "_data_" << sum; + } + + td::Result check_auth(int user_id) { + // Simulate auth check + return (user_id % 2) == 0; // Even IDs are authorized + } + }; + + // Request handler service + struct RequestHandlerOld : public td::actor::core::Actor { + RequestHandlerOld(td::actor::ActorId db, td::Promise promise, td::int32 request_id) + : db_(db), promise_(std::move(promise)), user_id_(request_id % 1000) { + } + + void start_up() override { + // Step 1: Check authorization + send_closure_immediate(db_, &DatabaseService::check_auth, user_id_, + td::promise_send_closure(actor_id(this), &RequestHandlerOld::on_authorized)); + } + + void on_authorized(td::Result r_authorized) { + if (r_authorized.is_error()) { + promise_.set_error(r_authorized.move_as_error()); + return stop(); + } + auto authorized = r_authorized.ok(); + if (!authorized) { + promise_.set_value("401 anauthorized"); + return stop(); + } + + send_closure_immediate(db_, &DatabaseService::query_user, user_id_, + td::promise_send_closure(actor_id(this), &RequestHandlerOld::on_user)); + } + + void on_user(td::Result r_user_data) { + if (r_user_data.is_error()) { + promise_.set_error(r_user_data.move_as_error()); + return stop(); + } + auto user_data = r_user_data.ok(); + promise_.set_value(PSTRING() << "200 OK: " << user_data); + return stop(); + } + + private: + td::actor::ActorId db_; + td::Promise promise_; + td::int32 user_id_{}; + }; + + coro_benchmark(state, [](auto& state) -> Task { + const int concurrent_requests = static_cast(state.range(0)); + state.SetLabel(PSTRING() << "HttpHandlerOld_" << concurrent_requests << "_requests"); + + auto db = td::actor::create_actor("database"); + + int request_counter = 0; + + while (state.KeepRunningBatch(concurrent_requests)) { + std::vector> requests; + requests.reserve(concurrent_requests); + + // Launch concurrent requests + for (int i = 0; i < concurrent_requests; ++i) { + auto [task, promise] = StartedTask::make_bridge(); + td::actor::create_actor("handler", db.get(), std::move(promise), ++request_counter) + .release(); + requests.emplace_back(std::move(task)); + } + + // Await all responses + for (auto& request : requests) { + auto response = co_await std::move(request); + benchmark::DoNotOptimize(response); + } + } + + co_return td::Unit(); + }); +} + +// Register HTTP handler benchmark with various concurrency levels +BENCHMARK(BM_HttpRequestHandlerOld)->RangeMultiplier(10)->Range(1, 1000)->UseRealTime()->MeasureProcessCPUTime(); + +// Concurrent Pub-Sub benchmark: Publishers -> Broker -> Subscribers +// Multiple Publisher actors produce messages concurrently, Broker fans out to all Subscribers +static void BM_PubSubConcurrent(benchmark::State& state) { + struct Message { + std::string payload; + }; + + struct Subscriber : public td::actor::core::Actor { + explicit Subscriber(int id) : id_(id) { + } + + void process(Message m) { + int sum = 0; + for (char c : m.payload) { + sum += static_cast(c); + } + total_++; + benchmark::DoNotOptimize(sum); + } + + td::Result get_delivered_count() { + return total_; + } + + private: + int id_; + td::int64 total_{0}; + }; + + struct Broker : public td::actor::core::Actor { + void subscribe(td::actor::ActorId sub) { + subscribers_.push_back(sub); + } + + Task publish(Message m) { + for (auto& sub : subscribers_) { + send_closure(sub, &Subscriber::process, m); + } + co_return static_cast(subscribers_.size()); + } + + private: + std::vector> subscribers_; + }; + + struct Publisher : public td::actor::core::Actor { + Publisher(td::actor::ActorId broker, int id) : broker_(broker), id_(id) { + } + + Task produce(int count) { + int delivered_total = 0; + for (int j = 0; j < count; ++j) { + Message m{PSTRING() << "msg_" << id_ << "_" << j}; + auto delivered = co_await ask(broker_, &Broker::publish, std::move(m)); + delivered_total += delivered; + } + co_return delivered_total; + } + + private: + td::actor::ActorId broker_; + int id_; + }; + + coro_benchmark(state, [](auto& state) -> Task { + const int num_publishers = static_cast(state.range(0)); + const int num_subscribers = static_cast(state.range(1)); + const int num_brokers = static_cast(state.range(2)); + constexpr int messages_per_publisher = 10; + state.SetLabel(PSTRING() << "PubSubConcurrent_P" << num_publishers << "_S" << num_subscribers << "_B" + << num_brokers); + + // Create brokers (shards) + std::vector> brokers; + brokers.reserve(num_brokers); + for (int b = 0; b < num_brokers; ++b) { + brokers.push_back(td::actor::create_actor("broker")); + } + + // Create subscribers and subscribe them round-robin across brokers + std::vector> subscribers; + subscribers.reserve(num_subscribers); + for (int i = 0; i < num_subscribers; ++i) { + auto sub = td::actor::create_actor("subscriber", i); + for (auto& broker : brokers) { + send_closure(broker, &Broker::subscribe, sub.get()); + } + subscribers.push_back(std::move(sub)); + } + + // Create publishers and assign each to a broker (round-robin) + std::vector> publishers; + publishers.reserve(num_publishers); + for (int p = 0; p < num_publishers; ++p) { + auto& broker = brokers[p % num_brokers]; + publishers.push_back(td::actor::create_actor("publisher", broker.get(), p)); + } + + const td::int64 total_messages = num_publishers * messages_per_publisher * num_subscribers; + td::int64 iteration_count = 0; + while (state.KeepRunningBatch(total_messages)) { + iteration_count++; + std::vector> tasks; + tasks.reserve(num_publishers); + for (auto& pub : publishers) { + tasks.emplace_back(ask(pub, &Publisher::produce, messages_per_publisher)); + } + + int delivered_sum = 0; + for (auto& task : tasks) { + delivered_sum += co_await std::move(task); + } + benchmark::DoNotOptimize(delivered_sum); + } + + td::int64 total_delivered = 0; + const td::int64 expected_per_subscriber = iteration_count * num_publishers * messages_per_publisher; + for (auto& subscriber : subscribers) { + while (true) { + auto delivered_count = co_await ask(subscriber, &Subscriber::get_delivered_count); + if (delivered_count != expected_per_subscriber) { + LOG(ERROR) << "Subscriber delivered " << delivered_count << " != expected " << expected_per_subscriber; + continue; + } + total_delivered += delivered_count; + break; + } + } + CHECK(state.iterations() == total_delivered); + + co_return td::Unit(); + }); +} + +static void ApplyPubSubArgs(benchmark::internal::Benchmark* b) { + constexpr int combos[][3] = {{1, 10, 1}, {10, 10, 1}, {10, 100, 1}, {100, 100, 1}, {10, 100, 4}, {100, 100, 4}}; + for (const auto& combo : combos) { + b->Args({combo[0], combo[1], combo[2]}); + } +} +BENCHMARK(BM_PubSubConcurrent)->Apply(ApplyPubSubArgs)->UseRealTime()->MeasureProcessCPUTime(); + +static void BM_ConcurrentAsks(benchmark::State& state) { + coro_benchmark(state, [](auto& state) -> Task { + const int num_actors = static_cast(state.range(0)); + std::vector> actors; + actors.reserve(num_actors); + for (int i = 0; i < num_actors; ++i) { + actors.push_back(td::actor::create_actor("bench_actor")); + } + + while (state.KeepRunningBatch(num_actors)) { + std::vector> tasks; + tasks.reserve(num_actors); + + for (auto& actor : actors) { + tasks.emplace_back(ask(actor, &BenchActor::compute_task, 42)); + } + + for (auto& task : tasks) { + auto result = co_await std::move(task); + benchmark::DoNotOptimize(result); + } + } + + co_return td::Unit(); + }); +} +BENCHMARK(BM_ConcurrentAsks)->RangeMultiplier(4)->Range(1, 64)->UseRealTime()->MeasureProcessCPUTime(); + +int main(int argc, char** argv) { + benchmark::Initialize(&argc, argv); + if (benchmark::ReportUnrecognizedArguments(argc, argv)) { + return 1; + } + benchmark::SetBenchmarkFilter("BM_Ask/5/1/10/"); + //benchmark::SetBenchmarkFilter("BM_Ask"); + benchmark::RunSpecifiedBenchmarks(); + return 0; +} \ No newline at end of file diff --git a/tdactor/example/actor-example-coroutines.cpp b/tdactor/example/actor-example-coroutines.cpp new file mode 100644 index 000000000..e41603e60 --- /dev/null +++ b/tdactor/example/actor-example-coroutines.cpp @@ -0,0 +1,182 @@ +#include "absl/status/status.h" +#include "td/actor/coro.h" +#include "td/actor/actor.h" +#include "td/utils/SharedSlice.h" +#include "td/utils/port/sleep.h" + +#include +#include +#include +#include +#include +#include + +using namespace td::actor; + +Task example_create() { + LOG(INFO) << "Detach"; + Task value = []() -> Task { + LOG(FATAL) << "This line will not be executed"; + co_return 17; + }(); + value.detach(); + + LOG(INFO) << "Simple co_await"; + // Task will be started after co await has been called + Task value2 = []() -> Task { co_return 17; }(); + CHECK(17 == co_await std::move(value2)); + + LOG(INFO) << "start_immediate than co_await"; + // Task is started immediately + StartedTask value3 = []() -> Task { + td::usleep_for(1000000); + co_return 17; + }() + .start_immediate(); + CHECK(value3.await_ready()); + CHECK(17 == (co_await std::move(value3))); + + LOG(INFO) << "start than co_await"; + // Task is started on scheduler + StartedTask value4 = []() -> Task { + td::usleep_for(1000000); + co_return 17; + }() + .start(); + CHECK(!value4.await_ready()); + CHECK(17 == (co_await std::move(value4))); + + StartedTask value5 = spawn_actor("worker", []() -> Task { + // This code will be run on some actor + // Main reason to use it is actor statistics + co_return 17; + }()); + CHECK(17 == (co_await std::move(value5))); + co_return td::Unit(); +} + +Task example_communicate() { + LOG(INFO) << "Communicate with actor"; + struct Worker : public Actor { + int square(int x) { + return x * x; + } + Task square_task(int x) { + co_return square(x); + } + void square_promise(int x, td::Promise promise) { + send_closure(actor_id(this), &Worker::square_task, x, std::move(promise)); + } + }; + auto worker = create_actor("worker"); + + StartedTask value6 = ask(worker, &Worker::square, 17); + CHECK(289 == (co_await std::move(value6))); + StartedTask value7 = ask(worker, &Worker::square_promise, 17); + CHECK(289 == (co_await std::move(value7))); + StartedTask value8 = ask(worker, &Worker::square_task, 17); + CHECK(289 == (co_await std::move(value8))); + co_return td::Unit(); +} + +Task task_error() { + co_return td::Status::Error("test error"); +} +td::Result result_error() { + return td::Status::Error("test error"); +} +Task pass_task_error() { + co_await task_error(); + co_return 17; +} +Task pass_result_error() { + co_await result_error(); + co_return 17; +} + +Task example_error_handling() { + // Error handling + (co_await pass_task_error().wrap()).ensure_error(); + (co_await pass_result_error().wrap()).ensure_error(); + (co_await result_error().wrap()).ensure_error(); + + co_return td::Unit(); +} + +Task example_actor() { + struct TaskActor : public Actor { + TaskActor(td::Promise promise) : promise_(std::move(promise)) { + + } + void start_up() override { + // it is usual actor all coroutines create FROM actor, will be executed ON actor + run().start().detach(); + } + Task run() { + state_ = 19; + finish(); + co_return td::Unit(); + } + private: + td::Promise promise_; + int state_ {17}; + + void finish() { + promise_.set_result(state_); + stop(); + } + }; + + auto [task, promise] = StartedTask::make_bridge(); + auto task_actor = create_actor("task_actor", std::move(promise)); + CHECK(19 == (co_await std::move(task))); + co_return td::Unit(); + +} + +Task example_all() { + std::vector> v; + int n = std::thread::hardware_concurrency(); + for (int i = 0; i < n; i++) { + v.push_back([](int i) -> Task { + td::usleep_for(1000000); + co_return i* i; + }(i) + .start()); + } + auto vv = co_await all(std::move(v)); + for (int i = 0; i < n; i++) { + CHECK(vv[i] == i * i); + } + co_return td::Unit(); +} + +Task run_all_examples() { + co_await example_create(); + co_await example_communicate(); + co_await example_error_handling(); + co_await example_actor(); + co_await example_all(); + co_return td::Unit(); +} + +Task example() { + LOG(INFO) << "Start example coroutine"; + + (co_await run_all_examples().wrap()).ensure(); + + LOG(INFO) << "Finish example coroutine and stop scheduler"; + td::actor::SchedulerContext::get()->stop(); + co_return td::Unit(); +} + +int main() { + SET_VERBOSITY_LEVEL(VERBOSITY_NAME(INFO)); + td::actor::Scheduler scheduler({std::thread::hardware_concurrency()}); + + scheduler.run_in_context([&] { (void)example().start(); }); + + scheduler.run(); + LOG(INFO) << "DONE"; + return 0; +} \ No newline at end of file diff --git a/tdactor/td/actor/PromiseFuture.h b/tdactor/td/actor/PromiseFuture.h index 2b8890d3f..91cb916cc 100644 --- a/tdactor/td/actor/PromiseFuture.h +++ b/tdactor/td/actor/PromiseFuture.h @@ -518,4 +518,33 @@ std::pair, Future> make_promise_future() { return std::make_pair(std::move(promise), std::move(future)); } +template +inline constexpr bool always_false = false; + +template +struct is_result : std::false_type {}; +template +struct is_result> : std::true_type {}; + +template +inline constexpr bool is_result_v = is_result::value; + +template +constexpr decltype(auto) connect(L &&l, R &&r) noexcept { + if constexpr (is_result_v>) { + if (r.is_error()) { + connect(std::forward(l), r.move_as_error()); + } else { + connect(std::forward(l), r.move_as_ok()); + } + } else if constexpr (requires { custom_connect(std::forward(l), std::forward(r)); }) { + // ADL will find overloads defined in the namespaces of L or R + return custom_connect(std::forward(l), std::forward(r)); + } else if constexpr (requires { std::forward(l)(std::forward(r)); }) { + return std::forward(l)(std::forward(r)); + } else { + static_assert(always_false, "no matching apply overload"); + } +} + } // namespace td diff --git a/tdactor/td/actor/actor.h b/tdactor/td/actor/actor.h index 1f4f6e994..12721fca3 100644 --- a/tdactor/td/actor/actor.h +++ b/tdactor/td/actor/actor.h @@ -26,84 +26,85 @@ namespace td { namespace actor { template -TD_WARN_UNUSED_RESULT ActorOwn create_actor(ActorOptions options, ArgsT &&... args) { +TD_WARN_UNUSED_RESULT ActorOwn create_actor(ActorOptions options, ArgsT &&...args) { return ActorOwn(ActorId::create(options, std::forward(args)...)); } template -TD_WARN_UNUSED_RESULT ActorOwn create_actor(Slice name, ArgsT &&... args) { +TD_WARN_UNUSED_RESULT ActorOwn create_actor(Slice name, ArgsT &&...args) { return ActorOwn(ActorId::create(ActorOptions().with_name(name), std::forward(args)...)); } -#define SEND_CLOSURE_LATER 1 -#ifndef SEND_CLOSURE_LATER +namespace internal { -template , - size_t argument_count = member_function_argument_count(), - std::enable_if_t with_promise = false> -void send_closure(ActorIdT &&actor_id, FunctionT function, ArgsT &&... args) { +template > +void send_closure_dispatch(ActorIdT &&actor_id, FunctionT function, ArgsT &&...args) { using ActorT = typename std::decay_t::ActorT; static_assert(std::is_base_of::value, "unsafe send_closure"); + constexpr size_t argument_count = member_function_argument_count(); + ActorIdT id = std::forward(actor_id); - detail::send_closure(id.as_actor_ref(), function, std::forward(args)...); + if constexpr (argument_count == sizeof...(ArgsT)) { + if constexpr (Later) { + detail::send_closure_later(id.as_actor_ref(), function, std::forward(args)...); + } else { + detail::send_closure(id.as_actor_ref(), function, std::forward(args)...); + } + } else { + auto closure = call_n_arguments( + [&function](auto &&...nargs) { + if constexpr (Later) { + return create_delayed_closure(function, std::forward(nargs)...); + } else { + return create_immediate_closure(function, std::forward(nargs)...); + } + }, + std::forward(args)...); + auto promise = get_last_argument(std::forward(args)...); + if constexpr (Later) { + detail::send_closure_with_promise_later(id.as_actor_ref(), std::move(closure), std::move(promise)); + } else { + detail::send_closure_with_promise(id.as_actor_ref(), std::move(closure), std::move(promise)); + } + } } +} // namespace internal + template , size_t argument_count = member_function_argument_count(), - std::enable_if_t with_promise = true> -void send_closure(ActorIdT &&actor_id, FunctionT function, ArgsT &&... args) { - using ActorT = typename std::decay_t::ActorT; - static_assert(std::is_base_of::value, "unsafe send_closure"); - - ActorIdT id = std::forward(actor_id); - detail::send_closure_with_promise(id.as_actor_ref(), - call_n_arguments( - [&function](auto &&... nargs) { - return create_immediate_closure(function, - std::forward(nargs)...); - }, - std::forward(args)...), - get_last_argument(std::forward(args)...)); + std::enable_if_t with_promise = false> +void send_closure_immediate(ActorIdT &&actor_id, FunctionT function, ArgsT &&...args) { + internal::send_closure_dispatch(std::forward(actor_id), function, std::forward(args)...); } -#else +template , + size_t argument_count = member_function_argument_count(), + std::enable_if_t with_promise = true> +void send_closure_immediate(ActorIdT &&actor_id, FunctionT function, ArgsT &&...args) { + internal::send_closure_dispatch(std::forward(actor_id), function, std::forward(args)...); +} template , size_t argument_count = member_function_argument_count(), std::enable_if_t with_promise = false> -void send_closure(ActorIdT &&actor_id, FunctionT function, ArgsT &&... args) { - using ActorT = typename std::decay_t::ActorT; - static_assert(std::is_base_of::value, "unsafe send_closure"); - - ActorIdT id = std::forward(actor_id); - detail::send_closure_later(id.as_actor_ref(), function, std::forward(args)...); +void send_closure(ActorIdT &&actor_id, FunctionT function, ArgsT &&...args) { + internal::send_closure_dispatch(std::forward(actor_id), function, std::forward(args)...); } template , size_t argument_count = member_function_argument_count(), std::enable_if_t with_promise = true> -void send_closure(ActorIdT &&actor_id, FunctionT function, ArgsT &&... args) { - using ActorT = typename std::decay_t::ActorT; - static_assert(std::is_base_of::value, "unsafe send_closure"); - - ActorIdT id = std::forward(actor_id); - detail::send_closure_with_promise_later(id.as_actor_ref(), - call_n_arguments( - [&function](auto &&... nargs) { - return create_delayed_closure(function, - std::forward(nargs)...); - }, - std::forward(args)...), - get_last_argument(std::forward(args)...)); +void send_closure(ActorIdT &&actor_id, FunctionT function, ArgsT &&...args) { + internal::send_closure_dispatch(std::forward(actor_id), function, std::forward(args)...); } -#endif - template , size_t argument_count = member_function_argument_count(), std::enable_if_t with_promise = false> -auto future_send_closure(ActorIdT &&actor_id, FunctionT function, ArgsT &&... args) { +auto future_send_closure(ActorIdT &&actor_id, FunctionT function, ArgsT &&...args) { using R = ::td::detail::get_ret_t>; auto pf = make_promise_future(); send_closure(std::forward(actor_id), std::move(function), std::forward(args)..., @@ -115,7 +116,7 @@ template , size_t argument_count = member_function_argument_count(), std::enable_if_t with_promise = true> -Future future_send_closure(ActorIdT &&actor_id, FunctionT function, ArgsT &&... args) { +Future future_send_closure(ActorIdT &&actor_id, FunctionT function, ArgsT &&...args) { auto pf = make_promise_future(); send_closure(std::forward(actor_id), std::move(function), std::forward(args)..., std::move(pf.first)); @@ -123,7 +124,7 @@ Future future_send_closure(ActorIdT &&actor_id, FunctionT function, ArgsT &&. } template -bool send_closure_bool(ActorIdT &&actor_id, FunctionT function, ArgsT &&... args) { +bool send_closure_bool(ActorIdT &&actor_id, FunctionT function, ArgsT &&...args) { send_closure(std::forward(actor_id), function, std::forward(args)...); return true; } @@ -131,45 +132,30 @@ bool send_closure_bool(ActorIdT &&actor_id, FunctionT function, ArgsT &&... args template , size_t argument_count = member_function_argument_count(), std::enable_if_t with_promise = false> -void send_closure_later(ActorIdT &&actor_id, FunctionT function, ArgsT &&... args) { - using ActorT = typename std::decay_t::ActorT; - static_assert(std::is_base_of::value, "unsafe send_closure"); - - ActorIdT id = std::forward(actor_id); - detail::send_closure_later(id.as_actor_ref(), function, std::forward(args)...); +void send_closure_later(ActorIdT &&actor_id, FunctionT function, ArgsT &&...args) { + internal::send_closure_dispatch(std::forward(actor_id), function, std::forward(args)...); } template , size_t argument_count = member_function_argument_count(), std::enable_if_t with_promise = true> -void send_closure_later(ActorIdT &&actor_id, FunctionT function, ArgsT &&... args) { - using ActorT = typename std::decay_t::ActorT; - static_assert(std::is_base_of::value, "unsafe send_closure"); - - ActorIdT id = std::forward(actor_id); - detail::send_closure_with_promise_later(id.as_actor_ref(), - call_n_arguments( - [&function](auto &&... nargs) { - return create_delayed_closure(function, - std::forward(nargs)...); - }, - std::forward(args)...), - get_last_argument(std::forward(args)...)); +void send_closure_later(ActorIdT &&actor_id, FunctionT function, ArgsT &&...args) { + internal::send_closure_dispatch(std::forward(actor_id), function, std::forward(args)...); } template -bool send_closure_later_bool(ActorIdT &&actor_id, FunctionT function, ArgsT &&... args) { +bool send_closure_later_bool(ActorIdT &&actor_id, FunctionT function, ArgsT &&...args) { send_closure_later(std::forward(actor_id), function, std::forward(args)...); return true; } template -void send_lambda(ActorIdT &&actor_id, ArgsT &&... args) { +void send_lambda(ActorIdT &&actor_id, ArgsT &&...args) { ActorIdT id = std::forward(actor_id); detail::send_lambda(id.as_actor_ref(), std::forward(args)...); } template -void send_lambda_later(ActorIdT &&actor_id, ArgsT &&... args) { +void send_lambda_later(ActorIdT &&actor_id, ArgsT &&...args) { ActorIdT id = std::forward(actor_id); detail::send_lambda_later(id.as_actor_ref(), std::forward(args)...); } @@ -188,14 +174,14 @@ void send_signals_later(ActorIdT &&actor_id, ActorSignals signals) { class SendClosure { public: template - void operator()(ArgsT &&... args) const { + void operator()(ArgsT &&...args) const { td::actor::send_closure(std::forward(args)...); } }; template template -auto Promise::send_closure(ArgsT &&... args) { +auto Promise::send_closure(ArgsT &&...args) { return [promise = std::move(*this), t = std::make_tuple(std::forward(args)...)](auto &&r_res) mutable { TRY_RESULT_PROMISE(promise, res, std::move(r_res)); td::call_tuple(SendClosure(), std::tuple_cat(std::move(t), std::make_tuple(std::move(res), std::move(promise)))); @@ -203,7 +189,7 @@ auto Promise::send_closure(ArgsT &&... args) { } template -auto promise_send_closure(ArgsT &&... args) { +auto promise_send_closure(ArgsT &&...args) { return [t = std::make_tuple(std::forward(args)...)](auto &&res) mutable { td::call_tuple(SendClosure(), std::tuple_cat(std::move(t), std::make_tuple(std::move(res)))); }; diff --git a/tdactor/td/actor/common.h b/tdactor/td/actor/common.h index 222db7436..d81d67645 100644 --- a/tdactor/td/actor/common.h +++ b/tdactor/td/actor/common.h @@ -301,6 +301,12 @@ inline void send_message_later(ActorRef actor_ref, core::ActorMessage message) { send_message_later(actor_ref.actor_info, std::move(message)); } +template +auto run_on_current_actor(ClosureT &&closure) { + using ActorType = typename std::remove_cvref_t::ActorType; + return closure.run(¤t_actor()); +} + template void send_immediate(ActorRef actor_ref, ExecuteF &&execute, ToMessageF &&to_message) { auto scheduler_context_ptr = core::SchedulerContext::get(); @@ -320,7 +326,7 @@ void send_immediate(ActorRef actor_ref, ExecuteF &&execute, ToMessageF &&to_mess } template -void send_lambda(ActorRef actor_ref, F &&lambda) { +void send_lambda_immediate(ActorRef actor_ref, F &&lambda) { send_immediate(actor_ref, lambda, [&lambda]() mutable { return ActorMessageCreator::lambda(std::move(lambda)); }); } template @@ -328,14 +334,18 @@ void send_lambda_later(ActorRef actor_ref, F &&lambda) { send_message_later(actor_ref, ActorMessageCreator::lambda(std::move(lambda))); } +template +void send_lambda(ActorRef actor_ref, F &&lambda) { + send_lambda_immediate(actor_ref, std::forward(lambda)); +} + template void send_closure_impl(ActorRef actor_ref, ClosureT &&closure) { - using ActorType = typename ClosureT::ActorType; send_immediate( - actor_ref, [&closure]() mutable { closure.run(¤t_actor()); }, + actor_ref, [&closure]() mutable { run_on_current_actor(closure); }, [&closure]() mutable { return ActorMessageCreator::lambda( - [closure = to_delayed_closure(std::move(closure))]() mutable { closure.run(¤t_actor()); }); + [closure = to_delayed_closure(std::move(closure))]() mutable { run_on_current_actor(closure); }); }); } @@ -346,38 +356,77 @@ void send_closure(ActorRef actor_ref, ArgsT &&...args) { template void send_closure_later_impl(ActorRef actor_ref, ClosureT &&closure) { - using ActorType = typename ClosureT::ActorType; - send_message_later(actor_ref, - ActorMessageCreator::lambda([closure = to_delayed_closure(std::move(closure))]() mutable { - closure.run(¤t_actor()); - })); + send_message_later( + actor_ref, ActorMessageCreator::lambda( + [closure = to_delayed_closure(std::move(closure))]() mutable { run_on_current_actor(closure); })); } +// Helper to unwrap Result to T, otherwise keep the type as is +template +struct unwrap_result { + using type = T; +}; +template +struct unwrap_result> { + using type = T; +}; +template +using unwrap_result_t = typename unwrap_result::type; + template void send_closure_with_promise(ActorRef actor_ref, ClosureT &&closure, PromiseT &&promise) { - using ActorType = typename ClosureT::ActorType; - using ResultType = decltype(closure.run(¤t_actor())); - auto &&promise_i = promise_interface(std::forward(promise)); - send_immediate( - actor_ref, [&closure, &promise = promise_i]() mutable { promise(closure.run(¤t_actor())); }, - [&closure, &promise = promise_i]() mutable { - return ActorMessageCreator::lambda( - [closure = to_delayed_closure(std::move(closure)), promise = std::move(promise)]() mutable { - promise(closure.run(¤t_actor())); - }); - }); + using RawResultType = decltype(run_on_current_actor(closure)); + if constexpr (std::is_void_v) { + // Adapt void to td::Unit + auto &&promise_i = promise_interface(std::forward(promise)); + send_immediate( + actor_ref, + [&closure, &promise = promise_i]() mutable { + run_on_current_actor(closure); + promise(td::Unit()); + }, + [&closure, &promise = promise_i]() mutable { + return ActorMessageCreator::lambda( + [closure = to_delayed_closure(std::move(closure)), promise = std::move(promise)]() mutable { + run_on_current_actor(closure); + promise(td::Unit()); + }); + }); + } else { + using ResultType = unwrap_result_t; + auto &&promise_i = promise_interface(std::forward(promise)); + send_immediate( + actor_ref, + [&closure, &promise = promise_i]() mutable { connect(std::move(promise), run_on_current_actor(closure)); }, + [&closure, &promise = promise_i]() mutable { + return ActorMessageCreator::lambda( + [closure = to_delayed_closure(std::move(closure)), promise = std::move(promise)]() mutable { + connect(std::move(promise), run_on_current_actor(closure)); + }); + }); + } } template void send_closure_with_promise_later(ActorRef actor_ref, ClosureT &&closure, PromiseT &&promise) { - using ActorType = typename ClosureT::ActorType; - using ResultType = decltype(closure.run(¤t_actor())); - send_message_later( - actor_ref, - ActorMessageCreator::lambda([closure = to_delayed_closure(std::move(closure)), - promise = promise_interface(std::forward(promise))]() mutable { - promise(closure.run(¤t_actor())); - })); + using RawResultType = decltype(run_on_current_actor(closure)); + if constexpr (std::is_void_v) { + // Adapt void to td::Unit + send_message_later( + actor_ref, + ActorMessageCreator::lambda([closure = to_delayed_closure(std::move(closure)), + promise = promise_interface(std::forward(promise))]() mutable { + run_on_current_actor(closure); + promise(td::Unit()); + })); + } else { + using ResultType = unwrap_result_t; + send_message_later(actor_ref, ActorMessageCreator::lambda([closure = to_delayed_closure(std::move(closure)), + promise = promise_interface( + std::forward(promise))]() mutable { + connect(std::move(promise), run_on_current_actor(closure)); + })); + } } template diff --git a/tdactor/td/actor/core/ActorExecuteContext.h b/tdactor/td/actor/core/ActorExecuteContext.h index 37f7d856c..37d80b592 100644 --- a/tdactor/td/actor/core/ActorExecuteContext.h +++ b/tdactor/td/actor/core/ActorExecuteContext.h @@ -44,6 +44,9 @@ class ActorExecuteContext : public Context { CHECK(actor_); return *actor_; } + Actor *actor_ptr() const { + return actor_; + } bool has_flags() const { return flags_ != 0; } diff --git a/tdactor/td/actor/core/CpuWorker.cpp b/tdactor/td/actor/core/CpuWorker.cpp index d78660c6f..b8cf7ce41 100644 --- a/tdactor/td/actor/core/CpuWorker.cpp +++ b/tdactor/td/actor/core/CpuWorker.cpp @@ -22,6 +22,7 @@ #include "td/actor/core/SchedulerContext.h" #include "td/actor/core/Scheduler.h" // FIXME: afer LocalQueue is in a separate file +#include namespace td { namespace actor { @@ -34,55 +35,54 @@ void CpuWorker::run() { waiter_.init_slot(slot, thread_id); auto &debug = dispatcher.get_debug(); while (true) { - SchedulerMessage message; - if (try_pop(message, thread_id)) { + SchedulerToken token = nullptr; + if (try_pop(token, thread_id)) { waiter_.stop_wait(slot); - if (!message) { + if (!token) { return; } - auto lock = debug.start(message->get_name()); - ActorExecutor executor(*message, dispatcher, ActorExecutor::Options().with_from_queue()); + auto encoded = reinterpret_cast(token); + if ((encoded & 1u) == 0u) { + // Regular actor message + auto raw_message = reinterpret_cast(token); + SchedulerMessage message(SchedulerMessage::acquire_t{}, raw_message); + auto lock = debug.start(message->get_name()); + ActorExecutor executor(*message, dispatcher, ActorExecutor::Options().with_from_queue()); + } else { + // Coroutine continuation + auto h = std::coroutine_handle<>::from_address(reinterpret_cast(encoded & ~uintptr_t(1))); + auto lock = debug.start("coro"); + h.resume(); + } } else { waiter_.wait(slot); } } } -bool CpuWorker::try_pop_local(SchedulerMessage &message) { - SchedulerMessage::Raw *raw_message; - if (local_queues_[id_].try_pop(raw_message)) { - message = SchedulerMessage(SchedulerMessage::acquire_t{}, raw_message); - return true; - } - return false; +bool CpuWorker::try_pop_local(SchedulerToken &token) { + return local_queues_[id_].try_pop(token); } -bool CpuWorker::try_pop_global(SchedulerMessage &message, size_t thread_id) { - SchedulerMessage::Raw *raw_message; - if (queue_.try_pop(raw_message, thread_id)) { - message = SchedulerMessage(SchedulerMessage::acquire_t{}, raw_message); - return true; - } - return false; +bool CpuWorker::try_pop_global(SchedulerToken &token, size_t thread_id) { + return queue_.try_pop(token, thread_id); } -bool CpuWorker::try_pop(SchedulerMessage &message, size_t thread_id) { +bool CpuWorker::try_pop(SchedulerToken &token, size_t thread_id) { if (++cnt_ == 51) { cnt_ = 0; - if (try_pop_global(message, thread_id) || try_pop_local(message)) { + if (try_pop_global(token, thread_id) || try_pop_local(token)) { return true; } } else { - if (try_pop_local(message) || try_pop_global(message, thread_id)) { + if (try_pop_local(token) || try_pop_global(token, thread_id)) { return true; } } for (size_t i = 1; i < local_queues_.size(); i++) { size_t pos = (i + id_) % local_queues_.size(); - SchedulerMessage::Raw *raw_message; - if (local_queues_[id_].steal(raw_message, local_queues_[pos])) { - message = SchedulerMessage(SchedulerMessage::acquire_t{}, raw_message); + if (local_queues_[id_].steal(token, local_queues_[pos])) { return true; } } diff --git a/tdactor/td/actor/core/CpuWorker.h b/tdactor/td/actor/core/CpuWorker.h index d9f32513b..dd520241a 100644 --- a/tdactor/td/actor/core/CpuWorker.h +++ b/tdactor/td/actor/core/CpuWorker.h @@ -19,6 +19,7 @@ #pragma once #include "td/actor/core/SchedulerMessage.h" +#include "td/actor/core/SchedulerContext.h" #include "td/utils/MpmcQueue.h" #include "td/utils/MpmcWaiter.h" @@ -31,23 +32,22 @@ template struct LocalQueue; class CpuWorker { public: - CpuWorker(MpmcQueue &queue, MpmcWaiter &waiter, size_t id, - MutableSpan> local_queues) + CpuWorker(MpmcQueue &queue, MpmcWaiter &waiter, size_t id, + MutableSpan> local_queues) : queue_(queue), waiter_(waiter), id_(id), local_queues_(local_queues) { } void run(); private: - MpmcQueue &queue_; + MpmcQueue &queue_; MpmcWaiter &waiter_; size_t id_; - MutableSpan> local_queues_; + MutableSpan> local_queues_; size_t cnt_{0}; - bool try_pop(SchedulerMessage &message, size_t thread_id); - - bool try_pop_local(SchedulerMessage &message); - bool try_pop_global(SchedulerMessage &message, size_t thread_id); + bool try_pop(SchedulerToken &token, size_t thread_id); + bool try_pop_local(SchedulerToken &token); + bool try_pop_global(SchedulerToken &token, size_t thread_id); }; } // namespace core } // namespace actor diff --git a/tdactor/td/actor/core/Scheduler.cpp b/tdactor/td/actor/core/Scheduler.cpp index 40f1e72cd..dfb570638 100644 --- a/tdactor/td/actor/core/Scheduler.cpp +++ b/tdactor/td/actor/core/Scheduler.cpp @@ -21,6 +21,8 @@ #include "td/actor/core/CpuWorker.h" #include "td/actor/core/IoWorker.h" +#include + namespace td { namespace actor { namespace core { @@ -43,10 +45,10 @@ Scheduler::Scheduler(std::shared_ptr scheduler_group_info, S info_->id = id; if (cpu_threads_count != 0) { info_->cpu_threads_count = cpu_threads_count; - info_->cpu_queue = std::make_unique>(1024, max_thread_count()); + info_->cpu_queue = std::make_unique>(1024, max_thread_count()); info_->cpu_queue_waiter = std::make_unique(); - info_->cpu_local_queue = std::vector>(cpu_threads_count); + info_->cpu_local_queue = std::vector>(cpu_threads_count); } info_->io_queue = std::make_unique>(); info_->io_queue->init(); @@ -161,22 +163,38 @@ void Scheduler::ContextImpl::add_to_queue(ActorInfoPtr actor_info_ptr, Scheduler if (need_poll || !info.cpu_queue) { info.io_queue->writer_put(std::move(actor_info_ptr)); } else { + auto token = static_cast(actor_info_ptr.release()); if (scheduler_id == get_scheduler_id() && cpu_worker_id_.is_valid()) { - // may push local - CHECK(actor_info_ptr); - auto raw = actor_info_ptr.release(); auto should_notify = info.cpu_local_queue[cpu_worker_id_.value()].push( - raw, [&](auto value) { info.cpu_queue->push(value, get_thread_id()); }); + token, [&](auto value) { info.cpu_queue->push(value, get_thread_id()); }); if (should_notify) { info.cpu_queue_waiter->notify(); } return; } - info.cpu_queue->push(actor_info_ptr.release(), get_thread_id()); + info.cpu_queue->push(token, get_thread_id()); info.cpu_queue_waiter->notify(); } } +void Scheduler::ContextImpl::add_token_to_cpu_queue(SchedulerToken token, SchedulerId scheduler_id) { + if (!scheduler_id.is_valid()) { + scheduler_id = get_scheduler_id(); + } + auto &info = scheduler_group()->schedulers.at(scheduler_id.value()); + if (scheduler_id == get_scheduler_id() && cpu_worker_id_.is_valid()) { + auto should_notify = info.cpu_local_queue[cpu_worker_id_.value()].push(token, [&](auto value) { + info.cpu_queue->push(value, get_thread_id()); + }); + if (should_notify) { + info.cpu_queue_waiter->notify(); + } + return; + } + info.cpu_queue->push(token, get_thread_id()); + info.cpu_queue_waiter->notify(); +} + ActorInfoCreator &Scheduler::ContextImpl::get_actor_info_creator() { return *creator_; } @@ -245,7 +263,7 @@ void Scheduler::ContextImpl::stop() { for (auto &scheduler_info : group.schedulers) { scheduler_info.io_queue->writer_put({}); for (size_t i = 0; i < scheduler_info.cpu_threads_count; i++) { - scheduler_info.cpu_queue->push({}, get_thread_id()); + scheduler_info.cpu_queue->push(nullptr, get_thread_id()); scheduler_info.cpu_queue_waiter->notify(); } } @@ -294,11 +312,18 @@ void Scheduler::close_scheduler_group(SchedulerGroupInfo &group_info) { for (auto &q : scheduler_info.cpu_local_queue) { auto &cpu_queue = q; while (true) { - SchedulerMessage::Raw *raw_message; + SchedulerToken raw_message; if (!cpu_queue.try_pop(raw_message)) { break; } - SchedulerMessage(SchedulerMessage::acquire_t{}, raw_message); + auto encoded = reinterpret_cast(raw_message); + if ((encoded & 1u) == 0u) { + SchedulerMessage::Raw *raw = reinterpret_cast(raw_message); + SchedulerMessage(SchedulerMessage::acquire_t{}, raw); + } else { + auto h = std::coroutine_handle<>::from_address(reinterpret_cast(encoded & ~uintptr_t(1))); + h.destroy(); + } // message's destructor is called queues_are_empty = false; } @@ -306,11 +331,18 @@ void Scheduler::close_scheduler_group(SchedulerGroupInfo &group_info) { if (scheduler_info.cpu_queue) { auto &cpu_queue = *scheduler_info.cpu_queue; while (true) { - SchedulerMessage::Raw *raw_message; + SchedulerToken raw_message; if (!cpu_queue.try_pop(raw_message, get_thread_id())) { break; } - SchedulerMessage(SchedulerMessage::acquire_t{}, raw_message); + auto encoded = reinterpret_cast(raw_message); + if ((encoded & 1u) == 0u) { + SchedulerMessage::Raw *raw = reinterpret_cast(raw_message); + SchedulerMessage(SchedulerMessage::acquire_t{}, raw); + } else { + auto h = std::coroutine_handle<>::from_address(reinterpret_cast(encoded & ~uintptr_t(1))); + h.destroy(); + } // message's destructor is called queues_are_empty = false; } diff --git a/tdactor/td/actor/core/Scheduler.h b/tdactor/td/actor/core/Scheduler.h index e76b919e1..85458294f 100644 --- a/tdactor/td/actor/core/Scheduler.h +++ b/tdactor/td/actor/core/Scheduler.h @@ -55,10 +55,8 @@ #include #include -#include #include #include -#include #include namespace td { @@ -151,10 +149,10 @@ struct LocalQueue { struct SchedulerInfo { SchedulerId id; // will be read by all workers is any thread - std::unique_ptr> cpu_queue; + std::unique_ptr> cpu_queue; std::unique_ptr cpu_queue_waiter; - std::vector> cpu_local_queue; + std::vector> cpu_local_queue; //std::vector> cpu_stealing_queue; // only scheduler itself may read from io_queue_ @@ -244,6 +242,7 @@ class Scheduler { SchedulerId get_scheduler_id() const override; void add_to_queue(ActorInfoPtr actor_info_ptr, SchedulerId scheduler_id, bool need_poll) override; + void add_token_to_cpu_queue(SchedulerToken token, SchedulerId scheduler_id) override; ActorInfoCreator &get_actor_info_creator() override; diff --git a/tdactor/td/actor/core/SchedulerContext.h b/tdactor/td/actor/core/SchedulerContext.h index 99b1922f5..55dc8fa45 100644 --- a/tdactor/td/actor/core/SchedulerContext.h +++ b/tdactor/td/actor/core/SchedulerContext.h @@ -28,6 +28,10 @@ namespace td { namespace actor { namespace core { + +// Token type for CPU queue - encodes either ActorInfo* (bit 0 = 0) or coroutine handle (bit 0 = 1) +using SchedulerToken = void*; + class SchedulerDispatcher { public: virtual ~SchedulerDispatcher() = default; @@ -35,6 +39,7 @@ class SchedulerDispatcher { virtual SchedulerId get_scheduler_id() const = 0; virtual void add_to_queue(ActorInfoPtr actor_info_ptr, SchedulerId scheduler_id, bool need_poll) = 0; virtual void set_alarm_timestamp(const ActorInfoPtr &actor_info_ptr) = 0; + virtual void add_token_to_cpu_queue(SchedulerToken token, SchedulerId scheduler_id) = 0; }; struct Debug; diff --git a/tdactor/td/actor/coro.h b/tdactor/td/actor/coro.h new file mode 100644 index 000000000..3013b5734 --- /dev/null +++ b/tdactor/td/actor/coro.h @@ -0,0 +1,11 @@ +#pragma once + +#include "td/actor/coro_types.h" +#include "td/actor/coro_executor.h" +#include "td/actor/coro_awaitables.h" +#include "td/actor/coro_task.h" +#include "td/actor/coro_utils.h" + +namespace td::actor { + +} // namespace td::actor \ No newline at end of file diff --git a/tdactor/td/actor/coro_awaitables.h b/tdactor/td/actor/coro_awaitables.h new file mode 100644 index 000000000..2296d73ea --- /dev/null +++ b/tdactor/td/actor/coro_awaitables.h @@ -0,0 +1,234 @@ +#pragma once + +#include "td/actor/coro_types.h" +#include "td/actor/coro_executor.h" +#include "td/utils/Status.h" + +#include +#include +#include +#include + +namespace td::actor { + +namespace detail { + +template +struct WrappedCoroutine { + struct promise_type { + Body* body{}; + std::coroutine_handle outer{}; + + promise_type(Body* b, std::coroutine_handle o) noexcept : body(b), outer(o) { + } + + WrappedCoroutine get_return_object() noexcept { + return WrappedCoroutine{std::coroutine_handle::from_promise(*this)}; + } + std::suspend_always initial_suspend() noexcept { + return {}; + } + auto final_suspend() noexcept { + struct A { + promise_type* p; + bool await_ready() noexcept { + return false; + } + std::coroutine_handle<> await_suspend(std::coroutine_handle self) noexcept { + auto next = p->body->route_resume(p->outer); + self.destroy(); + return next; + } + void await_resume() noexcept { + } + }; + return A{this}; + } + void return_void() noexcept { + } + void unhandled_exception() noexcept { + std::terminate(); + } + }; + + using handle = std::coroutine_handle; + handle h{}; + explicit WrappedCoroutine(handle hh) : h(hh) { + } + ~WrappedCoroutine() { + if (h) + h.destroy(); + } + WrappedCoroutine& operator=(WrappedCoroutine&& o) = delete; + WrappedCoroutine(const WrappedCoroutine&) = delete; + WrappedCoroutine& operator=(const WrappedCoroutine&) = delete; +}; + +template +[[nodiscard]] WrappedCoroutine make_wrapped_coroutine( + BodyT* b, std::coroutine_handle o) noexcept { + co_return; +} + +template +[[nodiscard]] std::coroutine_handle<> wrap_coroutine(BodyT* body, std::coroutine_handle outer) noexcept { + auto tmp = make_wrapped_coroutine(body, outer); + return std::exchange(tmp.h, {}); +} + +template +inline constexpr bool has_peek = requires(Aw& a) { a.await_resume_peek(); }; + +template +struct TaskUnwrapAwaiter { + using A = std::remove_cvref_t; + using Res = decltype(std::declval().await_resume()); + using Ok = decltype(std::declval().move_as_ok()); + + [[no_unique_address]] A aw; + + using Cache = std::conditional_t, td::Unit, std::optional>; + [[no_unique_address]] Cache ok_; + + bool await_ready() noexcept { + // If aw.await_ready() is true and not error, we may continue execution right away + if constexpr (has_peek) { + if (aw.await_ready() && !aw.await_resume_peek().is_error()) { + return true; + } + } + return false; + } + + template + std::coroutine_handle<> route_resume(std::coroutine_handle h) noexcept { + if constexpr (has_peek) { + const Res& r = aw.await_resume_peek(); + if (r.is_error()) { + auto rr = aw.await_resume(); + return h.promise().route_finish(std::move(rr).move_as_error()); + } + return h.promise().route_resume(); + } else { + Res r = aw.await_resume(); + if (r.is_error()) { + return h.promise().route_finish(std::move(r).move_as_error()); + } + ok_.emplace(std::move(r).move_as_ok()); + return h.promise().route_resume(); + } + } + + template + std::coroutine_handle<> await_suspend(std::coroutine_handle h) noexcept { + if constexpr (requires { aw.await_ready(); }) { + if (aw.await_ready()) { + return route_resume(h); + } + } + if (h.promise().is_immediate_execution_always_allowed()) { + return await_suspend_to(aw, h); + } + auto r_handle = wrap_coroutine(this, h); + return await_suspend_to(aw, r_handle); + } + + Ok await_resume() noexcept { + if constexpr (!has_peek) { + if (ok_) { + return std::move(*ok_); + } + } + Res r = aw.await_resume(); + return std::move(r).move_as_ok(); + } +}; + +template +struct TaskWrapAwaiter { + using A = std::remove_cvref_t; + using Res = decltype(std::declval().await_resume()); + + [[no_unique_address]] A aw; + + bool await_ready() noexcept { + return aw.await_ready(); + } + + template + std::coroutine_handle<> route_resume(std::coroutine_handle h) noexcept { + return h.promise().route_resume(); + } + + template + std::coroutine_handle<> await_suspend(std::coroutine_handle h) noexcept { + if constexpr (requires { aw.await_ready(); }) { + if (aw.await_ready()) { + return route_resume(h); + } + } + if (h.promise().is_immediate_execution_always_allowed()) { + return await_suspend_to(aw, h); + } + auto r_handle = wrap_coroutine(this, h); + return await_suspend_to(aw, r_handle); + } + + Res await_resume() noexcept { + return aw.await_resume(); + } +}; + +template +struct ResultUnwrapAwaiter { + using Res = std::remove_cvref_t; + using Ok = decltype(std::declval().move_as_ok()); + + Res result; + + bool await_ready() noexcept { + return result.is_ok(); + } + + template + std::coroutine_handle<> await_suspend(std::coroutine_handle h) noexcept { + if (result.is_error()) { + // This code is not generic and currently only used in Task and SharedTask implementations + h.promise().return_value(std::move(result).move_as_error()); + return h.promise().final_suspend().await_suspend(h); + } + return h; + } + + Ok await_resume() noexcept { + return std::move(result).move_as_ok(); + } +}; + +} // namespace detail + +// These helpers are used via await_transform to: +// 1) Ensure the awaiting coroutine resumes on the current task's scheduler. +// 2) Optionally unwrap td::Result. If it’s an error, propagate it to the parent +// as if using co_return error; so `co_await get_error();` is equivalent to `co_return get_error();`. +template +[[nodiscard]] auto unwrap_and_resume_on_current(Aw&& aw_) noexcept { + return detail::TaskUnwrapAwaiter{std::forward(aw_), {}}; +} + +template +[[nodiscard]] auto wrap_and_resume_on_current(Aw&& aw_) noexcept { + return detail::TaskWrapAwaiter{std::forward(aw_)}; +} + +template +[[nodiscard]] auto result_awaiter_unwrap(Result&& r) noexcept { + return detail::ResultUnwrapAwaiter>(std::move(r)); +} + +template +[[nodiscard]] auto result_awaiter_wrap(Result&& r) noexcept { + return detail::ReadyAwaitable>(std::move(r)); +} + +} // namespace td::actor diff --git a/tdactor/td/actor/coro_executor.h b/tdactor/td/actor/coro_executor.h new file mode 100644 index 000000000..6871529e1 --- /dev/null +++ b/tdactor/td/actor/coro_executor.h @@ -0,0 +1,237 @@ +#pragma once + +#include "td/actor/actor.h" +#include "td/actor/core/SchedulerContext.h" +#include "td/actor/core/SchedulerId.h" +#include "td/actor/coro_types.h" + +#include +#include +#include + +namespace td::actor { + +namespace detail { + +inline ActorId<> get_current_actor_id() noexcept { + auto context = core::ActorExecuteContext::get(); + if (context && context->actor_ptr()) { + return actor_id(context->actor_ptr()); + } + return td::actor::ActorId<>{}; +} + +template +class ActorMessageCoroutineSafe : public core::ActorMessageImpl { + public: + explicit ActorMessageCoroutineSafe(std::coroutine_handle

continuation) : continuation_(continuation) { + } + void run() override { + continuation_.resume(); + continuation_ = {}; + } + ~ActorMessageCoroutineSafe() override; + + private: + std::coroutine_handle

continuation_; +}; + +// Executor +// - schedule (sometime, NOT right now) +// - execute_or_schedule - it is ok to start executing immediately +// - for any just execute +// - for scheduler just execute +// - for actor try to execute immediately and schedule if can't +// - resume_or_schedule - if we are already executing we may continue otherwise schedule. + +struct ActorExecutor { + td::actor::ActorId<> actor; + bool is_immediate_execution_allowed() const noexcept { + return actor == get_current_actor_id(); + } + bool is_immediate_execution_always_allowed() const noexcept { + return false; + } + template + [[nodiscard]] std::coroutine_handle<> resume_or_schedule(std::coroutine_handle

cont) noexcept { + if (is_immediate_execution_allowed()) { + return cont; + } + schedule(std::move(cont)); + return std::noop_coroutine(); + } + template + static auto to_message(std::coroutine_handle

cont) noexcept { + return td::actor::core::ActorMessage{std::make_unique>(std::move(cont))}; + } + template + [[nodiscard]] std::coroutine_handle<> execute_or_schedule(std::coroutine_handle

cont) noexcept { + if (is_immediate_execution_allowed()) { + return cont; + } + td::actor::detail::send_immediate( + actor.as_actor_ref(), [&] { cont.resume(); }, [&]() { return to_message(std::move(cont)); }); + return std::noop_coroutine(); + } + template + void schedule(std::coroutine_handle

cont) noexcept { + td::actor::detail::send_message_later(actor.as_actor_ref(), to_message(std::move(cont))); + } +}; + +struct SchedulerExecutor { + bool is_immediate_execution_allowed() const noexcept { + return get_current_actor_id().empty(); + } + bool is_immediate_execution_always_allowed() const noexcept { + return false; + } + [[nodiscard]] std::coroutine_handle<> resume_or_schedule(std::coroutine_handle<> cont) noexcept { + return execute_or_schedule(std::move(cont)); + } + [[nodiscard]] std::coroutine_handle<> execute_or_schedule(std::coroutine_handle<> cont) noexcept { + if (is_immediate_execution_allowed()) { + return cont; + } + schedule(std::move(cont)); + return std::noop_coroutine(); + } + void schedule(std::coroutine_handle<> cont) noexcept { + auto token = reinterpret_cast(encode_continuation(cont)); + auto ctx = td::actor::core::SchedulerContext::get(); + CHECK(ctx); + ctx->add_token_to_cpu_queue(token, td::actor::core::SchedulerId{}); + } +}; + +struct AnyExecutor { + bool is_immediate_execution_allowed() const noexcept { + return true; + } + bool is_immediate_execution_always_allowed() const noexcept { + return true; + } + [[nodiscard]] std::coroutine_handle<> resume_or_schedule(std::coroutine_handle<> cont) noexcept { + return execute_or_schedule(std::move(cont)); + } + [[nodiscard]] std::coroutine_handle<> execute_or_schedule(std::coroutine_handle<> cont) noexcept { + return cont; + } + void schedule(std::coroutine_handle<> cont) noexcept { + LOG(ERROR) << "Schedule on any executor"; + SchedulerExecutor{}.schedule(cont); + } +}; + +struct Executor { + std::variant executor_{SchedulerExecutor{}}; + + static Executor on_actor(td::actor::ActorId<> actor) noexcept { + return {ActorExecutor{std::move(actor)}}; + } + template + static Executor on_actor(const td::actor::ActorOwn& actor) noexcept { + return {ActorExecutor{actor.get()}}; + } + static Executor on_scheduler() noexcept { + return {SchedulerExecutor{}}; + } + static Executor on_any() noexcept { + return {AnyExecutor{}}; + } + static Executor on_current_actor() noexcept { + return on_actor(get_current_actor_id()); + } + static Executor on_default() noexcept { + auto current_actor_id = get_current_actor_id(); + return current_actor_id.empty() ? on_scheduler() : on_actor(std::move(current_actor_id)); + } + + bool is_immediate_execution_allowed() const noexcept { + return visit([](auto& v) { return v.is_immediate_execution_allowed(); }, executor_); + //return executor_.visit([&](auto& v) { return v.is_immediate_execution_allowed(); }); + } + bool is_immediate_execution_always_allowed() const noexcept { + return visit([](auto& v) { return v.is_immediate_execution_always_allowed(); }, executor_); + } + template + [[nodiscard]] std::coroutine_handle<> resume_or_schedule(std::coroutine_handle

cont) noexcept { + return visit([&](auto& v) { return v.resume_or_schedule(std::move(cont)); }, executor_); + } + template + [[nodiscard]] std::coroutine_handle<> execute_or_schedule(std::coroutine_handle

cont) noexcept { + return visit([&](auto& v) { return v.execute_or_schedule(std::move(cont)); }, executor_); + } + + template + void schedule(std::coroutine_handle

cont) noexcept { + return visit([&](auto& v) { return v.schedule(std::move(cont)); }, executor_); + } +}; + +template +ActorMessageCoroutineSafe

::~ActorMessageCoroutineSafe() { + if (continuation_) { + SchedulerExecutor{}.schedule(continuation_.promise().route_finish(td::Status::Error("Actor destroyed"))); + } +} + +struct ResumeOn { + Executor executor; + + bool await_ready() noexcept { + return executor.is_immediate_execution_allowed(); + } + + template + std::coroutine_handle<> await_suspend(std::coroutine_handle

cont) noexcept { + return executor.resume_or_schedule(std::move(cont)); + } + + void await_resume() noexcept { + } +}; + +struct YieldOn { + Executor executor; + bool await_ready() noexcept { + return false; + } + template + std::coroutine_handle<> await_suspend(std::coroutine_handle

self) noexcept { + executor.schedule(std::move(self)); + return std::noop_coroutine(); + } + void await_resume() noexcept { + } +}; + +} // namespace detail + +using Executor = detail::Executor; + +[[nodiscard]] inline auto resume_on(Executor executor) noexcept { + return detail::ResumeOn{std::move(executor)}; +} + +[[nodiscard]] inline auto yield_on(Executor executor) noexcept { + return detail::YieldOn{std::move(executor)}; +} + +inline auto attach_to_actor(td::actor::ActorId<> actor_id) noexcept { + return detail::ResumeOn{Executor::on_actor(actor_id)}; +} + +inline auto detach_from_actor() noexcept { + return detail::ResumeOn{Executor::on_scheduler()}; +} + +inline auto become_lightweight() noexcept { + return detail::ResumeOn{Executor::on_any()}; +} + +inline Yield yield_on_current() noexcept { + return Yield{}; +} + +} // namespace td::actor diff --git a/tdactor/td/actor/coro_task.h b/tdactor/td/actor/coro_task.h new file mode 100644 index 000000000..accf714ca --- /dev/null +++ b/tdactor/td/actor/coro_task.h @@ -0,0 +1,469 @@ +#pragma once + +#include "td/actor/coro_types.h" +#include "td/actor/coro_executor.h" +#include "td/actor/coro_awaitables.h" +#include "td/actor/PromiseFuture.h" +#include "td/utils/Status.h" + +#include +#include +#include +#include +#include + +namespace td::actor { + +namespace detail { + +struct TaskStateManagerData { + Executor executor{Executor::on_default()}; + std::atomic flags{0}; + std::coroutine_handle<> continuation{}; + + enum Flags : uint8_t { + READY_FLAG = 1, + STARTED_FLAG = 2, + SUSPEND_FLAG = 4, + DETACH_FLAG = 8, + }; + + uint8_t set_flag(uint8_t new_flag) noexcept { + auto old_flags = flags.fetch_or(new_flag, std::memory_order_acq_rel); + CHECK((old_flags & new_flag) == 0); + return old_flags; + } + + [[nodiscard]] std::coroutine_handle<> on_ready(std::coroutine_handle<> self_handle) { + auto old_flags = set_flag(READY_FLAG); + + if (!(old_flags & STARTED_FLAG)) { + return continuation; + } + + std::coroutine_handle<> next = std::noop_coroutine(); + if (old_flags & SUSPEND_FLAG) { + next = continuation; + } + if (old_flags & DETACH_FLAG) { + self_handle.destroy(); + } + return next; + } + + void set_executor(Executor new_executor) noexcept { + executor = std::move(new_executor); + } +}; + +template +struct TaskStateManager { + using Data = TaskStateManagerData; + using Coro = std::coroutine_handle<>; + using Self = std::coroutine_handle

; + Data* data; + + void set_executor(Executor executor) noexcept { + data->executor = std::move(executor); + } + + void start(Self self) { + set_is_started(); + data->executor.schedule(self); + } + void start_immediate(Self self) { + set_is_started(); + data->executor.execute_or_schedule(self).resume(); + } + void start_external() { + set_is_started(); + } + + bool is_ready() const noexcept { + return false; + } + + [[nodiscard]] std::coroutine_handle<> on_suspend_and_start(Self self, Coro continuation) { + data->continuation = continuation; + return data->executor.execute_or_schedule(self); + } + + void on_detach(Self self) { + self.destroy(); + } + + private: + void set_is_started() { + data->flags.fetch_or(Data::STARTED_FLAG, std::memory_order_relaxed); + } +}; + +template +struct StartedTaskStateManager { + using Data = TaskStateManagerData; + using Coro = std::coroutine_handle<>; + using Self = std::coroutine_handle

; + Data* data; + + bool is_ready() const noexcept { + return data->flags.load(std::memory_order_acquire) & Data::READY_FLAG; + } + + [[nodiscard]] std::coroutine_handle<> on_suspend(Coro new_continuation) { + CHECK(!data->continuation); + data->continuation = new_continuation; + auto old_flags = data->set_flag(Data::SUSPEND_FLAG); + return (old_flags & Data::READY_FLAG) ? new_continuation : std::noop_coroutine(); + } + + void on_detach(Self self) { + auto old_flags = data->set_flag(Data::DETACH_FLAG); + if (old_flags & Data::READY_FLAG) { + self.destroy(); + } + } +}; + +} // namespace detail + +struct promise_common { + detail::TaskStateManagerData state_manager_data; +}; + +template +struct promise_value : promise_common { + [[no_unique_address]] ResultT result; + + template + void return_value(TT&& v) noexcept { + result = std::forward(v); + } + + struct ExternalResult {}; + void return_value(ExternalResult&&) noexcept { + } + + void unhandled_exception() noexcept { + result = td::Status::Error("unhandled exception in coroutine"); + } + + ResultT extract_result() noexcept { + return std::move(result); + } +}; + +template +struct Task; + +template +struct StartedTask; + +template +struct promise_type : promise_value> { + static_assert(!std::is_void_v, "Task is not supported; use Task instead"); + using Handle = std::coroutine_handle; + + auto self() noexcept { + return Handle::from_promise(*this); + } + + template + void external_return_value(TT&& v) noexcept + requires(!std::is_void_v) + { + route_finish(std::forward(v)).resume(); + } + + Task get_return_object() noexcept { + return Task{Handle::from_promise(*this)}; + } + + std::suspend_always initial_suspend() noexcept { + return {}; + } + + auto final_suspend() noexcept { + struct Final { + bool await_ready() noexcept { + return false; + } + std::coroutine_handle<> await_suspend(Handle self) noexcept { + return self.promise().state_manager_data.on_ready(self); + } + void await_resume() noexcept { + } + }; + return Final{}; + } + + auto await_transform(detail::YieldOn y) { + this->state_manager_data.set_executor(y.executor); + return y; + } + auto await_transform(detail::ResumeOn y) { + this->state_manager_data.set_executor(y.executor); + return y; + } + + template + auto await_transform(SkipAwaitTransform wrapped_aw) noexcept { + return std::move(wrapped_aw.awaitable); + } + + auto await_transform(Yield) noexcept { + return yield_on(this->state_manager_data.executor); + } + + auto await_transform(std::suspend_always) noexcept { + return std::suspend_always{}; + } + + auto await_transform(std::suspend_never) noexcept { + return std::suspend_always{}; + } + + template + auto await_transform(td::ResultUnwrap wrapped) noexcept { + return await_transform(std::move(wrapped.result)); + } + template + auto await_transform(td::ResultWrap wrapped) noexcept { + return await_transform(Wrapped>{std::move(wrapped.result)}); + } + + template + auto await_transform(Wrapped> wrapped) noexcept { + return result_awaiter_wrap(std::move(wrapped.value)); + } + + template + auto await_transform(td::Result&& result) noexcept { + return result_awaiter_unwrap(std::move(result)); + } + + template + auto await_transform(Task&& task) noexcept { + return unwrap_and_resume_on_current(std::move(task).start_immediate()); + } + template + auto await_transform(StartedTask&& task) noexcept { + return unwrap_and_resume_on_current(std::move(task)); + } + + template + auto await_transform(Wrapped>&& wrapped) noexcept { + return wrap_and_resume_on_current(std::move(wrapped.value)); + } + template + auto await_transform(Wrapped>&& wrapped) noexcept { + return wrap_and_resume_on_current(std::move(wrapped.value)); + } + + template + auto await_transform(Aw&& aw) noexcept { + return wrap_and_resume_on_current(std::forward(aw)); + } + + // API used by TaskWrapAwaiter and TaskUnwrapAwaiter + bool is_immediate_execution_always_allowed() const noexcept { + return this->state_manager_data.executor.is_immediate_execution_always_allowed(); + } + + std::coroutine_handle<> route_resume() { + return this->state_manager_data.executor.resume_or_schedule(self()); + } + + std::coroutine_handle<> route_finish(td::Result r) { + this->return_value(std::move(r)); + return final_suspend().await_suspend(self()); + } +}; + +template +struct [[nodiscard]] Task { + using value_type = T; + + using promise_type = promise_type; + using Handle = std::coroutine_handle; + Handle h{}; + + Task() = default; + explicit Task(Handle hh) : h(hh) { + } + Task(Task&& o) noexcept : h(std::exchange(o.h, {})) { + } + Task& operator=(Task&& o) = delete; + Task(const Task&) = delete; + Task& operator=(const Task&) = delete; + + auto sm() { + return detail::TaskStateManager{&h.promise().state_manager_data}; + } + + ~Task() noexcept { + detach(); + } + void detach() { + if (!h) { + return; + } + sm().on_detach(h); + h = {}; + } + + auto start() && { + sm().start(h); + return StartedTask{std::exchange(h, {})}; + } + auto start_immediate() && { + sm().start_immediate(h); + return StartedTask{std::exchange(h, {})}; + } + auto start_external() && { + sm().start_external(); + return StartedTask{std::exchange(h, {})}; + } + void set_executor(Executor new_executor) { + CHECK(h); + sm().set_executor(std::move(new_executor)); + } + + constexpr bool await_ready() noexcept { + return sm().is_ready(); + } + + std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation) noexcept { + CHECK(h); + return sm().on_suspend_and_start(h, continuation); + } + td::Result await_resume() noexcept { + CHECK(h); + return h.promise().extract_result(); + } + + const td::Result& await_resume_peek() const noexcept { + CHECK(h); + return h.promise().result; + } + + auto wrap() && { + return Wrapped{std::move(*this)}; + } +}; + +template +struct [[nodiscard]] StartedTask { + using value_type = T; + + using promise_type = promise_type; + using Handle = std::coroutine_handle; + Handle h{}; + + auto sm() { + CHECK(h); + return detail::StartedTaskStateManager{&h.promise().state_manager_data}; + } + StartedTask() = default; + explicit StartedTask(Handle hh) : h(hh) { + CHECK(h); + } + StartedTask(StartedTask&& o) noexcept : h(std::exchange(o.h, {})) { + } + StartedTask& operator=(StartedTask&& o) = delete; + StartedTask(const StartedTask&) = delete; + StartedTask& operator=(const StartedTask&) = delete; + + ~StartedTask() noexcept { + detach(); + } + void detach() { + if (!h) { + return; + } + sm().on_detach(h); + h = {}; + } + bool await_ready() noexcept { + return sm().is_ready(); + } + + std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation) noexcept { + return sm().on_suspend(continuation); + } + td::Result await_resume() noexcept { + return h.promise().extract_result(); + } + + const td::Result& await_resume_peek() const noexcept { + CHECK(h); + return h.promise().result; + } + + auto wrap() && { + return Wrapped{std::move(*this)}; + } + + template + auto then(F&& f) && { + using Self = StartedTask; + using FDecayed = std::decay_t; + using Awaitable = decltype(detail::make_awaitable(std::declval()(std::declval()))); + using Ret = decltype(std::declval().await_resume()); + using U = std::conditional_t, decltype(std::declval().move_as_ok()), Ret>; + return [](Self task, FDecayed fn) mutable -> Task { + co_await become_lightweight(); + auto value = co_await std::move(task); + co_return co_await detail::make_awaitable(fn(std::move(value))); + }(std::move(*this), std::forward(f)); + } + + struct ExternalPromise : public PromiseInterface { + ExternalPromise() = default; + explicit ExternalPromise(promise_type* p) : promise(p) { + } + void set_value(T&& value) override { + promise.release()->external_return_value(std::move(value)); + } + void set_error(Status&& error) override { + promise.release()->external_return_value(std::move(error)); + } + + operator bool() const { + return bool(promise); + } + + struct Deleter { + void operator()(promise_type* p) { + p->external_return_value(td::Status::Error("promise destroyed")); + } + }; + std::unique_ptr promise{}; + }; + + static std::pair make_bridge() { + auto task = []() -> Task { co_return typename promise_type::ExternalResult{}; }(); + task.set_executor(Executor::on_scheduler()); + auto promise = ExternalPromise(&task.h.promise()); + auto started_task = std::move(task).start_external(); + return std::make_pair(std::move(started_task), std::move(promise)); + } +}; + +template +void custom_connect(P&& p, StartedTask&& mt) noexcept { + if (mt.await_ready()) { + connect(std::move(p), mt.await_resume()); + return; + } + [](auto promise, auto mt) mutable -> detail::FireAndForget { + auto result = co_await mt; + connect(std::move(promise), std::move(result)); + }(std::forward

(p), std::move(mt)); +} + +template +void custom_connect(P&& p, Task&& t) noexcept { + connect(std::forward

(p), std::move(t).start_immediate()); +} + +} // namespace td::actor diff --git a/tdactor/td/actor/coro_types.h b/tdactor/td/actor/coro_types.h new file mode 100644 index 000000000..ae65dd70c --- /dev/null +++ b/tdactor/td/actor/coro_types.h @@ -0,0 +1,114 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace td::actor { + + +template +concept TDResultLike = requires(R r) { + { r.is_error() } -> std::convertible_to; + r.move_as_ok(); + r.move_as_error(); +}; + +template +concept IsAwaitable = requires(T& t, std::coroutine_handle<> h) { + t.await_suspend(h); + t.await_resume(); +}; + +template +concept CoroTask = requires(T t) { + typename T::promise_type; + typename T::value_type; + { t.await_suspend(std::coroutine_handle<>{}) }; + { t.await_resume() }; +}; + +namespace detail { +template +inline std::coroutine_handle<> await_suspend_to(Aw& aw, std::coroutine_handle<> cont) noexcept { + if constexpr (std::is_same_v) { + aw.await_suspend(cont); + return std::noop_coroutine(); + } else if constexpr (std::is_same_v) { + if (!aw.await_suspend(cont)) { + return cont; + } + return std::noop_coroutine(); + } else { + return aw.await_suspend(cont); + } +} + +struct FireAndForget { + struct promise_type { + FireAndForget get_return_object() noexcept { + return {}; + } + std::suspend_never initial_suspend() noexcept { + return {}; + } + std::suspend_never final_suspend() noexcept { + return {}; + } + void return_void() noexcept { + } + void unhandled_exception() noexcept { + } + }; +}; + +template +struct ReadyAwaitable { + [[no_unique_address]] T value; + constexpr bool await_ready() const noexcept { + return true; + } + constexpr void await_suspend(std::coroutine_handle<>) const noexcept { + } + T await_resume() noexcept { + return std::move(value); + } +}; + +template +auto make_awaitable(X&& x) { + using DX = std::decay_t; + if constexpr (IsAwaitable) { + return std::forward(x); + } else { + return ReadyAwaitable{std::forward(x)}; + } +} + +inline uintptr_t encode_continuation(std::coroutine_handle<> h) noexcept { + auto p = h.address(); + auto v = reinterpret_cast(p); + return v | 1u; +} + +inline std::coroutine_handle<> decode_continuation(uintptr_t token) noexcept { + return std::coroutine_handle<>::from_address(reinterpret_cast(token & ~uintptr_t(1))); +} + +} // namespace detail + +template +struct [[nodiscard]] SkipAwaitTransform { + [[no_unique_address]] T awaitable; +}; + +template +struct Wrapped { + [[no_unique_address]] T value; +}; + +struct [[nodiscard]] Yield {}; + +} // namespace td::actor diff --git a/tdactor/td/actor/coro_utils.h b/tdactor/td/actor/coro_utils.h new file mode 100644 index 000000000..0abb72967 --- /dev/null +++ b/tdactor/td/actor/coro_utils.h @@ -0,0 +1,268 @@ +#pragma once + +#include "td/actor/coro_types.h" +#include "td/actor/coro_executor.h" +#include "td/actor/coro_task.h" +#include "td/actor/actor.h" +#include "td/utils/Status.h" +#include "td/utils/Slice.h" + +#include +#include +#include +#include + +namespace td::actor { + +namespace detail { + +template +struct memfn_meta; + +template +struct remove_memfn_const { + using type = T; +}; +template +struct remove_memfn_const { + using type = R (C::*)(Args...); +}; +template +using remove_memfn_const_t = remove_memfn_const::type; + +template +struct memfn_meta : memfn_meta> {}; +template +struct memfn_meta { + using cls = C; + using ret = R; +}; + +template +struct unwrap_promise { + using type = void; +}; +template +struct unwrap_promise> { + using type = T; +}; + +template +using last_t = std::tuple_element_t>; + +template +struct is_task : std::false_type {}; +template +struct is_task> : std::true_type {}; + +} // namespace detail + +template +td::Result> collect(std::vector>&& results) { + for (auto& result : results) { + if (result.is_error()) { + return result.move_as_error(); + } + } + std::vector values; + values.reserve(results.size()); + for (auto& result : results) { + values.push_back(result.move_as_ok()); + } + return values; +} + +template +td::Result> collect(std::tuple...> results) { + return std::apply( + [](auto&&... results) -> td::Result> { + td::Status error; + bool has_error = false; + + (void)((results.is_error() ? (has_error = true, error = results.move_as_error(), false) : true) && ...); + + if (has_error) { + return std::move(error); + } + + return std::tuple{std::move(results).move_as_ok()...}; + }, + std::move(results)); +} + +template +using await_result_t = + decltype(std::declval::promise_type>().await_transform(std::declval()).await_resume()); + +template 1), int> = 0> +auto all(Awaitables&&... awaitables) -> Task...>> { + co_await become_lightweight(); + co_return std::tuple...>{co_await std::forward(awaitables)...}; +} + +template 1), int> = 0> +auto all_wrap(Awaitables&&... awaitables) -> Task>...>> { + co_await become_lightweight(); + co_return std::tuple>...>{ + co_await Wrapped{std::forward(awaitables)}...}; +} + +template +Task>> all(std::vector tasks) { + co_await become_lightweight(); + std::vector> results; + results.reserve(tasks.size()); + // TODO: auto start + for (auto& task : tasks) { + results.push_back(co_await std::move(task)); + } + co_return results; +} + +template +Task>>> all_wrap(std::vector tasks) { + co_await become_lightweight(); + std::vector>> results; + // TODO: auto start + results.reserve(tasks.size()); + for (auto& task : tasks) { + results.push_back(co_await Wrapped{std::move(task)}); + } + co_return results; +} + +enum class UnifiedKind : uint8_t { None, Void, TaskReturn, PromiseArgument, ReturnValue }; + +template +struct unified_result; + +template +struct unified_result : unified_result> {}; + +template +struct unified_result (C::*)(Args...)> { + using type = T; + static constexpr UnifiedKind kind = UnifiedKind::ReturnValue; +}; + +template +struct unified_result (C::*)(Args...)> { + using type = T; + static constexpr UnifiedKind kind = UnifiedKind::TaskReturn; +}; + +template +struct unified_result { + using type = T; + static constexpr UnifiedKind kind = UnifiedKind::ReturnValue; +}; + +template +struct unified_result { + using Last = std::remove_cvref_t>; + static constexpr bool is_promise = std::is_same_v::type>>; + using type = std::conditional_t::type, void>; + static constexpr UnifiedKind kind = is_promise ? UnifiedKind::PromiseArgument : UnifiedKind::Void; +}; + +template +auto ask_impl(TargetId&& to, MemFn mf, Args&&... args) { + using Meta = unified_result; + using T = Meta::type; + using Ret = detail::memfn_meta::ret; + + using TT = std::conditional_t, td::Unit, T>; + + static_assert(Meta::kind == UnifiedKind::TaskReturn || Meta::kind == UnifiedKind::PromiseArgument || + Meta::kind == UnifiedKind::ReturnValue || Meta::kind == UnifiedKind::Void, + "ask: method must return Task or take td::Promise as last parameter"); + + if constexpr (Meta::kind == UnifiedKind::TaskReturn) { + return ask_new_impl(std::forward(to), mf, std::forward(args)...); + } + + auto [task, promise] = StartedTask::make_bridge(); + td::actor::internal::send_closure_dispatch(std::forward(to), mf, std::forward(args)..., + std::move(promise)); + return std::move(task); +} + +template +auto ask_new_impl(TargetId&& to, MemFn mf, Args&&... args) { + using Meta = unified_result; + using T = Meta::type; + static_assert(Meta::kind == UnifiedKind::TaskReturn, "ask: method must return Task"); + if constexpr (Later) { + auto task = [](auto closure) -> Task { + co_return co_await detail::run_on_current_actor(closure); + }(create_delayed_closure(mf, std::forward(args)...)); + task.set_executor(Executor::on_actor(to)); + return std::move(task).start(); + } else { + std::optional> o_task; + td::actor::detail::send_immediate( + to.as_actor_ref(), + [&] { + o_task.emplace(detail::run_on_current_actor(create_immediate_closure(mf, std::forward(args)...)) + .start_immediate()); + }, + [&]() { + auto task = [](auto closure) -> Task { + co_return co_await detail::run_on_current_actor(closure); + }(create_delayed_closure(mf, std::forward(args)...)); + task.set_executor(Executor::on_actor(to)); + o_task.emplace(std::move(task).start_external()); + return detail::ActorExecutor::to_message(o_task->h); + }); + return std::move(*o_task); + } +} + +template +auto ask_new(Args&&... args) { + return ask_new_impl(std::forward(args)...); +} + +template +auto ask_new_immediate(Args&&... args) { + return ask_new_impl(std::forward(args)...); +} + +template +auto ask(Args&&... args) { + return ask_impl(std::forward(args)...); +} + +template +auto ask_immediate(Args&&... args) { + return ask_impl(std::forward(args)...); +} + +template +auto ask_promise(Args&&... args) { + return ask(std::forward(args)...); +} + +template +auto spawn_actor(td::Slice name, TaskType task) { + using StartedTaskType = StartedTask; + using PromiseType = StartedTaskType::ExternalPromise; + auto [result_task, result_promise] = StartedTaskType::make_bridge(); + + struct TaskAwaiter : public td::actor::Actor { + TaskAwaiter(TaskType task, PromiseType promise) : task_(std::move(task)), promise_(std::move(promise)) { + } + + private: + TaskType task_; + PromiseType promise_; + void start_up() { + task_.set_executor(Executor::on_current_actor()); + connect(std::move(promise_), std::move(task_)); + } + }; + td::actor::create_actor(name, std::move(task), std::move(result_promise)).release(); + return std::move(result_task); +} + +} // namespace td::actor diff --git a/tdactor/test/CMakeLists.txt b/tdactor/test/CMakeLists.txt new file mode 100644 index 000000000..17a71612c --- /dev/null +++ b/tdactor/test/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(test-coro test-coro.cpp) +target_link_libraries(test-coro tdactor tdutils) +target_include_directories(test-coro PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../) diff --git a/tdactor/test/actors_core.cpp b/tdactor/test/actors_core.cpp index 1c56d1e5b..43eb67dc4 100644 --- a/tdactor/test/actors_core.cpp +++ b/tdactor/test/actors_core.cpp @@ -279,6 +279,9 @@ TEST(Actor2, executor_simple) { void add_to_queue(ActorInfoPtr ptr, SchedulerId scheduler_id, bool need_poll) override { queue.push_back(std::move(ptr)); } + void add_token_to_cpu_queue(SchedulerToken token, SchedulerId scheduler_id) override { + UNREACHABLE(); + } void set_alarm_timestamp(const ActorInfoPtr &actor_info_ptr) override { UNREACHABLE(); } @@ -326,7 +329,9 @@ TEST(Actor2, executor_simple) { LOG_CHECK(sb.as_cslice() == "") << sb.as_cslice(); } CHECK(dispatcher.queue.size() == 1); - { ActorExecutor executor(*actor, dispatcher, ActorExecutor::Options().with_from_queue()); } + { + ActorExecutor executor(*actor, dispatcher, ActorExecutor::Options().with_from_queue()); + } CHECK(dispatcher.queue.size() == 1); dispatcher.queue.clear(); LOG_CHECK(sb.as_cslice() == "bigB") << sb.as_cslice(); @@ -1142,7 +1147,8 @@ TEST(Actor2, test_stats) { td::actor::create_actor("queue_worker").release(); } void alarm() override { - td::actor::send_closure(stats_, &ActorStats::prepare_stats, td::promise_send_closure(actor_id(this), &Master::on_stats)); + td::actor::send_closure(stats_, &ActorStats::prepare_stats, + td::promise_send_closure(actor_id(this), &Master::on_stats)); alarm_timestamp() = td::Timestamp::in(5); } void on_stats(td::Result r_stats) { @@ -1155,7 +1161,7 @@ TEST(Actor2, test_stats) { private: std::shared_ptr watcher_; td::actor::ActorOwn stats_; - int cnt_={2}; + int cnt_ = {2}; }; td::actor::create_actor("Master", watcher).release(); }); diff --git a/tdactor/test/test-coro.cpp b/tdactor/test/test-coro.cpp new file mode 100644 index 000000000..27508d44c --- /dev/null +++ b/tdactor/test/test-coro.cpp @@ -0,0 +1,877 @@ +#include "absl/strings/str_format.h" +#include "td/actor/coro.h" +#include "td/actor/actor.h" +#include "td/utils/Random.h" +#include "td/utils/tests.h" +#include "td/utils/port/sleep.h" + +#include +#include +#include +#include +#include +#include +#include + +using namespace td::actor; + +template +inline void expect_ok(const td::Result& r, const char* msg) { + LOG_CHECK(r.is_ok()) << msg; +} + +template +inline void expect_eq(const T& a, const U& b, const char* msg) { + LOG_CHECK(a == b) << msg; +} + +inline void expect_true(bool cond, const char* msg) { + LOG_CHECK(cond) << msg; +} + +inline void small_sleep_ms(int ms) { + std::this_thread::sleep_for(std::chrono::milliseconds(ms)); +} + +// Minimal custom awaitables used to validate await_transform branches +struct HandleReturningAwaitable { + std::coroutine_handle<> stored_handle; + int value; + bool ready; + + explicit HandleReturningAwaitable(int v = 42, bool r = false) : value(v), ready(r) { + } + + bool await_ready() const noexcept { + return ready; + } + std::coroutine_handle<> await_suspend(std::coroutine_handle<> h) noexcept { + stored_handle = h; + return h; // symmetric transfer form + } + int await_resume() noexcept { + return value; + } +}; + +struct BoolReturningAwaitable { + std::coroutine_handle<> stored_handle; + int value; + bool ready; + bool should_suspend; + + explicit BoolReturningAwaitable(int v = 43, bool r = false, bool suspend = true) + : value(v), ready(r), should_suspend(suspend) { + } + + bool await_ready() const noexcept { + return ready; + } + bool await_suspend(std::coroutine_handle<> h) noexcept { + stored_handle = h; + if (should_suspend) { + detail::SchedulerExecutor{}.schedule(h); + } + return should_suspend; + } + int await_resume() noexcept { + return value; + } +}; + +struct VoidReturningAwaitable { + std::coroutine_handle<> stored_handle; + int value; + bool ready; + + explicit VoidReturningAwaitable(int v = 44, bool r = false) : value(v), ready(r) { + } + + bool await_ready() const noexcept { + return ready; + } + void await_suspend(std::coroutine_handle<> h) noexcept { + stored_handle = h; + detail::SchedulerExecutor{}.schedule(h); + } + int await_resume() noexcept { + return value; + } +}; + +// Simple utility actors used by tests +class TestLogger final : public td::actor::Actor { + public: + Task log(std::string msg) { + small_sleep_ms(10); + LOG(INFO) << "[Logger] " << msg; + co_return td::Unit(); + } + + void log_promise(std::string msg, td::Promise promise) { + small_sleep_ms(10); + LOG(INFO) << "[Logger Promise] " << msg; + promise.set_value(td::Unit()); + } + + Task multiply2(int x) { + co_return x * 2; + } + + void multiply3(int x, td::Promise promise) { + promise.set_value(x * 3); + } +}; + +class TestDatabase final : public td::actor::Actor { + public: + explicit TestDatabase(td::actor::ActorId logger) : logger_(logger) { + } + + Task calcA(const std::string& key) { + if (key.empty()) + co_return td::Status::Error("empty key"); + small_sleep_ms(5); + co_return static_cast(key.size()) * 10; + } + + Task square(size_t x) { + co_return x* x; + } + + Task get(std::string key) { + int ai = co_await calcA(key); // Tasks propagate errors by default + (void)co_await ask(logger_, &TestLogger::log, std::string("DB get ") + key); + small_sleep_ms(5); + if (key == "user") + co_return std::string("Alice:") + std::to_string(ai); + co_return td::Status::Error("not found"); + } + + private: + td::actor::ActorId logger_; +}; + +// 4) Tests grouped by topic +class CoroSpec final : public td::actor::Actor { + public: + void start_up() override { + logger_ = td::actor::create_actor("TestLogger").release(); + db_ = td::actor::create_actor("TestDatabase", logger_).release(); + [](Task test) -> Task { + (co_await std::move(test).wrap()).ensure(); + co_await yield_on_current(); + td::actor::SchedulerContext::get()->stop(); + co_return td::Unit{}; + }(run_all()) + .start_immediate() + .detach(); + } + + Task unified_queries() { + LOG(INFO) << "=== unified queries ==="; + + using Value = std::unique_ptr; + static auto make_value = []() { return std::make_unique(7); }; + class Uni : public td::actor::Actor { + public: + Value get_value() { + return make_value(); + } + td::Result get_result() { + return make_value(); + } + td::Result get_result_err() { + return td::Status::Error("error"); + } + Task get_task() { + co_return get_value(); + } + Task get_task_err() { + co_return td::Status::Error("error"); + } + void get_via_promise(td::Promise promise) { + promise.set_value(get_value()); + } + void get_via_promise_err(td::Promise promise) { + promise.set_error(td::Status::Error("error")); + } + void get_void() { + } + }; + auto uni = td::actor::create_actor("UnifiedResult"); + + auto check = [](td::Result v) { + CHECK(v.ok()); + CHECK(*v.ok() == 7); + }; + auto check_value = [](Value v) { CHECK(*v == 7); }; + auto check_ok = [](td::Result v) { v.ensure(); }; + auto check_err = [](td::Result v) { v.ensure_error(); }; + + auto meta_ask = [&](auto&&... args) -> Task { + LOG(INFO) << "meta_ask: ask(args...)"; + check(co_await ask(args...).wrap()); + LOG(INFO) << "meta_ask: ask_immediate(args...)"; + check(co_await ask_immediate(args...).wrap()); + LOG(INFO) << "meta_ask: co_try(ask_immediate(args...))"; + check_value(co_await ask_immediate(args...)); + LOG(INFO) << "meta_ask: co_try(ask(args...))"; + check_value(co_await ask(args...)); + //check(co_await ask_new(args...)); + co_return td::Unit{}; + }; + + auto meta_ask_err = [&](auto&&... args) -> Task { + LOG(INFO) << "meta_ask_err: ask(args...)"; + check_err(co_await ask(args...).wrap()); + LOG(INFO) << "meta_ask_err: ask_immediate(args...)"; + check_err(co_await ask_immediate(args...).wrap()); + + check_err(co_await [](auto&&... args) -> Task { + LOG(INFO) << "meta_ask_err: co_try(ask_immediate(args...))"; + co_return co_await ask_immediate(args...); + }(args...) + .wrap()); + check_err(co_await [](auto&&... args) -> Task { + LOG(INFO) << "meta_ask_err: co_try(ask_immediate(args...))"; + co_return co_await ask_immediate(args...); + }(args...) + .wrap()); + check_err(co_await [](auto&&... args) -> Task { + LOG(INFO) << "meta_ask_err: co_try(ask_immediate(args...))"; + co_await ask(args...); + co_return std::make_unique(17); + }(args...) + .wrap()); + + LOG(INFO) << "meta_ask_err: co_try(ask(args...))"; + //check(co_await ask_new(args...)); + co_return td::Unit{}; + }; + + // ask from coroutines + check_ok(co_await meta_ask(uni, &Uni::get_result)); + check_ok(co_await meta_ask(uni, &Uni::get_task)); + check_ok(co_await meta_ask(uni, &Uni::get_via_promise)); + check_ok(co_await meta_ask(uni, &Uni::get_value)); + co_await ask(uni, &Uni::get_void); + co_await ask_immediate(uni, &Uni::get_void); + + check_ok(co_await meta_ask_err(uni, &Uni::get_result_err)); + check_ok(co_await meta_ask_err(uni, &Uni::get_task_err)); + check_ok(co_await meta_ask_err(uni, &Uni::get_via_promise_err)); + + check(co_await ask_new(uni, &Uni::get_task)); + + static_assert(td::is_promise_interface::ExternalPromise>()); + + auto check_send_closure = [&](auto&& f) -> Task { + auto [task, task_promise] = StartedTask::make_bridge(); + td::Promise promise = [moved_task_promise = std::move(task_promise)](td::Result r) mutable { + moved_task_promise.set_result(std::move(r)); + }; + f(std::move(promise)); + send_closure(uni, &Uni::get_value, std::move(promise)); + check(co_await std::move(task)); + co_return td::Unit{}; + }; + + co_await check_send_closure([&](auto promise) { send_closure(uni, &Uni::get_result, std::move(promise)); }); + co_await check_send_closure([&](auto promise) { send_closure(uni, &Uni::get_value, std::move(promise)); }); + co_await check_send_closure([&](auto promise) { send_closure(uni, &Uni::get_via_promise, std::move(promise)); }); + co_await check_send_closure([&](auto promise) { send_closure(uni, &Uni::get_task, std::move(promise)); }); + + bool done = false; + LOG(INFO) << "Test send_closure_immediate"; + send_closure_immediate(uni, &Uni::get_void, [&](td::Result r) { + (void)r; + done = true; + }); + CHECK(done); + + co_return td::Unit(); + } + + // A. Awaitable branch coverage (handle/bool/void; ready/suspend) + Task awaitable_branches() { + LOG(INFO) << "=== Awaitable branches ==="; + + struct Case { + const char* name; + int expected; + std::function()> run; + }; + std::vector cases = { + {"handle:not_ready", 100, [&]() -> Task { co_return co_await HandleReturningAwaitable(100, false); }}, + {"handle:ready", 101, [&]() -> Task { co_return co_await HandleReturningAwaitable(101, true); }}, + {"bool:suspend", 200, [&]() -> Task { co_return co_await BoolReturningAwaitable(200, false, true); }}, + {"bool:immediate", 201, [&]() -> Task { co_return co_await BoolReturningAwaitable(201, false, false); }}, + {"bool:ready", 202, [&]() -> Task { co_return co_await BoolReturningAwaitable(202, true, true); }}, + {"void:not_ready", 300, [&]() -> Task { co_return co_await VoidReturningAwaitable(300, false); }}, + {"void:ready", 301, [&]() -> Task { co_return co_await VoidReturningAwaitable(301, true); }}, + }; + for (auto& c : cases) { + auto r = co_await c.run(); + expect_eq(r, c.expected, c.name); + } + + int sum = 0; + sum += co_await HandleReturningAwaitable(10, false); + sum += co_await BoolReturningAwaitable(20, false, true); + sum += co_await BoolReturningAwaitable(30, false, false); + sum += co_await VoidReturningAwaitable(40, false); + expect_eq(sum, 100, "mixed awaitables sum"); + + co_return td::Unit(); + } + + // B. Recursion via ask hop and direct recursion + Task rec_fast(int n) { + if (n == 0) + co_return 0; + int r = co_await rec_fast(n - 1); // Tasks propagate errors by default now + co_return r + 1; + } + Task rec_slow(int n) { + if (n == 0) + co_return 0; + int r = + co_await ask(actor_id(this), &CoroSpec::rec_slow, n - 1); // ask returns StartedTask which propagates errors + co_return r + 1; + } + Task recursion() { + LOG(INFO) << "=== Recursion ==="; + for (int depth : {5, 10}) { + int a = co_await rec_slow(depth); // Tasks propagate errors by default + expect_eq(a, depth, "recursion via ask"); + int b = co_await rec_fast(depth); // Tasks propagate errors by default + expect_eq(b, depth, "direct recursion"); + } + co_return td::Unit(); + } + + // C. ask()/ask_immediate and unified Task/Promise targets + Task asks() { + LOG(INFO) << "=== ask / ask_immediate ==="; + + auto delayed = ask(db_, &TestDatabase::square, 4); + expect_true(!delayed.await_ready(), "delayed ask is not ready"); + expect_eq(co_await std::move(delayed), static_cast(16), "delayed ask result"); + + co_await Yield{}; + for (int i = 0; i < 16; i++) { + auto immediate = ask_immediate(db_, &TestDatabase::square, 4); + expect_true(immediate.await_ready(), "immediate ask is ready"); + expect_eq(immediate.await_resume().ok(), static_cast(16), "immediate ask result"); + } + + auto user = co_await ask(db_, &TestDatabase::get, std::string("user")); + LOG(INFO) << "User: " << user; + + (void)co_await ask(logger_, &TestLogger::log, std::string("unified Task target")); + (void)co_await ask(logger_, &TestLogger::log_promise, std::string("unified Promise target")); + co_return td::Unit(); + } + + // C2. Modifiers: Yield, ChangeOwner attach/detach, yield_on + Task modifiers() { + LOG(INFO) << "=== Modifiers (Yield, ChangeOwner, yield_on) ==="; + + auto self = actor_id(this); + co_await attach_to_actor(self); // just in case + auto on_self = [self] { + if (self != td::actor::detail::get_current_actor_id()) { + return td::Status::Error("not on self"); + } + return td::Status::OK(); + }; + auto on_none = [] { + if (!td::actor::detail::get_current_actor_id().empty()) { + return td::Status::Error("not on none"); + } + return td::Status::OK(); + }; + + // Yield sequencing + { + td::Timer timer; + for (int i = 0; i < 1000; i++) { + co_await yield_on_current(); + on_self(); + } + LOG(INFO) << "yield_on_current (x100): " << timer.elapsed(); + timer = {}; + for (int i = 0; i < 1000; i++) { + co_await attach_to_actor(self); // noop + on_self(); + } + LOG(INFO) << "attach_to_actor (x100) : " << timer.elapsed(); + } + + // Attach to current actor and ensure suspended await resumes on same actor + { + co_await attach_to_actor(self); + int v = co_await BoolReturningAwaitable(123, false, true); + on_self(); + expect_eq(v, 123, "suspended await result"); + } + + // Detach (no specific owner), ensure we still resume and continue + { + co_await detach_from_actor(); + on_none(); + int v = co_await BoolReturningAwaitable(321, false, true); + expect_eq(v, 321, "detached suspended await result"); + co_await attach_to_actor(self); + on_self(); + } + + // Explicit yield_on to current actor + { + co_await yield_on_current(); + on_self(); + } + + co_return td::Unit(); + } + + // D. Concurrency and double-resumption surface + Task concurrency() { + LOG(INFO) << "=== Concurrency ==="; + auto self = actor_id(this); + co_await detach_from_actor(); + + for (int round = 0; round < 100; round++) { + co_await attach_to_actor(self); + co_await detach_from_actor(); + auto task = [](int value) -> Task { + td::usleep_for(td::Random::fast(0, 1000)); + co_return value * 2; + }(round) + .start(); + td::usleep_for(td::Random::fast(0, 1000)); + auto result = co_await std::move(task); + CHECK(result == round * 2); + } + + for (int round = 0; round < 100; round++) { + co_await attach_to_actor(self); + co_await detach_from_actor(); + auto task = [](int value) -> Task { + td::usleep_for(td::Random::fast(0, 1000)); + co_return value * 2; + }(round) + .start(); + td::usleep_for(td::Random::fast(0, 1000)); + task.detach(); + td::usleep_for(100); + } + + // Many parallel tasks + sum + std::vector> many; + size_t expect = 0; + for (size_t i = 0; i < 200; i++) { + auto t = [](size_t v) -> Task { co_return v; }(i).start(); + many.push_back(std::move(t)); + expect += i; + } + size_t got = 0; + for (auto& t : many) { + auto v = co_await std::move(t); // Tasks propagate errors by default + got += v; + } + expect_eq(got, expect, "many parallel sum"); + co_return td::Unit{}; + } + Task concurrency2() { + LOG(INFO) << "=== Concurrency 2 ==="; + // A few shapes with spawn_coroutine_old + std::vector> shapes; + for (int i = 0; i < 8000; i++) { + int m = i % 4; + if (m == 0) { + //shapes.push_back(spawn_actor("immediate", []() -> Task { co_return 1; })); + } else if (m == 1) { + shapes.push_back(spawn_actor("hop1", []() -> Task { + co_await spawn_actor("sub", []() -> Task { co_return td::Unit(); }()); + co_return 2; + }())); + } else if (m == 2) { + // intentionally left out heavy nested spawns variant + } else { + //shapes.push_back([]() -> Task { co_return 2; }().start_immediate()); + } + } + int s = 0; + for (auto& t : shapes) { + auto v = co_await std::move(t); // Tasks propagate errors by default + s += v; + } + LOG(INFO) << "shapes sum: " << s; + co_return td::Unit(); + } + + // F. Task lifecycle sanity (lazy start + await; explicit start) + Task lifecycle() { + LOG(INFO) << "=== Task lifecycle ==="; + auto make_task = []() -> Task { co_return 7; }; + + // Await without explicit start + { + auto v = co_await make_task(); // Tasks propagate errors by default + expect_eq(v, 7, "await without start"); + } + // Explicit start + { + auto t = make_task().start(); + auto v = co_await std::move(t); // Tasks propagate errors by default + expect_eq(v, 7, "await after start"); + } + co_return td::Unit(); + } + Task helpers() { + LOG(INFO) << "=== Task helper ==="; + CHECK(5 == co_await td::actor::detail::make_awaitable(5)); + auto get7 = []() -> Task { co_return 7; }; + CHECK(7 == co_await get7()); + auto square = [](size_t x) -> Task { co_return x* x; }; + auto res = co_await get7().start().then(square); + CHECK(res == 49); + co_return td::Unit(); + co_return td::Unit(); + } + + Task combinators() { + LOG(ERROR) << "Test combinators"; + + // Test all() with variadic arguments + { + auto make_task = [](int val, int delay_ms) -> Task { + small_sleep_ms(delay_ms); + co_return val; + }; + + auto tuple = co_await all(make_task(1, 10), make_task(2, 20), make_task(3, 30)); + auto a = std::move(std::get<0>(tuple)); + auto b = std::move(std::get<1>(tuple)); + auto c = std::move(std::get<2>(tuple)); + expect_eq(1, a, "all() first result"); + expect_eq(2, b, "all() second result"); + expect_eq(3, c, "all() third result"); + LOG(ERROR) << "all() variadic test passed"; + } + + // Test all() with vector + { + std::vector> tasks; + for (int i = 0; i < 5; ++i) { + tasks.push_back([](int val) -> Task { + small_sleep_ms(val * 10); + co_return val * 2; + }(i)); + } + + auto results = co_await all(std::move(tasks)); + expect_eq(5u, results.size(), "all() vector size"); + for (size_t i = 0; i < results.size(); ++i) { + expect_eq(static_cast(i * 2), results[i], "all() vector result"); + } + LOG(ERROR) << "all() vector test passed"; + } + + // Test all() with errors and collect_results + { + auto success_task = []() -> Task { co_return 42; }; + auto error_task = []() -> Task { co_return td::Status::Error("Test error"); }; + + auto tuple = co_await all(success_task().wrap(), error_task().wrap()); + + // Test that individual results can have errors + auto s = std::move(std::get<0>(tuple)); + auto e = std::move(std::get<1>(tuple)); + expect_eq(42, s.ok(), "all() with error - success task"); + expect_true(e.is_error(), "all() with error - error task"); + + // Test collect_results with tuple containing an error + auto tuple2 = co_await all(success_task().wrap(), error_task().wrap()); + auto collected = collect(std::move(tuple2)); + expect_true(collected.is_error(), "collect_results should return error if any task failed"); + LOG(ERROR) << "all() error handling test passed"; + } + + // Test collect_results with all successful tasks + { + auto task1 = []() -> Task { co_return 1; }; + auto task2 = []() -> Task { co_return 2; }; + auto task3 = []() -> Task { co_return 3; }; + + // Test with tuple + auto tuple = co_await all(task1().wrap(), task2().wrap(), task3().wrap()); + auto collected_tuple = collect(std::move(tuple)); + expect_ok(collected_tuple, "collect_results should succeed when all tasks succeed"); + auto [a, b, c] = collected_tuple.move_as_ok(); + expect_eq(1, a, "First value"); + expect_eq(2, b, "Second value"); + expect_eq(3, c, "Third value"); + + // Test with vector + std::vector> tasks; + for (int i = 0; i < 5; ++i) { + tasks.push_back([](int val) -> Task { co_return val; }(i)); + } + auto vec = co_await all_wrap(std::move(tasks)); + auto collected_vec = collect(std::move(vec)); + expect_ok(collected_vec, "collect_results should succeed for vector"); + auto& values = collected_vec.ok(); + expect_eq(5u, values.size(), "Vector size"); + for (size_t i = 0; i < values.size(); ++i) { + expect_eq(static_cast(i), values[i], "Vector element"); + } + LOG(ERROR) << "collect_results test passed"; + } + + co_return td::Unit{}; + } + + // G. co_try success and error propagation + Task try_awaitable() { + LOG(INFO) << "=== co_try ==="; + + // Success path: unwrap value + { + auto ok_task = []() -> Task { co_return 123; }; + int v = co_await ok_task(); + expect_eq(v, 123, "co_try unwraps ok value"); + } + + // Error path: propagate out of outer Task, so awaiting yields error + { + auto err_task = []() -> Task { co_return td::Status::Error("boom"); }; + auto r = co_await err_task().wrap(); // control: direct await to observe error pattern + expect_true(r.is_error(), "sanity: err_task returns error"); + + // Now check propagation via co_try inside an outer task + auto outer = [err_task]() -> Task { + int x = co_await err_task(); + co_return x + 1; // should never reach + }(); + auto rr = co_await std::move(outer).wrap(); + expect_true(rr.is_error(), "co_try propagates error to outer Task"); + } + + // Test try_unwrap() method on StartedTask + { + auto ok_task = []() -> Task { co_return 456; }; + auto started = ok_task().start_immediate(); + int v = co_await std::move(started); + expect_eq(v, 456, "try_unwrap() unwraps ok value from StartedTask"); + } + + // Test try_unwrap() error propagation + { + auto err_task = []() -> Task { co_return td::Status::Error("test error"); }; + auto outer = [err_task]() -> Task { + auto started = err_task().start_immediate(); + int x = co_await std::move(started); + co_return x + 1; // should never reach + }(); + auto result = co_await std::move(outer).wrap(); + expect_true(result.is_error(), "try_unwrap() propagates error from StartedTask"); + } + + // Test co_try() with Result values (non-awaitable) + { + auto outer = []() -> Task { + td::Result ok_result = 789; + int x = co_await std::move(ok_result); + co_return x + 1; + }(); + auto result = co_await std::move(outer).wrap(); // Use wrap() to get Result + expect_true(result.is_ok(), "co_try(Result) works with ok value"); + expect_eq(result.move_as_ok(), 790, "co_try(Result) returns correct value"); + } + + // Test co_try() error propagation with Result + { + auto outer = []() -> Task { + td::Result err_result = td::Status::Error("direct error"); + int x = co_await std::move(err_result); + co_return x + 1; // should never reach + }(); + auto result = co_await std::move(outer).wrap(); // Use wrap() to get Result + expect_true(result.is_error(), "co_try(Result) propagates error"); + } + + // Test co_try() with Result lvalue reference + { + auto outer = []() -> Task { + td::Result ok_result = 999; + // Test with lvalue reference (should work with the Storage type handling) + int x = co_await std::move(ok_result); + // Result should be moved out after co_try + co_return x + 2; + }(); + auto result = co_await std::move(outer).wrap(); // Use wrap() to get Result + expect_true(result.is_ok(), "co_try(Result&) works with lvalue reference"); + expect_eq(result.move_as_ok(), 1001, "co_try(Result&) returns correct value"); + } + + // Test default Result co_await (propagates errors) + { + auto outer = []() -> Task { + td::Result ok_result = 333; + int x = co_await std::move(ok_result); // Default: propagates error + co_return x * 2; + }(); + auto result = co_await std::move(outer).wrap(); // Use wrap() to get Result + expect_true(result.is_ok(), "Result default co_await works with ok value"); + expect_eq(result.move_as_ok(), 666, "Result default co_await returns correct value"); + } + + // Test default Result co_await error propagation + { + auto outer = []() -> Task { + td::Result err_result = td::Status::Error("unwrap error"); + int x = co_await std::move(err_result); // Default: propagates error + co_return x * 2; // should never reach + }(); + auto result = co_await std::move(outer).wrap(); // Use wrap() to get Result + expect_true(result.is_error(), "Result default co_await propagates error"); + } + + // Test Result::wrap() to prevent error propagation + { + auto outer = []() -> Task> { + td::Result err_result = td::Status::Error("wrapped error"); + auto full_result = co_await std::move(err_result).wrap(); // Explicit: no propagation + expect_true(full_result.is_error(), "wrap() preserves error in Result"); + co_return full_result; + }(); + auto result = co_await std::move(outer); + expect_true(result.is_error(), "wrap() preserved the error"); + } + + // Test Result::wrap() with ok value + { + auto outer = []() -> Task> { + td::Result ok_result = 555; + auto full_result = co_await std::move(ok_result).wrap(); // Explicit: no propagation + expect_true(full_result.is_ok(), "wrap() preserves ok value in Result"); + co_return full_result; + }(); + auto result = co_await std::move(outer).wrap(); // Need to wrap outer task too + expect_true(result.is_ok(), "Task completes successfully"); + auto inner_result = result.move_as_ok(); + expect_true(inner_result.is_ok(), "wrap() preserved the ok value"); + expect_eq(inner_result.move_as_ok(), 555, "wrap() preserved the correct value"); + } + + // Test Task default co_await (propagates errors) + { + auto inner = []() -> Task { co_return 888; }; + auto outer = [&inner]() -> Task { + int x = co_await inner(); // Default: propagates errors, returns T + co_return x + 1; + }(); + auto result = co_await std::move(outer).wrap(); // Wrap to get Result + expect_true(result.is_ok(), "Task default co_await works"); + expect_eq(result.move_as_ok(), 889, "Task default co_await returns correct value"); + } + + // Test Task default co_await error propagation + { + auto inner = []() -> Task { co_return td::Status::Error("task error"); }; + auto outer = [&inner]() -> Task { + int x = co_await inner(); // Default: propagates error + co_return x + 1; // should never reach + }(); + auto result = co_await std::move(outer).wrap(); // Wrap to get Result + expect_true(result.is_error(), "Task default co_await propagates error"); + } + + // Test Task::wrap() to prevent error propagation + { + auto inner = []() -> Task { co_return td::Status::Error("wrapped task error"); }; + auto outer = [&inner]() -> Task> { + auto full_result = co_await inner().wrap(); // Explicit: no propagation + expect_true(full_result.is_error(), "Task::wrap() preserves error"); + co_return full_result; + }(); + auto result = co_await std::move(outer).wrap(); // Wrap outer too + expect_true(result.is_ok(), "Outer task completes successfully"); + auto inner_result = result.move_as_ok(); + expect_true(inner_result.is_error(), "Task::wrap() preserved the error"); + } + + co_return td::Unit{}; + } + + static Task slow_task() { + td::usleep_for(2000000); + co_return td::Unit{}; + } + Task stop_actor() { + class StopActor : public td::actor::Actor { + public: + void start_up() override { + alarm_timestamp() = td::Timestamp::in(1); + } + void alarm() override { + LOG(INFO) << "alarm"; + stop(); + } + Task query() { + auto task = slow_task(); + task.set_executor(Executor::on_scheduler()); + // actor will stop before task is finished + co_await std::move(task); + // here we could access actor but we should v + LOG(FATAL) << "access stopped actor"; + co_return 1; + } + }; + auto a = create_actor("stop_actor"); + auto r = co_await ask(a, &StopActor::query).wrap(); + r.ensure_error(); + LOG(INFO) << "Got error from stopped actor " << r.error(); + co_return td::Unit{}; + } + + // Master runner + Task run_all() { + LOG(ERROR) << "Run tests"; + (void)co_await ask(logger_, &TestLogger::log, std::string("Starting coroutine tests")); + + co_await unified_queries(); + co_await concurrency(); + for (int i = 0; i < 10; i++) { + co_await concurrency2(); + } + co_await awaitable_branches(); + co_await recursion(); + co_await asks(); + co_await modifiers(); + co_await lifecycle(); + co_await helpers(); + co_await combinators(); + co_await try_awaitable(); + co_await stop_actor(); + + (void)co_await ask(logger_, &TestLogger::log, std::string("All tests passed")); + co_return td::Unit(); + } + + private: + td::actor::ActorId db_; + td::actor::ActorId logger_; +}; + +// 5) Runner +int main() { + SET_VERBOSITY_LEVEL(VERBOSITY_NAME(INFO)); + td::actor::Scheduler scheduler({4}); + scheduler.run_in_context([&] { td::actor::create_actor("CoroSpec").release(); }); + scheduler.run(); + return 0; +} diff --git a/tdutils/td/utils/Closure.h b/tdutils/td/utils/Closure.h index 3b2fc9de5..3407c7542 100644 --- a/tdutils/td/utils/Closure.h +++ b/tdutils/td/utils/Closure.h @@ -95,7 +95,7 @@ class ImmediateClosure { template ImmediateClosure create_immediate_closure( - ResultT (ActorT::*func)(DestArgsT...), SrcArgsT &&... args) { + ResultT (ActorT::*func)(DestArgsT...), SrcArgsT &&...args) { return ImmediateClosure(func, std::forward(args)...); } @@ -172,7 +172,7 @@ DelayedClosure to_delayed_closure(DelayedClosure &&other) { } template -auto create_delayed_closure(ResultT (ActorT::*func)(DestArgsT...), SrcArgsT &&... args) { +auto create_delayed_closure(ResultT (ActorT::*func)(DestArgsT...), SrcArgsT &&...args) { return DelayedClosure(func, std::forward(args)...); } diff --git a/tdutils/td/utils/Status.h b/tdutils/td/utils/Status.h index f75de466a..27e09af59 100644 --- a/tdutils/td/utils/Status.h +++ b/tdutils/td/utils/Status.h @@ -443,6 +443,12 @@ class Status { } }; +// Forward declarations for Result wrappers +template +struct ResultUnwrap; +template +struct ResultWrap; + template class Result { public: @@ -604,6 +610,16 @@ class Result { return f(move_as_ok()); } + // Returns a wrapper that can be co_awaited to propagate errors in coroutines + ResultUnwrap try_unwrap() && { + return ResultUnwrap(std::move(*this)); + } + + // Returns a wrapper that prevents error propagation when co_awaited + ResultWrap wrap() && { + return ResultWrap(std::move(*this)); + } + private: Status status_; union { @@ -611,6 +627,19 @@ class Result { }; }; + +// Wrapper to prevent error propagation when co_awaiting Result +template +struct ResultWrap { + Result result; +}; + +template +struct ResultUnwrap { + Result result; + +}; + template <> inline Result::Result(Status &&status) : status_(std::move(status)) { // no assert