diff --git a/docs/index.rst b/docs/index.rst index 0e9c481d0..8b59fdf99 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1057,7 +1057,7 @@ addition to the properties dictated by the underlying librdkafka C library: where ``expiry_time`` is the time in seconds since the epoch as a floating point number. This callback is useful only when ``sasl.mechanisms=OAUTHBEARER`` is set and is served to get the initial token before a successful broker connection can be made. - The callback can be triggered by calling ``client.poll()`` or ``producer.flush()``. + The callback is asynchronously triggered by the background thread to maintain token validity.``. * ``on_delivery(kafka.KafkaError, kafka.Message)`` (**Producer**): value is a Python function reference that is called once for each produced message to indicate the final diff --git a/src/confluent_kafka/src/Admin.c b/src/confluent_kafka/src/Admin.c index 06953e867..c3af72413 100644 --- a/src/confluent_kafka/src/Admin.c +++ b/src/confluent_kafka/src/Admin.c @@ -5526,10 +5526,22 @@ static int Admin_init(PyObject *selfobj, PyObject *args, PyObject *kwargs) { return -1; } + /* Enable SASL callbacks on background thread for AdminClient since + * applications typically don't call poll() regularly on AdminClient. */ + if (self->oauth_cb) { + rd_kafka_sasl_background_callbacks_enable(self->rk); + } + /* Forward log messages to poll queue */ if (self->logger) rd_kafka_set_log_queue(self->rk, NULL); + + /* Wait for the background thread to set the token */ + if (self->oauth_cb) { + return wait_for_oauth_token_set(self); + } + return 0; } diff --git a/src/confluent_kafka/src/Consumer.c b/src/confluent_kafka/src/Consumer.c index b0ee9e9cf..2a6ec2433 100644 --- a/src/confluent_kafka/src/Consumer.c +++ b/src/confluent_kafka/src/Consumer.c @@ -1664,6 +1664,12 @@ static int Consumer_init(PyObject *selfobj, PyObject *args, PyObject *kwargs) { return -1; } + /* Enable Token Refresh to be handled by background thread if OAuth + * callback is provided */ + if (self->oauth_cb) { + rd_kafka_sasl_background_callbacks_enable(self->rk); + } + /* Forward log messages to main queue which is then forwarded * to the consumer queue */ if (self->logger) @@ -1674,6 +1680,12 @@ static int Consumer_init(PyObject *selfobj, PyObject *args, PyObject *kwargs) { self->u.Consumer.rkqu = rd_kafka_queue_get_consumer(self->rk); assert(self->u.Consumer.rkqu); + + /* Wait for the background thread to set the token */ + if (self->oauth_cb) { + return wait_for_oauth_token_set(self); + } + return 0; } diff --git a/src/confluent_kafka/src/Producer.c b/src/confluent_kafka/src/Producer.c index cb36a51cf..b09bad47a 100644 --- a/src/confluent_kafka/src/Producer.c +++ b/src/confluent_kafka/src/Producer.c @@ -1331,10 +1331,21 @@ static int Producer_init(PyObject *selfobj, PyObject *args, PyObject *kwargs) { return -1; } + /* Enable Token Refresh to be handled by background thread if OAuth + * callback is provided */ + if (self->oauth_cb) { + rd_kafka_sasl_background_callbacks_enable(self->rk); + } + /* Forward log messages to poll queue */ if (self->logger) rd_kafka_set_log_queue(self->rk, NULL); + /* Wait for the background thread to set the token */ + if (self->oauth_cb) { + return wait_for_oauth_token_set(self); + } + return 0; } diff --git a/src/confluent_kafka/src/confluent_kafka.c b/src/confluent_kafka/src/confluent_kafka.c index 6a1924cf7..1505ddd0b 100644 --- a/src/confluent_kafka/src/confluent_kafka.c +++ b/src/confluent_kafka/src/confluent_kafka.c @@ -2037,11 +2037,59 @@ static int py_extensions_to_c(char **extensions, return 1; } + +/** + * @brief Waits for OAuth callback to set the token + * + * Useful during client init as we want to ensure we have the token before we + * return back + * + * Returns 0 if token was set within the timeout period, -1 otherwise. + */ +int wait_for_oauth_token_set(Handle *h) { + + if (!h->oauth_cb) + return 0; + + int max_wait_sec = 10; + int retry_interval_sec = 1; /* Check every 1 sec */ + int elapsed_sec = 0; + while (!h->oauth_token_set && elapsed_sec < max_wait_sec) { + CallState cs; + CallState_begin(h, &cs); + sleep(retry_interval_sec); + CallState_end(h, &cs); + elapsed_sec += retry_interval_sec; + } + + if (!h->oauth_token_set) { + /* Token not set within timeout */ + cfl_PyErr_Format( + RD_KAFKA_RESP_ERR_SASL_AUTHENTICATION_FAILED, + "OAuth token not set within %d seconds timeout", + max_wait_sec); + CallState cs; + CallState_begin(h, &cs); + rd_kafka_destroy(h->rk); + h->rk = NULL; + CallState_end(h, &cs); + return -1; + } + return 0; +} + +/** + * @brief Callback invoked when a OAuth token needs to be refreshed. + * + * Note that this callback will be invoked by the background thread as + * all client types have been configured to use background threads for sasl + * events. + */ static void oauth_cb(rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) { Handle *h = opaque; PyObject *eo, *result; - CallState *cs; + PyGILState_STATE gstate; const char *token; double expiry; const char *principal = ""; @@ -2051,7 +2099,7 @@ oauth_cb(rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) { char err_msg[2048]; rd_kafka_resp_err_t err_code; - cs = CallState_get(h); + gstate = PyGILState_Ensure(); eo = Py_BuildValue("s", oauthbearer_config); result = PyObject_CallFunctionObjArgs(h->oauth_cb, eo, NULL); @@ -2103,6 +2151,7 @@ oauth_cb(rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) { PyErr_Format(PyExc_ValueError, "%s", err_msg); goto fail; } + h->oauth_token_set = 1; goto done; fail: @@ -2116,10 +2165,10 @@ oauth_cb(rd_kafka_t *rk, const char *oauthbearer_config, void *opaque) { PyErr_Clear(); goto done; err: - CallState_crash(cs); + PyGILState_Release(gstate); rd_kafka_yield(h->rk); done: - CallState_resume(cs); + PyGILState_Release(gstate); } /**************************************************************************** @@ -2649,8 +2698,10 @@ rd_kafka_conf_t *common_conf_setup(rd_kafka_type_t ktype, rd_kafka_conf_set_log_cb(conf, log_cb); } - if (h->oauth_cb) + if (h->oauth_cb) { rd_kafka_conf_set_oauthbearer_token_refresh_cb(conf, oauth_cb); + rd_kafka_conf_enable_sasl_queue(conf, 1); + } rd_kafka_conf_set_opaque(conf, h); diff --git a/src/confluent_kafka/src/confluent_kafka.h b/src/confluent_kafka/src/confluent_kafka.h index 3b2fc6b7f..32414c1d8 100644 --- a/src/confluent_kafka/src/confluent_kafka.h +++ b/src/confluent_kafka/src/confluent_kafka.h @@ -239,6 +239,7 @@ typedef struct { PyObject *logger; PyObject *oauth_cb; + int oauth_token_set; union { /** @@ -465,6 +466,7 @@ PyObject *c_topic_partition_result_to_py_dict( PyObject *list_topics(Handle *self, PyObject *args, PyObject *kwargs); PyObject *list_groups(Handle *self, PyObject *args, PyObject *kwargs); PyObject *set_sasl_credentials(Handle *self, PyObject *args, PyObject *kwargs); +int wait_for_oauth_token_set(Handle *self); extern const char list_topics_doc[]; diff --git a/tests/test_misc.py b/tests/test_misc.py index 1192d61da..9deece0e8 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -8,7 +8,7 @@ import pytest import confluent_kafka -from confluent_kafka import Consumer, Producer +from confluent_kafka import Consumer, KafkaException, Producer from confluent_kafka.admin import AdminClient from tests.common import TestConsumer @@ -140,9 +140,7 @@ def oauth_cb(oauth_config): } kc = TestConsumer(conf) - - while not seen_oauth_cb: - kc.poll(timeout=0.1) + assert seen_oauth_cb # callback is expected to happen during client init kc.close() @@ -160,20 +158,77 @@ def oauth_cb(oauth_config): 'group.id': 'test', 'security.protocol': 'sasl_plaintext', 'sasl.mechanisms': 'OAUTHBEARER', - 'session.timeout.ms': 100, # Avoid close() blocking too long + 'session.timeout.ms': 100, 'sasl.oauthbearer.config': 'oauth_cb', 'oauth_cb': oauth_cb, } kc = TestConsumer(conf) - - while not seen_oauth_cb: - kc.poll(timeout=0.1) + assert seen_oauth_cb # callback is expected to happen during client init kc.close() def test_oauth_cb_failure(): - """Tests oauth_cb.""" + """ + Tests oauth_cb for a case when it fails to return a token. + We expect the client init to fail + """ + + def oauth_cb(oauth_config): + raise Exception + + conf = { + 'group.id': 'test', + 'security.protocol': 'sasl_plaintext', + 'sasl.mechanisms': 'OAUTHBEARER', + 'session.timeout.ms': 1000, + 'sasl.oauthbearer.config': 'oauth_cb', + 'oauth_cb': oauth_cb, + } + + with pytest.raises(KafkaException): + TestConsumer(conf) + + +def test_oauth_cb_token_refresh_success(): + """ + Tests whether oauth callback gets called multiple times by the background thread + """ + oauth_cb_count = 0 + + def oauth_cb(oauth_config): + nonlocal oauth_cb_count + oauth_cb_count += 1 + assert oauth_config == 'oauth_cb' + return 'token', time.time() + 3 # token is returned with an expiry of 3 seconds + + conf = { + 'group.id': 'test', + 'security.protocol': 'sasl_plaintext', + 'sasl.mechanisms': 'OAUTHBEARER', + 'session.timeout.ms': 1000, + 'sasl.oauthbearer.config': 'oauth_cb', + 'oauth_cb': oauth_cb, + } + + kc = TestConsumer(conf) # callback is expected to happen during client init + assert oauth_cb_count == 1 + + # Check every 1 second for up to 5 seconds for callback count to increase + max_wait_sec = 5 + elapsed_sec = 0 + while oauth_cb_count == 1 and elapsed_sec < max_wait_sec: + time.sleep(1) + elapsed_sec += 1 + + kc.close() + assert oauth_cb_count > 1 + + +def test_oauth_cb_token_refresh_failure(): + """ + Tests whether oauth callback gets called again if token refresh failed in one of the calls after init + """ oauth_cb_count = 0 def oauth_cb(oauth_config): @@ -181,8 +236,8 @@ def oauth_cb(oauth_config): oauth_cb_count += 1 assert oauth_config == 'oauth_cb' if oauth_cb_count == 2: - return 'token', time.time() + 100.0, oauth_config, {"extthree": "extthreeval"} - raise Exception + raise Exception + return 'token', time.time() + 3 # token is returned with an expiry of 3 seconds conf = { 'group.id': 'test', @@ -193,11 +248,19 @@ def oauth_cb(oauth_config): 'oauth_cb': oauth_cb, } - kc = TestConsumer(conf) + kc = TestConsumer(conf) # callback is expected to happen during client init + assert oauth_cb_count == 1 + + # Check every 1 second for up to 15 seconds for callback count to increase + # Call back failure causes a refresh attempt after 10 secs, so ideally 2 callbacks should happen within 15 secs + max_wait_sec = 15 + elapsed_sec = 0 + while oauth_cb_count <= 2 and elapsed_sec < max_wait_sec: + time.sleep(1) + elapsed_sec += 1 - while oauth_cb_count < 2: - kc.poll(timeout=0.1) kc.close() + assert oauth_cb_count > 2 def skip_interceptors():