Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,7 @@ addition to the properties dictated by the underlying librdkafka C library:
where ``expiry_time`` is the time in seconds since the epoch as a floating point number.
This callback is useful only when ``sasl.mechanisms=OAUTHBEARER`` is set and
is served to get the initial token before a successful broker connection can be made.
The callback can be triggered by calling ``client.poll()`` or ``producer.flush()``.
The callback is asynchronously triggered by the background thread to maintain token validity.``.

* ``on_delivery(kafka.KafkaError, kafka.Message)`` (**Producer**): value is a Python function reference
that is called once for each produced message to indicate the final
Expand Down
12 changes: 12 additions & 0 deletions src/confluent_kafka/src/Admin.c
Original file line number Diff line number Diff line change
Expand Up @@ -5346,10 +5346,22 @@ static int Admin_init (PyObject *selfobj, PyObject *args, PyObject *kwargs) {
return -1;
}

/* Enable SASL callbacks on background thread for AdminClient since
* applications typically don't call poll() regularly on AdminClient. */
if (self->oauth_cb) {
rd_kafka_sasl_background_callbacks_enable(self->rk);
}

/* Forward log messages to poll queue */
if (self->logger)
rd_kafka_set_log_queue(self->rk, NULL);


/* Wait for the background thread to set the token */
if (self->oauth_cb) {
return wait_for_oauth_token_set(self);
}

return 0;
}

Expand Down
11 changes: 11 additions & 0 deletions src/confluent_kafka/src/Consumer.c
Original file line number Diff line number Diff line change
Expand Up @@ -1695,6 +1695,11 @@ static int Consumer_init (PyObject *selfobj, PyObject *args, PyObject *kwargs) {
return -1;
}

/* Enable Token Refresh to be handled by background thread if OAuth callback is provided */
if (self->oauth_cb) {
rd_kafka_sasl_background_callbacks_enable(self->rk);
}

/* Forward log messages to main queue which is then forwarded
* to the consumer queue */
if (self->logger)
Expand All @@ -1705,6 +1710,12 @@ static int Consumer_init (PyObject *selfobj, PyObject *args, PyObject *kwargs) {
self->u.Consumer.rkqu = rd_kafka_queue_get_consumer(self->rk);
assert(self->u.Consumer.rkqu);


/* Wait for the background thread to set the token */
if (self->oauth_cb) {
return wait_for_oauth_token_set(self);
}

return 0;
}

Expand Down
10 changes: 10 additions & 0 deletions src/confluent_kafka/src/Producer.c
Original file line number Diff line number Diff line change
Expand Up @@ -1316,10 +1316,20 @@ static int Producer_init (PyObject *selfobj, PyObject *args, PyObject *kwargs) {
return -1;
}

/* Enable Token Refresh to be handled by background thread if OAuth callback is provided */
if (self->oauth_cb) {
rd_kafka_sasl_background_callbacks_enable(self->rk);
}

/* Forward log messages to poll queue */
if (self->logger)
rd_kafka_set_log_queue(self->rk, NULL);

/* Wait for the background thread to set the token */
if (self->oauth_cb) {
return wait_for_oauth_token_set(self);
}

return 0;
}

Expand Down
58 changes: 53 additions & 5 deletions src/confluent_kafka/src/confluent_kafka.c
Original file line number Diff line number Diff line change
Expand Up @@ -2053,11 +2053,56 @@ static int py_extensions_to_c (char **extensions, Py_ssize_t idx,
return 1;
}


/**
* @brief Waits for OAuth callback to set the token
*
* Useful during client init as we want to ensure we have the token before we return back
*
* Returns 0 if token was set within the timeout period, -1 otherwise.
*/
int wait_for_oauth_token_set(Handle *h) {

if (!h->oauth_cb)
return 0;

int max_wait_sec = 5;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did we pick this timeout? I'm not sure what the longest oauth wait time will be. Maybe this should be a bit longer in case the loopback is slow?

Copy link
Member Author

@ojasvajain ojasvajain Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No particular reason to keep it 5 sec. It should be fine to increase it (say, to 10s) as this init is a one time thing and it will help reduce flakiness for clients with slower callback functions. For clients with faster callbacks, this shouldn't be a problem as we are breaking out of the loop if callback succeeds early.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Increased to 10s

int retry_interval_sec = 1; /* Check every 1 sec */
int elapsed_sec = 0;
while (!h->oauth_token_set && elapsed_sec < max_wait_sec) {
CallState cs;
CallState_begin(h, &cs);
sleep(retry_interval_sec);
CallState_end(h, &cs);
elapsed_sec += retry_interval_sec;
}

if (!h->oauth_token_set) {
/* Token not set within timeout */
cfl_PyErr_Format(RD_KAFKA_RESP_ERR_SASL_AUTHENTICATION_FAILED,
"OAuth token not set within %d seconds timeout",
max_wait_sec);
CallState cs;
CallState_begin(h, &cs);
rd_kafka_destroy(h->rk);
h->rk = NULL;
CallState_end(h, &cs);
return -1;
}
return 0;
}

/**
* @brief Callback invoked when a OAuth token needs to be refreshed.
*
* Note that this callback will be invoked by the background thread as
* all client types have been configured to use background threads for sasl events.
*/
static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config,
void *opaque) {
Handle *h = opaque;
PyObject *eo, *result;
CallState *cs;
PyGILState_STATE gstate;
const char *token;
double expiry;
const char *principal = "";
Expand All @@ -2067,7 +2112,7 @@ static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config,
char err_msg[2048];
rd_kafka_resp_err_t err_code;

cs = CallState_get(h);
gstate = PyGILState_Ensure();

eo = Py_BuildValue("s", oauthbearer_config);
result = PyObject_CallFunctionObjArgs(h->oauth_cb, eo, NULL);
Expand Down Expand Up @@ -2116,6 +2161,7 @@ static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config,
PyErr_Format(PyExc_ValueError, "%s", err_msg);
goto fail;
}
h->oauth_token_set = 1;
goto done;

fail:
Expand All @@ -2127,10 +2173,10 @@ static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config,
PyErr_Clear();
goto done;
err:
CallState_crash(cs);
PyGILState_Release(gstate);
rd_kafka_yield(h->rk);
done:
CallState_resume(cs);
PyGILState_Release(gstate);
}

/****************************************************************************
Expand Down Expand Up @@ -2650,8 +2696,10 @@ rd_kafka_conf_t *common_conf_setup (rd_kafka_type_t ktype,
rd_kafka_conf_set_log_cb(conf, log_cb);
}

if (h->oauth_cb)
if (h->oauth_cb) {
rd_kafka_conf_set_oauthbearer_token_refresh_cb(conf, oauth_cb);
rd_kafka_conf_enable_sasl_queue(conf, 1);
}

rd_kafka_conf_set_opaque(conf, h);

Expand Down
2 changes: 2 additions & 0 deletions src/confluent_kafka/src/confluent_kafka.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ typedef struct {

PyObject *logger;
PyObject *oauth_cb;
int oauth_token_set;

union {
/**
Expand Down Expand Up @@ -444,6 +445,7 @@ PyObject *c_topic_partition_result_to_py_dict(
PyObject *list_topics (Handle *self, PyObject *args, PyObject *kwargs);
PyObject *list_groups (Handle *self, PyObject *args, PyObject *kwargs);
PyObject *set_sasl_credentials(Handle *self, PyObject *args, PyObject *kwargs);
int wait_for_oauth_token_set(Handle *self);


extern const char list_topics_doc[];
Expand Down
89 changes: 75 additions & 14 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys

import confluent_kafka
from confluent_kafka import Consumer, Producer
from confluent_kafka import Consumer, Producer, KafkaException
from confluent_kafka.admin import AdminClient

from tests.common import TestConsumer
Expand Down Expand Up @@ -135,9 +135,7 @@ def oauth_cb(oauth_config):
}

kc = TestConsumer(conf)

while not seen_oauth_cb:
kc.poll(timeout=0.1)
assert seen_oauth_cb # callback is expected to happen during client init
kc.close()


Expand All @@ -154,29 +152,84 @@ def oauth_cb(oauth_config):
conf = {'group.id': 'test',
'security.protocol': 'sasl_plaintext',
'sasl.mechanisms': 'OAUTHBEARER',
'session.timeout.ms': 100, # Avoid close() blocking too long
'session.timeout.ms': 100,
'sasl.oauthbearer.config': 'oauth_cb',
'oauth_cb': oauth_cb
}

kc = TestConsumer(conf)

while not seen_oauth_cb:
kc.poll(timeout=0.1)
assert seen_oauth_cb # callback is expected to happen during client init
kc.close()


def test_oauth_cb_failure():
""" Tests oauth_cb. """
"""
Tests oauth_cb for a case when it fails to return a token.
We expect the client init to fail
"""

def oauth_cb(oauth_config):
raise Exception

conf = {'group.id': 'test',
'security.protocol': 'sasl_plaintext',
'sasl.mechanisms': 'OAUTHBEARER',
'session.timeout.ms': 1000,
'sasl.oauthbearer.config': 'oauth_cb',
'oauth_cb': oauth_cb
}

with pytest.raises(KafkaException):
TestConsumer(conf)


def test_oauth_cb_token_refresh_success():
"""
Tests whether oauth callback gets called multiple times by the background thread
"""
oauth_cb_count = 0

def oauth_cb(oauth_config):
nonlocal oauth_cb_count
oauth_cb_count += 1
assert oauth_config == 'oauth_cb'
return 'token', time.time() + 3 # token is returned with an expiry of 3 seconds

conf = {'group.id': 'test',
'security.protocol': 'sasl_plaintext',
'sasl.mechanisms': 'OAUTHBEARER',
'session.timeout.ms': 1000,
'sasl.oauthbearer.config': 'oauth_cb',
'oauth_cb': oauth_cb
}

kc = TestConsumer(conf) # callback is expected to happen during client init
assert oauth_cb_count == 1

# Check every 1 second for up to 5 seconds for callback count to increase
max_wait_sec = 5
elapsed_sec = 0
while oauth_cb_count == 1 and elapsed_sec < max_wait_sec:
time.sleep(1)
elapsed_sec += 1

kc.close()
assert oauth_cb_count > 1


def test_oauth_cb_token_refresh_failure():
"""
Tests whether oauth callback gets called again if token refresh failed in one of the calls after init
"""
oauth_cb_count = 0

def oauth_cb(oauth_config):
nonlocal oauth_cb_count
oauth_cb_count += 1
assert oauth_config == 'oauth_cb'
if oauth_cb_count == 2:
return 'token', time.time() + 100.0, oauth_config, {"extthree": "extthreeval"}
raise Exception
raise Exception
return 'token', time.time() + 3 # token is returned with an expiry of 3 seconds

conf = {'group.id': 'test',
'security.protocol': 'sasl_plaintext',
Expand All @@ -186,11 +239,19 @@ def oauth_cb(oauth_config):
'oauth_cb': oauth_cb
}

kc = TestConsumer(conf)
kc = TestConsumer(conf) # callback is expected to happen during client init
assert oauth_cb_count == 1

# Check every 1 second for up to 15 seconds for callback count to increase
# Call back failure causes a refresh attempt after 10 secs, so ideally 2 callbacks should happen within 15 secs
max_wait_sec = 15
elapsed_sec = 0
while oauth_cb_count <= 2 and elapsed_sec < max_wait_sec:
time.sleep(1)
elapsed_sec += 1

while oauth_cb_count < 2:
kc.poll(timeout=0.1)
kc.close()
assert oauth_cb_count > 2


def skip_interceptors():
Expand Down