Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ doc-deploy:clean ## Run Deploy Documentation

doc-serve: ## Launch doc local server
mkdocs serve

4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ Django Ninja Extra is a powerful extension for [Django Ninja](https://django-nin

## Requirements

- Python >= 3.6
- Django >= 2.1
- Python >= 3.8
- Django >= 4.0
- Pydantic >= 1.6
- Django-Ninja >= 0.16.1

Expand Down
105 changes: 85 additions & 20 deletions ninja_extra/controllers/model/service.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,41 @@
import traceback
import typing as t
from functools import wraps

import django
from asgiref.sync import sync_to_async
from django.db.models import Model, QuerySet
from pydantic import BaseModel as PydanticModel

from ninja_extra.exceptions import NotFound
from ninja_extra.shortcuts import get_object_or_exception
from ninja_extra.shortcuts import aget_object_or_exception, get_object_or_exception

from .interfaces import AsyncModelServiceBase, ModelServiceBase

django_version_greater_than_4_2 = django.VERSION > (4, 2)


def _async_django_support(sync_method_name: str) -> t.Callable[..., t.Any]:
"""
Ensures that django version supports async orm methods.
If not, it will use sync_to_async to call the sync method with thread_sensitive=True.
"""

def decorator(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Coroutine]:
@wraps(func)
async def wrapper(self: "ModelService", *args: t.Any, **kwargs: t.Any) -> t.Any:
if not django_version_greater_than_4_2:
alternate_method = getattr(self, sync_method_name)
return await sync_to_async(alternate_method, thread_sensitive=True)(
*args, **kwargs
)

return await func(self, *args, **kwargs)

return wrapper

return decorator


class ModelService(ModelServiceBase, AsyncModelServiceBase):
"""
Expand All @@ -21,21 +47,17 @@ class ModelService(ModelServiceBase, AsyncModelServiceBase):
def __init__(self, model: t.Type[Model]) -> None:
self.model = model

# --- Synchronous Methods ---

def get_one(self, pk: t.Any, **kwargs: t.Any) -> t.Any:
obj = get_object_or_exception(
klass=self.model, error_message=None, exception=NotFound, pk=pk
)
return obj

async def get_one_async(self, pk: t.Any, **kwargs: t.Any) -> t.Any:
return await sync_to_async(self.get_one, thread_sensitive=True)(pk, **kwargs)

def get_all(self, **kwargs: t.Any) -> t.Union[QuerySet, t.List[t.Any]]:
return self.model.objects.all()

async def get_all_async(self, **kwargs: t.Any) -> t.Union[QuerySet, t.List[t.Any]]:
return await sync_to_async(self.get_all, thread_sensitive=True)(**kwargs)

def create(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
data = schema.model_dump(by_alias=True)
data.update(kwargs)
Expand Down Expand Up @@ -63,9 +85,6 @@ def create(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
)
raise TypeError(msg) from tex

async def create_async(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
return await sync_to_async(self.create, thread_sensitive=True)(schema, **kwargs)

def update(self, instance: Model, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
data = schema.model_dump(exclude_none=True)
data.update(kwargs)
Expand All @@ -74,23 +93,69 @@ def update(self, instance: Model, schema: PydanticModel, **kwargs: t.Any) -> t.A
instance.save()
return instance

def patch(self, instance: Model, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
return self.update(instance=instance, schema=schema, **kwargs)

def delete(self, instance: Model, **kwargs: t.Any) -> t.Any:
instance.delete()

# --- Asynchronous Methods (using native async ORM where possible) ---

@_async_django_support("get_one")
async def get_one_async(self, pk: t.Any, **kwargs: t.Any) -> t.Any:
obj = await aget_object_or_exception(
klass=self.model, error_message=None, exception=NotFound, pk=pk
)
return obj

async def get_all_async(self, **kwargs: t.Any) -> t.Union[QuerySet, t.List[t.Any]]:
return await sync_to_async(self.get_all, thread_sensitive=True)(**kwargs)

@_async_django_support("create")
async def create_async(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
data = schema.model_dump(by_alias=True)
data.update(kwargs)

try:
instance = await self.model._default_manager.acreate(**data)
return instance
except TypeError as tex: # pragma: no cover
tb = traceback.format_exc()
msg = (
"Got a `TypeError` when calling `%s.%s.create()`. "
"This may be because you have a writable field on the "
"serializer class that is not a valid argument to "
"`%s.%s.create()`. You may need to make the field "
"read-only, or override the %s.create() method to handle "
"this correctly.\nOriginal exception was:\n %s"
% (
self.model.__name__,
self.model._default_manager.name,
self.model.__name__,
self.model._default_manager.name,
self.__class__.__name__,
tb,
)
)
raise TypeError(msg) from tex

@_async_django_support("update")
async def update_async(
self, instance: Model, schema: PydanticModel, **kwargs: t.Any
) -> t.Any:
return await sync_to_async(self.update, thread_sensitive=True)(
instance, schema, **kwargs
)

def patch(self, instance: Model, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
return self.update(instance=instance, schema=schema, **kwargs)
data = schema.model_dump(exclude_none=True)
data.update(kwargs)
for attr, value in data.items():
setattr(instance, attr, value)
await instance.asave()
return instance

@_async_django_support("patch")
async def patch_async(
self, instance: Model, schema: PydanticModel, **kwargs: t.Any
) -> t.Any:
return await self.update_async(instance=instance, schema=schema, **kwargs)

def delete(self, instance: Model, **kwargs: t.Any) -> t.Any:
instance.delete()

@_async_django_support("delete")
async def delete_async(self, instance: Model, **kwargs: t.Any) -> t.Any:
return await sync_to_async(self.delete, thread_sensitive=True)(instance)
await instance.adelete()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ requires = [
"contextlib2"
]
description-file = "README.md"
requires-python = ">=3.7"
requires-python = ">=3.8"


[tool.flit.metadata.urls]
Expand Down
141 changes: 141 additions & 0 deletions tests/test_model_controller/test_async_django_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from unittest.mock import AsyncMock, Mock, patch

import pytest
from pydantic import BaseModel

from ninja_extra.controllers.model.service import ModelService, _async_django_support
from ninja_extra.exceptions import NotFound

from ..models import Event


class EventTestSchema(BaseModel):
title: str
start_date: str
end_date: str


class TestAsyncDjangoSupport:
"""Test the _async_django_support decorator functionality"""

def setup_method(self):
"""Set up test fixtures"""
self.service = ModelService(Event)
self.test_schema = EventTestSchema(
title="Test Event", start_date="2020-01-01", end_date="2020-01-02"
)

@pytest.mark.asyncio
@patch(
"ninja_extra.controllers.model.service.django_version_greater_than_4_2", False
)
@patch("ninja_extra.controllers.model.service.sync_to_async")
async def test_async_django_support_with_exception_in_sync_method(
self, mock_sync_to_async
):
"""Test decorator behavior when sync method raises an exception in Django < 4.2"""

# Mock sync method that raises an exception
sync_method = Mock(side_effect=NotFound("Test error"))
mock_async_wrapper = AsyncMock(side_effect=NotFound("Test error"))
mock_sync_to_async.return_value = mock_async_wrapper

# Add the sync method to the service instance
self.service.test_sync_method = sync_method

# Create a dummy async method
async_method = AsyncMock()

# Apply the decorator
decorated_method = _async_django_support("test_sync_method")(async_method)

# Verify the exception is propagated
with pytest.raises(NotFound, match="Test error"):
await decorated_method(self.service, "test_arg")

@pytest.mark.asyncio
@patch(
"ninja_extra.controllers.model.service.django_version_greater_than_4_2", True
)
async def test_async_django_support_with_exception_in_async_method(self):
"""Test decorator behavior when async method raises an exception in Django >= 4.2"""

# Mock an async method that raises an exception
original_async_method = AsyncMock(side_effect=NotFound("Async test error"))

# Apply the decorator
decorated_method = _async_django_support("sync_method_name")(
original_async_method
)

# Verify the exception is propagated
with pytest.raises(NotFound, match="Async test error"):
await decorated_method(self.service, "test_arg")

@pytest.mark.asyncio
@pytest.mark.django_db
async def test_actual_model_service_methods_with_django_4_2_true(self):
"""Integration test with actual ModelService methods when Django >= 4.2"""

with patch(
"ninja_extra.controllers.model.service.django_version_greater_than_4_2",
True,
):
# Test create_async
event = await self.service.create_async(self.test_schema)
assert event.title == "Test Event"
assert str(event.start_date) == "2020-01-01"

# Test get_one_async
retrieved_event = await self.service.get_one_async(event.pk)
assert retrieved_event.id == event.id
assert retrieved_event.title == "Test Event"

# Test update_async
update_schema = EventTestSchema(
title="Updated Event", start_date="2020-01-01", end_date="2020-01-02"
)
updated_event = await self.service.update_async(event, update_schema)
assert updated_event.title == "Updated Event"

# Test delete_async
await self.service.delete_async(event)

# Verify event is deleted
with pytest.raises(NotFound):
await self.service.get_one_async(event.pk)

@pytest.mark.asyncio
@pytest.mark.django_db
async def test_actual_model_service_methods_with_django_4_2_false(self):
"""Integration test with actual ModelService methods when Django < 4.2"""

with patch(
"ninja_extra.controllers.model.service.django_version_greater_than_4_2",
False,
):
# Test create_async (should fall back to sync method)
event = await self.service.create_async(self.test_schema)
assert event.title == "Test Event"
assert str(event.start_date) == "2020-01-01"

# Test get_one_async (should fall back to sync method)
retrieved_event = await self.service.get_one_async(event.pk)
assert retrieved_event.id == event.id
assert retrieved_event.title == "Test Event"

# Test update_async (should fall back to sync method)
update_schema = EventTestSchema(
title="Updated Event Sync",
start_date="2020-01-01",
end_date="2020-01-02",
)
updated_event = await self.service.update_async(event, update_schema)
assert updated_event.title == "Updated Event Sync"

# Test delete_async (should fall back to sync method)
await self.service.delete_async(event)

# Verify event is deleted
with pytest.raises(NotFound):
await self.service.get_one_async(event.pk)
Loading