Skip to content

Commit 8f52ccc

Browse files
committed
fixed failing tests
1 parent 175ea9e commit 8f52ccc

File tree

2 files changed

+173
-1
lines changed

2 files changed

+173
-1
lines changed

ninja_extra/controllers/model/service.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import traceback
22
import typing as t
3+
from functools import wraps
34

5+
import django
46
from asgiref.sync import sync_to_async
57
from django.db.models import Model, QuerySet
68
from pydantic import BaseModel as PydanticModel
@@ -10,6 +12,30 @@
1012

1113
from .interfaces import AsyncModelServiceBase, ModelServiceBase
1214

15+
django_version_greater_than_4_2 = django.VERSION >= (4, 2)
16+
17+
18+
def _async_django_support(sync_method_name: str) -> t.Callable[..., t.Any]:
19+
"""
20+
Ensures that django version supports async orm methods.
21+
If not, it will use sync_to_async to call the sync method with thread_sensitive=True.
22+
"""
23+
24+
def decorator(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Coroutine]:
25+
@wraps(func)
26+
async def wrapper(self: "ModelService", *args: t.Any, **kwargs: t.Any) -> t.Any:
27+
if not django_version_greater_than_4_2:
28+
alternate_method = getattr(self, sync_method_name)
29+
return await sync_to_async(alternate_method, thread_sensitive=True)(
30+
*args, **kwargs
31+
)
32+
33+
return await func(self, *args, **kwargs)
34+
35+
return wrapper
36+
37+
return decorator
38+
1339

1440
class ModelService(ModelServiceBase, AsyncModelServiceBase):
1541
"""
@@ -21,7 +47,7 @@ class ModelService(ModelServiceBase, AsyncModelServiceBase):
2147
def __init__(self, model: t.Type[Model]) -> None:
2248
self.model = model
2349

24-
# --- Synchonous Methods ---
50+
# --- Synchronous Methods ---
2551

2652
def get_one(self, pk: t.Any, **kwargs: t.Any) -> t.Any:
2753
obj = get_object_or_exception(
@@ -75,6 +101,7 @@ def delete(self, instance: Model, **kwargs: t.Any) -> t.Any:
75101

76102
# --- Asynchronous Methods (using native async ORM where possible) ---
77103

104+
@_async_django_support("get_one")
78105
async def get_one_async(self, pk: t.Any, **kwargs: t.Any) -> t.Any:
79106
obj = await aget_object_or_exception(
80107
klass=self.model, error_message=None, exception=NotFound, pk=pk
@@ -84,6 +111,7 @@ async def get_one_async(self, pk: t.Any, **kwargs: t.Any) -> t.Any:
84111
async def get_all_async(self, **kwargs: t.Any) -> t.Union[QuerySet, t.List[t.Any]]:
85112
return await sync_to_async(self.get_all, thread_sensitive=True)(**kwargs)
86113

114+
@_async_django_support("create")
87115
async def create_async(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
88116
data = schema.model_dump(by_alias=True)
89117
data.update(kwargs)
@@ -111,6 +139,7 @@ async def create_async(self, schema: PydanticModel, **kwargs: t.Any) -> t.Any:
111139
)
112140
raise TypeError(msg) from tex
113141

142+
@_async_django_support("update")
114143
async def update_async(
115144
self, instance: Model, schema: PydanticModel, **kwargs: t.Any
116145
) -> t.Any:
@@ -121,10 +150,12 @@ async def update_async(
121150
await instance.asave()
122151
return instance
123152

153+
@_async_django_support("patch")
124154
async def patch_async(
125155
self, instance: Model, schema: PydanticModel, **kwargs: t.Any
126156
) -> t.Any:
127157
return await self.update_async(instance=instance, schema=schema, **kwargs)
128158

159+
@_async_django_support("delete")
129160
async def delete_async(self, instance: Model, **kwargs: t.Any) -> t.Any:
130161
await instance.adelete()
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from unittest.mock import AsyncMock, Mock, patch
2+
3+
import pytest
4+
from pydantic import BaseModel
5+
6+
from ninja_extra.controllers.model.service import ModelService, _async_django_support
7+
from ninja_extra.exceptions import NotFound
8+
9+
from ..models import Event
10+
11+
12+
class EventTestSchema(BaseModel):
13+
title: str
14+
start_date: str
15+
end_date: str
16+
17+
18+
class TestAsyncDjangoSupport:
19+
"""Test the _async_django_support decorator functionality"""
20+
21+
def setup_method(self):
22+
"""Set up test fixtures"""
23+
self.service = ModelService(Event)
24+
self.test_schema = EventTestSchema(
25+
title="Test Event", start_date="2020-01-01", end_date="2020-01-02"
26+
)
27+
28+
@pytest.mark.asyncio
29+
@patch(
30+
"ninja_extra.controllers.model.service.django_version_greater_than_4_2", False
31+
)
32+
@patch("ninja_extra.controllers.model.service.sync_to_async")
33+
async def test_async_django_support_with_exception_in_sync_method(
34+
self, mock_sync_to_async
35+
):
36+
"""Test decorator behavior when sync method raises an exception in Django < 4.2"""
37+
38+
# Mock sync method that raises an exception
39+
sync_method = Mock(side_effect=NotFound("Test error"))
40+
mock_async_wrapper = AsyncMock(side_effect=NotFound("Test error"))
41+
mock_sync_to_async.return_value = mock_async_wrapper
42+
43+
# Add the sync method to the service instance
44+
self.service.test_sync_method = sync_method
45+
46+
# Create a dummy async method
47+
async_method = AsyncMock()
48+
49+
# Apply the decorator
50+
decorated_method = _async_django_support("test_sync_method")(async_method)
51+
52+
# Verify the exception is propagated
53+
with pytest.raises(NotFound, match="Test error"):
54+
await decorated_method(self.service, "test_arg")
55+
56+
@pytest.mark.asyncio
57+
@patch(
58+
"ninja_extra.controllers.model.service.django_version_greater_than_4_2", True
59+
)
60+
async def test_async_django_support_with_exception_in_async_method(self):
61+
"""Test decorator behavior when async method raises an exception in Django >= 4.2"""
62+
63+
# Mock an async method that raises an exception
64+
original_async_method = AsyncMock(side_effect=NotFound("Async test error"))
65+
66+
# Apply the decorator
67+
decorated_method = _async_django_support("sync_method_name")(
68+
original_async_method
69+
)
70+
71+
# Verify the exception is propagated
72+
with pytest.raises(NotFound, match="Async test error"):
73+
await decorated_method(self.service, "test_arg")
74+
75+
@pytest.mark.asyncio
76+
@pytest.mark.django_db
77+
async def test_actual_model_service_methods_with_django_4_2_true(self):
78+
"""Integration test with actual ModelService methods when Django >= 4.2"""
79+
80+
with patch(
81+
"ninja_extra.controllers.model.service.django_version_greater_than_4_2",
82+
True,
83+
):
84+
# Test create_async
85+
event = await self.service.create_async(self.test_schema)
86+
assert event.title == "Test Event"
87+
assert str(event.start_date) == "2020-01-01"
88+
89+
# Test get_one_async
90+
retrieved_event = await self.service.get_one_async(event.pk)
91+
assert retrieved_event.id == event.id
92+
assert retrieved_event.title == "Test Event"
93+
94+
# Test update_async
95+
update_schema = EventTestSchema(
96+
title="Updated Event", start_date="2020-01-01", end_date="2020-01-02"
97+
)
98+
updated_event = await self.service.update_async(event, update_schema)
99+
assert updated_event.title == "Updated Event"
100+
101+
# Test delete_async
102+
await self.service.delete_async(event)
103+
104+
# Verify event is deleted
105+
with pytest.raises(NotFound):
106+
await self.service.get_one_async(event.pk)
107+
108+
@pytest.mark.asyncio
109+
@pytest.mark.django_db
110+
async def test_actual_model_service_methods_with_django_4_2_false(self):
111+
"""Integration test with actual ModelService methods when Django < 4.2"""
112+
113+
with patch(
114+
"ninja_extra.controllers.model.service.django_version_greater_than_4_2",
115+
False,
116+
):
117+
# Test create_async (should fall back to sync method)
118+
event = await self.service.create_async(self.test_schema)
119+
assert event.title == "Test Event"
120+
assert str(event.start_date) == "2020-01-01"
121+
122+
# Test get_one_async (should fall back to sync method)
123+
retrieved_event = await self.service.get_one_async(event.pk)
124+
assert retrieved_event.id == event.id
125+
assert retrieved_event.title == "Test Event"
126+
127+
# Test update_async (should fall back to sync method)
128+
update_schema = EventTestSchema(
129+
title="Updated Event Sync",
130+
start_date="2020-01-01",
131+
end_date="2020-01-02",
132+
)
133+
updated_event = await self.service.update_async(event, update_schema)
134+
assert updated_event.title == "Updated Event Sync"
135+
136+
# Test delete_async (should fall back to sync method)
137+
await self.service.delete_async(event)
138+
139+
# Verify event is deleted
140+
with pytest.raises(NotFound):
141+
await self.service.get_one_async(event.pk)

0 commit comments

Comments
 (0)