diff --git a/Makefile b/Makefile index 18ded35..7864c76 100644 --- a/Makefile +++ b/Makefile @@ -39,3 +39,4 @@ doc-deploy:clean ## Run Deploy Documentation doc-serve: ## Launch doc local server mkdocs serve + diff --git a/README.md b/README.md index 9bb62af..e81cfdc 100644 --- a/README.md +++ b/README.md @@ -43,8 +43,8 @@ Django Ninja Extra is a powerful extension for [Django Ninja](https://django-nin ## Requirements -- Python >= 3.6 -- Django >= 2.1 +- Python >= 3.8 +- Django >= 4.0 - Pydantic >= 1.6 - Django-Ninja >= 0.16.1 diff --git a/ninja_extra/controllers/model/service.py b/ninja_extra/controllers/model/service.py index 56aeb57..b4dcf4d 100644 --- a/ninja_extra/controllers/model/service.py +++ b/ninja_extra/controllers/model/service.py @@ -1,15 +1,41 @@ import traceback import typing as t +from functools import wraps +import django from asgiref.sync import sync_to_async from django.db.models import Model, QuerySet from pydantic import BaseModel as PydanticModel from ninja_extra.exceptions import NotFound -from ninja_extra.shortcuts import get_object_or_exception +from ninja_extra.shortcuts import aget_object_or_exception, get_object_or_exception from .interfaces import AsyncModelServiceBase, ModelServiceBase +django_version_greater_than_4_2 = django.VERSION > (4, 2) + + +def _async_django_support(sync_method_name: str) -> t.Callable[..., t.Any]: + """ + Ensures that django version supports async orm methods. + If not, it will use sync_to_async to call the sync method with thread_sensitive=True. + """ + + def decorator(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Coroutine]: + @wraps(func) + async def wrapper(self: "ModelService", *args: t.Any, **kwargs: t.Any) -> t.Any: + if not django_version_greater_than_4_2: + alternate_method = getattr(self, sync_method_name) + return await sync_to_async(alternate_method, thread_sensitive=True)( + *args, **kwargs + ) + + return await func(self, *args, **kwargs) + + return wrapper + + return decorator + class ModelService(ModelServiceBase, AsyncModelServiceBase): """ @@ -21,21 +47,17 @@ class ModelService(ModelServiceBase, AsyncModelServiceBase): def __init__(self, model: t.Type[Model]) -> None: self.model = model + # --- Synchronous Methods --- + def get_one(self, pk: t.Any, **kwargs: t.Any) -> t.Any: obj = get_object_or_exception( klass=self.model, error_message=None, exception=NotFound, pk=pk ) return obj - async def get_one_async(self, pk: t.Any, **kwargs: t.Any) -> t.Any: - return await sync_to_async(self.get_one, thread_sensitive=True)(pk, **kwargs) - def get_all(self, **kwargs: t.Any) -> t.Union[QuerySet, t.List[t.Any]]: return self.model.objects.all() - async def get_all_async(self, **kwargs: t.Any) -> t.Union[QuerySet, t.List[t.Any]]: - return await sync_to_async(self.get_all, thread_sensitive=True)(**kwargs) - def create(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any: data = schema.model_dump(by_alias=True) data.update(kwargs) @@ -63,9 +85,6 @@ def create(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any: ) raise TypeError(msg) from tex - async def create_async(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any: - return await sync_to_async(self.create, thread_sensitive=True)(schema, **kwargs) - def update(self, instance: Model, schema: PydanticModel, **kwargs: t.Any) -> t.Any: data = schema.model_dump(exclude_none=True) data.update(kwargs) @@ -74,23 +93,69 @@ def update(self, instance: Model, schema: PydanticModel, **kwargs: t.Any) -> t.A instance.save() return instance + def patch(self, instance: Model, schema: PydanticModel, **kwargs: t.Any) -> t.Any: + return self.update(instance=instance, schema=schema, **kwargs) + + def delete(self, instance: Model, **kwargs: t.Any) -> t.Any: + instance.delete() + + # --- Asynchronous Methods (using native async ORM where possible) --- + + @_async_django_support("get_one") + async def get_one_async(self, pk: t.Any, **kwargs: t.Any) -> t.Any: + obj = await aget_object_or_exception( + klass=self.model, error_message=None, exception=NotFound, pk=pk + ) + return obj + + async def get_all_async(self, **kwargs: t.Any) -> t.Union[QuerySet, t.List[t.Any]]: + return await sync_to_async(self.get_all, thread_sensitive=True)(**kwargs) + + @_async_django_support("create") + async def create_async(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any: + data = schema.model_dump(by_alias=True) + data.update(kwargs) + + try: + instance = await self.model._default_manager.acreate(**data) + return instance + except TypeError as tex: # pragma: no cover + tb = traceback.format_exc() + msg = ( + "Got a `TypeError` when calling `%s.%s.create()`. " + "This may be because you have a writable field on the " + "serializer class that is not a valid argument to " + "`%s.%s.create()`. You may need to make the field " + "read-only, or override the %s.create() method to handle " + "this correctly.\nOriginal exception was:\n %s" + % ( + self.model.__name__, + self.model._default_manager.name, + self.model.__name__, + self.model._default_manager.name, + self.__class__.__name__, + tb, + ) + ) + raise TypeError(msg) from tex + + @_async_django_support("update") async def update_async( self, instance: Model, schema: PydanticModel, **kwargs: t.Any ) -> t.Any: - return await sync_to_async(self.update, thread_sensitive=True)( - instance, schema, **kwargs - ) - - def patch(self, instance: Model, schema: PydanticModel, **kwargs: t.Any) -> t.Any: - return self.update(instance=instance, schema=schema, **kwargs) + data = schema.model_dump(exclude_none=True) + data.update(kwargs) + for attr, value in data.items(): + setattr(instance, attr, value) + await instance.asave() + return instance + @_async_django_support("patch") async def patch_async( self, instance: Model, schema: PydanticModel, **kwargs: t.Any ) -> t.Any: return await self.update_async(instance=instance, schema=schema, **kwargs) - def delete(self, instance: Model, **kwargs: t.Any) -> t.Any: - instance.delete() - + @_async_django_support("delete") async def delete_async(self, instance: Model, **kwargs: t.Any) -> t.Any: - return await sync_to_async(self.delete, thread_sensitive=True)(instance) + await instance.adelete() diff --git a/pyproject.toml b/pyproject.toml index f0411d0..4dcf1d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ requires = [ "contextlib2" ] description-file = "README.md" -requires-python = ">=3.7" +requires-python = ">=3.8" [tool.flit.metadata.urls] diff --git a/tests/test_model_controller/test_async_django_support.py b/tests/test_model_controller/test_async_django_support.py new file mode 100644 index 0000000..47caa7d --- /dev/null +++ b/tests/test_model_controller/test_async_django_support.py @@ -0,0 +1,141 @@ +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from pydantic import BaseModel + +from ninja_extra.controllers.model.service import ModelService, _async_django_support +from ninja_extra.exceptions import NotFound + +from ..models import Event + + +class EventTestSchema(BaseModel): + title: str + start_date: str + end_date: str + + +class TestAsyncDjangoSupport: + """Test the _async_django_support decorator functionality""" + + def setup_method(self): + """Set up test fixtures""" + self.service = ModelService(Event) + self.test_schema = EventTestSchema( + title="Test Event", start_date="2020-01-01", end_date="2020-01-02" + ) + + @pytest.mark.asyncio + @patch( + "ninja_extra.controllers.model.service.django_version_greater_than_4_2", False + ) + @patch("ninja_extra.controllers.model.service.sync_to_async") + async def test_async_django_support_with_exception_in_sync_method( + self, mock_sync_to_async + ): + """Test decorator behavior when sync method raises an exception in Django < 4.2""" + + # Mock sync method that raises an exception + sync_method = Mock(side_effect=NotFound("Test error")) + mock_async_wrapper = AsyncMock(side_effect=NotFound("Test error")) + mock_sync_to_async.return_value = mock_async_wrapper + + # Add the sync method to the service instance + self.service.test_sync_method = sync_method + + # Create a dummy async method + async_method = AsyncMock() + + # Apply the decorator + decorated_method = _async_django_support("test_sync_method")(async_method) + + # Verify the exception is propagated + with pytest.raises(NotFound, match="Test error"): + await decorated_method(self.service, "test_arg") + + @pytest.mark.asyncio + @patch( + "ninja_extra.controllers.model.service.django_version_greater_than_4_2", True + ) + async def test_async_django_support_with_exception_in_async_method(self): + """Test decorator behavior when async method raises an exception in Django >= 4.2""" + + # Mock an async method that raises an exception + original_async_method = AsyncMock(side_effect=NotFound("Async test error")) + + # Apply the decorator + decorated_method = _async_django_support("sync_method_name")( + original_async_method + ) + + # Verify the exception is propagated + with pytest.raises(NotFound, match="Async test error"): + await decorated_method(self.service, "test_arg") + + @pytest.mark.asyncio + @pytest.mark.django_db + async def test_actual_model_service_methods_with_django_4_2_true(self): + """Integration test with actual ModelService methods when Django >= 4.2""" + + with patch( + "ninja_extra.controllers.model.service.django_version_greater_than_4_2", + True, + ): + # Test create_async + event = await self.service.create_async(self.test_schema) + assert event.title == "Test Event" + assert str(event.start_date) == "2020-01-01" + + # Test get_one_async + retrieved_event = await self.service.get_one_async(event.pk) + assert retrieved_event.id == event.id + assert retrieved_event.title == "Test Event" + + # Test update_async + update_schema = EventTestSchema( + title="Updated Event", start_date="2020-01-01", end_date="2020-01-02" + ) + updated_event = await self.service.update_async(event, update_schema) + assert updated_event.title == "Updated Event" + + # Test delete_async + await self.service.delete_async(event) + + # Verify event is deleted + with pytest.raises(NotFound): + await self.service.get_one_async(event.pk) + + @pytest.mark.asyncio + @pytest.mark.django_db + async def test_actual_model_service_methods_with_django_4_2_false(self): + """Integration test with actual ModelService methods when Django < 4.2""" + + with patch( + "ninja_extra.controllers.model.service.django_version_greater_than_4_2", + False, + ): + # Test create_async (should fall back to sync method) + event = await self.service.create_async(self.test_schema) + assert event.title == "Test Event" + assert str(event.start_date) == "2020-01-01" + + # Test get_one_async (should fall back to sync method) + retrieved_event = await self.service.get_one_async(event.pk) + assert retrieved_event.id == event.id + assert retrieved_event.title == "Test Event" + + # Test update_async (should fall back to sync method) + update_schema = EventTestSchema( + title="Updated Event Sync", + start_date="2020-01-01", + end_date="2020-01-02", + ) + updated_event = await self.service.update_async(event, update_schema) + assert updated_event.title == "Updated Event Sync" + + # Test delete_async (should fall back to sync method) + await self.service.delete_async(event) + + # Verify event is deleted + with pytest.raises(NotFound): + await self.service.get_one_async(event.pk)