diff --git a/django_cloud_tasks/views.py b/django_cloud_tasks/views.py index 2992ba2..376fe63 100644 --- a/django_cloud_tasks/views.py +++ b/django_cloud_tasks/views.py @@ -3,8 +3,13 @@ from django.apps import apps from django.http import JsonResponse +from django.utils.decorators import method_decorator +from django.views.decorators.csrf import csrf_exempt from django.views.generic import View +from gcp_pilot.base import DEFAULT_SERVICE_ACCOUNT from gcp_pilot.pubsub import Message +from google.auth.transport import requests +from google.oauth2 import id_token from django_cloud_tasks import exceptions from django_cloud_tasks.exceptions import TaskNotFound @@ -16,8 +21,26 @@ logger = logging.getLogger("django_cloud_tasks") +def verify_oidc_token(request: HttpRequest): + auth_header: str = request.headers.get("Authorization") + + if not auth_header: + raise PermissionDenied("No auth header") + + auth_type, creds = auth_header.split(" ", 1) + if auth_type.capitalize() != "Bearer": + raise PermissionDenied("Wrong auth_type " + auth_type) + + claims = id_token.verify_token(creds, requests.Request()) + if claims['email'] != DEFAULT_SERVICE_ACCOUNT: + raise PermissionDenied("Unauthorised user " + claims['user']) + + class GoogleCloudTaskView(View): + @method_decorator(csrf_exempt) def post(self, request, task_name, *args, **kwargs): + verify_oidc_token(request) + try: task_class = self.get_task(name=task_name) except TaskNotFound: