Skip to content

Commit 1180914

Browse files
authored
python credential provider (#215)
AwsCredentialsProvider.new_delegate(...) allows custom credentials provider to be defined with python code
1 parent 2027d61 commit 1180914

File tree

7 files changed

+191
-36
lines changed

7 files changed

+191
-36
lines changed

awscrt/auth.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -101,44 +101,29 @@ def __deepcopy__(self, memo):
101101

102102

103103
class AwsCredentialsProviderBase(NativeResource):
104-
"""
105-
Base class for providers that source the AwsCredentials needed to sign an authenticated AWS request.
106-
107-
NOTE: Custom subclasses of AwsCredentialsProviderBase are not yet supported.
108-
"""
109-
__slots__ = ()
110-
111-
def __init__(self, binding=None):
112-
super().__init__()
113-
114-
if binding is None:
115-
# TODO: create binding type that lets native code call into python subclass
116-
raise NotImplementedError("Custom subclasses of AwsCredentialsProviderBase are not yet supported")
117-
118-
self._binding = binding
119-
120-
def get_credentials(self):
121-
"""
122-
Asynchronously fetch AwsCredentials.
123-
124-
Returns:
125-
concurrent.futures.Future: A Future which will contain
126-
:class:`AwsCredentials` (or an exception) when the operation completes.
127-
The operation may complete on a different thread.
128-
"""
129-
raise NotImplementedError()
104+
# Pointless base class, kept for backwards compatibility.
105+
# AwsCredentialsProvider is (and always will be) the only subclass.
106+
#
107+
# Originally created with the thought that, when we supported
108+
# custom python providers, they would inherit from this class.
109+
# We ended up supporting custom python providers via
110+
# AwsCredentialsProvider.new_delegate() instead.
111+
pass
130112

131113

132114
class AwsCredentialsProvider(AwsCredentialsProviderBase):
133115
"""
134116
Credentials providers source the AwsCredentials needed to sign an authenticated AWS request.
135117
136-
Base class: AwsCredentialsProviderBase
137-
138118
This class provides `new()` functions for several built-in provider types.
119+
To define a custom provider, use the `new_delegate()` function.
139120
"""
140121
__slots__ = ()
141122

123+
def __init__(self, binding):
124+
super().__init__()
125+
self._binding = binding
126+
142127
@classmethod
143128
def new_default_chain(cls, client_bootstrap):
144129
"""
@@ -289,7 +274,35 @@ def new_chain(cls, providers):
289274
binding = _awscrt.credentials_provider_new_chain(providers)
290275
return cls(binding)
291276

277+
@classmethod
278+
def new_delegate(cls, get_credentials):
279+
"""
280+
Creates a provider that sources credentials from a custom
281+
synchronous callback.
282+
283+
Args:
284+
get_credentials: Callable which takes no arguments and returns
285+
:class:`AwsCredentials`.
286+
287+
Returns:
288+
AwsCredentialsProvider:
289+
"""
290+
# TODO: support async delegates
291+
292+
assert callable(get_credentials)
293+
294+
binding = _awscrt.credentials_provider_new_delegate(get_credentials)
295+
return cls(binding)
296+
292297
def get_credentials(self):
298+
"""
299+
Asynchronously fetch AwsCredentials.
300+
301+
Returns:
302+
concurrent.futures.Future: A Future which will contain
303+
:class:`AwsCredentials` (or an exception) when the operation completes.
304+
The operation may complete on a different thread.
305+
"""
293306
future = Future()
294307

295308
def _on_complete(error_code, binding):
@@ -306,7 +319,7 @@ def _on_complete(error_code, binding):
306319
try:
307320
_awscrt.credentials_provider_get_credentials(self._binding, _on_complete)
308321
except Exception as e:
309-
future.set_result(e)
322+
future.set_exception(e)
310323

311324
return future
312325

@@ -383,7 +396,7 @@ class AwsSigningConfig(NativeResource):
383396
signature_type (AwsSignatureType): Which sort of signature should be
384397
computed from the signable.
385398
386-
credentials_provider (AwsCredentialsProviderBase): Credentials provider
399+
credentials_provider (AwsCredentialsProvider): Credentials provider
387400
to fetch signing credentials with.
388401
389402
region (str): The region to sign against.
@@ -470,7 +483,7 @@ def __init__(self,
470483

471484
assert isinstance(algorithm, AwsSigningAlgorithm)
472485
assert isinstance(signature_type, AwsSignatureType)
473-
assert isinstance(credentials_provider, AwsCredentialsProviderBase)
486+
assert isinstance(credentials_provider, AwsCredentialsProvider)
474487
assert isinstance(region, str)
475488
assert isinstance(service, str)
476489
assert callable(should_sign_header) or should_sign_header is None
@@ -533,7 +546,7 @@ def signature_type(self):
533546

534547
@property
535548
def credentials_provider(self):
536-
"""AwsCredentialsProviderBase: Credentials provider to fetch signing credentials with"""
549+
"""AwsCredentialsProvider: Credentials provider to fetch signing credentials with"""
537550
return _awscrt.signing_config_get_credentials_provider(self._binding)
538551

539552
@property

awscrt/awsiot_mqtt_connection_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def websockets_with_default_aws_signing(region, credentials_provider, websocket_
226226
Arguments:
227227
region (str): AWS region to use when signing.
228228
229-
credentials_provider (awscrt.auth.AwsCredentialsProviderBase): Source of AWS credentials to use when signing.
229+
credentials_provider (awscrt.auth.AwsCredentialsProvider): Source of AWS credentials to use when signing.
230230
231231
websocket_proxy_options (awscrt.http.HttpProxyOptions): If specified, a proxy is used when connecting.
232232

source/auth.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ PyObject *aws_py_credentials_provider_new_profile(PyObject *self, PyObject *args
2121
PyObject *aws_py_credentials_provider_new_process(PyObject *self, PyObject *args);
2222
PyObject *aws_py_credentials_provider_new_environment(PyObject *self, PyObject *args);
2323
PyObject *aws_py_credentials_provider_new_chain(PyObject *self, PyObject *args);
24+
PyObject *aws_py_credentials_provider_new_delegate(PyObject *self, PyObject *args);
2425

2526
PyObject *aws_py_signing_config_new(PyObject *self, PyObject *args);
2627
PyObject *aws_py_signing_config_get_algorithm(PyObject *self, PyObject *args);

source/auth_credentials.c

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,16 @@ PyObject *aws_py_credentials_expiration_timestamp_seconds(PyObject *self, PyObje
135135
*/
136136
struct credentials_provider_binding {
137137
struct aws_credentials_provider *native;
138+
139+
/* Python get_credentials() callable.
140+
* Only used by "delegate" provider type */
141+
PyObject *py_delegate;
138142
};
139143

140144
/* Finally clean up binding (after capsule destructor runs and credentials provider shutdown completes) */
141145
static void s_credentials_provider_binding_clean_up(struct credentials_provider_binding *binding) {
146+
Py_XDECREF(binding->py_delegate);
147+
142148
aws_mem_release(aws_py_get_allocator(), binding);
143149
}
144150

@@ -166,7 +172,7 @@ struct aws_credentials_provider *aws_py_get_credentials_provider(PyObject *crede
166172
AWS_PY_RETURN_NATIVE_FROM_BINDING(
167173
credentials_provider,
168174
s_capsule_name_credentials_provider,
169-
"AwsCredentialsProviderBase",
175+
"AwsCredentialsProvider",
170176
credentials_provider_binding);
171177
}
172178

@@ -559,3 +565,103 @@ PyObject *aws_py_credentials_provider_new_chain(PyObject *self, PyObject *args)
559565
Py_XDECREF(capsule);
560566
return NULL;
561567
}
568+
569+
static int s_credentials_provider_delegate_get_credentials(
570+
void *delegate_user_data,
571+
aws_on_get_credentials_callback_fn callback,
572+
void *callback_user_data) {
573+
574+
struct credentials_provider_binding *binding = delegate_user_data;
575+
576+
PyGILState_STATE state;
577+
if (aws_py_gilstate_ensure(&state)) {
578+
/* Python has shut down. Nothing matters anymore, but don't crash */
579+
return aws_raise_error(AWS_ERROR_INVALID_STATE);
580+
}
581+
582+
struct aws_credentials *native_credentials = NULL;
583+
584+
PyObject *py_result = PyObject_CallFunction(binding->py_delegate, "()");
585+
if (!py_result) {
586+
AWS_LOGF_ERROR(
587+
AWS_LS_AUTH_CREDENTIALS_PROVIDER,
588+
"(id=%p) Exception in get_credentials() delegate callback",
589+
(void *)binding->native);
590+
591+
PyErr_WriteUnraisable(binding->py_delegate);
592+
goto done;
593+
}
594+
595+
/* Expect py_result to be AwsCredentials (which wraps native aws_credentials). */
596+
native_credentials = aws_py_get_credentials(py_result);
597+
if (!native_credentials) {
598+
AWS_LOGF_ERROR(
599+
AWS_LS_AUTH_CREDENTIALS_PROVIDER,
600+
"(id=%p) get_credentials() delegate callback failed to return AwsCredentials",
601+
(void *)binding->native);
602+
603+
PyErr_WriteUnraisable(binding->py_delegate);
604+
goto done;
605+
}
606+
607+
/* Keep native aws_credentials alive until we pass them to callback. */
608+
aws_credentials_acquire(native_credentials);
609+
610+
done:
611+
/* Decref the python AwsCredentials (or whatever else was returned) before releasing the GIL */
612+
Py_XDECREF(py_result);
613+
614+
PyGILState_Release(state);
615+
616+
if (!native_credentials) {
617+
return aws_raise_error(AWS_ERROR_CRT_CALLBACK_EXCEPTION);
618+
}
619+
620+
callback(native_credentials, AWS_ERROR_SUCCESS, callback_user_data);
621+
aws_credentials_release(native_credentials);
622+
return AWS_OP_SUCCESS;
623+
}
624+
625+
PyObject *aws_py_credentials_provider_new_delegate(PyObject *self, PyObject *args) {
626+
(void)self;
627+
struct aws_allocator *allocator = aws_py_get_allocator();
628+
629+
PyObject *py_delegate;
630+
631+
if (!PyArg_ParseTuple(args, "O", &py_delegate)) {
632+
return NULL;
633+
}
634+
635+
struct credentials_provider_binding *binding;
636+
PyObject *capsule = s_new_credentials_provider_binding_and_capsule(&binding);
637+
if (!capsule) {
638+
return NULL;
639+
}
640+
641+
binding->py_delegate = py_delegate;
642+
Py_INCREF(py_delegate);
643+
644+
/* From hereon, we need to clean up if errors occur.
645+
* Fortunately, the capsule destructor will clean up anything stored inside the binding */
646+
647+
struct aws_credentials_provider_delegate_options options = {
648+
.get_credentials = s_credentials_provider_delegate_get_credentials,
649+
.delegate_user_data = binding,
650+
.shutdown_options =
651+
{
652+
.shutdown_callback = s_credentials_provider_shutdown_complete,
653+
.shutdown_user_data = binding,
654+
},
655+
};
656+
657+
binding->native = aws_credentials_provider_new_delegate(allocator, &options);
658+
if (!binding->native) {
659+
PyErr_SetAwsLastError();
660+
goto error;
661+
}
662+
663+
return capsule;
664+
error:
665+
Py_DECREF(capsule);
666+
return NULL;
667+
}

source/module.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ int aws_py_gilstate_ensure(PyGILState_STATE *out_state) {
382382

383383
void *aws_py_get_binding(PyObject *obj, const char *capsule_name, const char *class_name) {
384384
if (!obj || obj == Py_None) {
385-
return PyErr_Format(PyExc_TypeError, "Excepted '%s', received 'NoneType'", class_name);
385+
return PyErr_Format(PyExc_TypeError, "Expected '%s', received 'NoneType'", class_name);
386386
}
387387

388388
PyObject *py_binding = PyObject_GetAttrString(obj, "_binding"); /* new reference */
@@ -554,6 +554,7 @@ static PyMethodDef s_module_methods[] = {
554554
AWS_PY_METHOD_DEF(credentials_provider_new_process, METH_VARARGS),
555555
AWS_PY_METHOD_DEF(credentials_provider_new_environment, METH_VARARGS),
556556
AWS_PY_METHOD_DEF(credentials_provider_new_chain, METH_VARARGS),
557+
AWS_PY_METHOD_DEF(credentials_provider_new_delegate, METH_VARARGS),
557558
AWS_PY_METHOD_DEF(signing_config_new, METH_VARARGS),
558559
AWS_PY_METHOD_DEF(signing_config_get_algorithm, METH_VARARGS),
559560
AWS_PY_METHOD_DEF(signing_config_get_signature_type, METH_VARARGS),

test/test_auth.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,40 @@ def test_process_provider(self):
181181
self.assertTrue('process_secret_access_key' == credentials.secret_access_key)
182182
self.assertTrue(credentials.session_token is None)
183183

184+
def test_delegate_provider(self):
185+
def delegate_get_credentials():
186+
return awscrt.auth.AwsCredentials("accesskey", "secretAccessKey", "sessionToken")
187+
188+
provider = awscrt.auth.AwsCredentialsProvider.new_delegate(delegate_get_credentials)
189+
credentials = provider.get_credentials().result(TIMEOUT)
190+
191+
# Don't use assertEqual(), which could log actual credentials if test fails.
192+
self.assertTrue('accesskey' == credentials.access_key_id)
193+
self.assertTrue('secretAccessKey' == credentials.secret_access_key)
194+
self.assertTrue('sessionToken' == credentials.session_token)
195+
196+
def test_delegate_provider_exception(self):
197+
# delegate that raises exception should result in exception
198+
def delegate_get_credentials():
199+
raise Exception("purposefully thrown exception")
200+
201+
provider = awscrt.auth.AwsCredentialsProvider.new_delegate(delegate_get_credentials)
202+
203+
with self.assertRaises(Exception):
204+
credentials_future = provider.get_credentials()
205+
credentials = credentials_future.result(TIMEOUT)
206+
207+
def test_delegate_provider_exception_from_bad_return_type(self):
208+
# delegate that returns wrong type should result in exception
209+
def delegate_get_credentials():
210+
return "purposefully return wrong type"
211+
212+
provider = awscrt.auth.AwsCredentialsProvider.new_delegate(delegate_get_credentials)
213+
214+
with self.assertRaises(Exception):
215+
credentials_future = provider.get_credentials()
216+
credentials = credentials_future.result(TIMEOUT)
217+
184218

185219
class TestSigningConfig(NativeResourceTest):
186220
def test_create(self):

0 commit comments

Comments
 (0)