Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 31 additions & 38 deletions api/subscriptions/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
RegistrationProvider,
AbstractProvider,
AbstractNode,
Preprint,
OSFUser,
)
from osf.models.notification_type import NotificationType
from osf.models.notification_subscription import NotificationSubscription
Expand All @@ -48,10 +46,6 @@ def get_queryset(self):
user_guid = self.request.user._id
provider_ct = ContentType.objects.get(app_label='osf', model='abstractprovider')

provider_subquery = AbstractProvider.objects.filter(
id=Cast(OuterRef('object_id'), IntegerField()),
).values('_id')[:1]

node_subquery = AbstractNode.objects.filter(
id=Cast(OuterRef('object_id'), IntegerField()),
).values('guids___id')[:1]
Expand All @@ -66,17 +60,17 @@ def get_queryset(self):
).annotate(
event_name=Case(
When(
notification_type=NotificationType.Type.USER_FILE_UPDATED.instance,
notification_type=NotificationType.Type.NODE_FILE_UPDATED.instance,
then=Value('files_updated'),
),
When(
notification_type=NotificationType.Type.USER_FILE_UPDATED.instance,
then=Value(f'{user_guid}_global_file_updated'),
then=Value('global_file_updated'),
),
When(
Q(notification_type=NotificationType.Type.PROVIDER_NEW_PENDING_SUBMISSIONS.instance) &
Q(content_type=provider_ct),
then=Value('new_pending_submissions'),
then=Value('global_reviews'),
),
),
legacy_id=Case(
Expand All @@ -91,7 +85,7 @@ def get_queryset(self):
When(
Q(notification_type=NotificationType.Type.PROVIDER_NEW_PENDING_SUBMISSIONS.instance) &
Q(content_type=provider_ct),
then=Concat(Subquery(provider_subquery), Value('_new_pending_submissions')),
then=Value(f'{user_guid}_global_reviews'),
),
),
)
Expand Down Expand Up @@ -133,41 +127,17 @@ def get_object(self):
provider_ct = ContentType.objects.get(app_label='osf', model='abstractprovider')
node_ct = ContentType.objects.get(app_label='osf', model='abstractnode')

provider_subquery = AbstractProvider.objects.filter(
id=Cast(OuterRef('object_id'), IntegerField()),
).values('_id')[:1]

node_subquery = AbstractNode.objects.filter(
id=Cast(OuterRef('object_id'), IntegerField()),
).values('guids___id')[:1]

guid_id, *event_parts = subscription_id.split('_')
event = '_'.join(event_parts) if event_parts else ''

subscription_obj = AbstractNode.load(guid_id) or Preprint.load(guid_id) or OSFUser.load(guid_id)

if event != 'global':
if subscription_obj is None:
subscription_obj = PreprintProvider.objects.get(_id=guid_id)
obj_filter = Q(
object_id=getattr(subscription_obj, 'id', None),
content_type=ContentType.objects.get_for_model(subscription_obj.__class__),
notification_type__in=[
NotificationType.Type.USER_FILE_UPDATED.instance,
NotificationType.Type.NODE_FILE_UPDATED.instance,
NotificationType.Type.PROVIDER_NEW_PENDING_SUBMISSIONS.instance,
],
)
else:
obj_filter = Q()

try:
obj = NotificationSubscription.objects.annotate(
annotated_obj_qs = NotificationSubscription.objects.filter(user=self.request.user).annotate(
legacy_id=Case(
When(
notification_type__name=NotificationType.Type.NODE_FILE_UPDATED.value,
content_type=node_ct,
then=Concat(Subquery(node_subquery), Value('_file_updated')),
then=Concat(Subquery(node_subquery), Value('_files_updated')),
),
When(
notification_type__name=NotificationType.Type.USER_FILE_UPDATED.value,
Expand All @@ -176,12 +146,13 @@ def get_object(self):
When(
notification_type__name=NotificationType.Type.PROVIDER_NEW_PENDING_SUBMISSIONS.value,
content_type=provider_ct,
then=Concat(Subquery(provider_subquery), Value('_new_pending_submissions')),
then=Value(f'{user_guid}_global_reviews'),
),
default=Value(f'{user_guid}_global'),
output_field=CharField(),
),
).filter(obj_filter)
)
obj = annotated_obj_qs.filter(legacy_id=subscription_id)

except ObjectDoesNotExist:
raise NotFound
Expand All @@ -194,6 +165,28 @@ def get_object(self):
self.check_object_permissions(self.request, obj)
return obj

def update(self, request, *args, **kwargs):
"""
Update a notification subscription
"""
ret = super().update(request, *args, **kwargs)
# Copy global_reviews subscription changes to new_pending_submissions subscriptions [ENG-9666]
if self.get_object().notification_type.name == NotificationType.Type.PROVIDER_NEW_PENDING_SUBMISSIONS.value:
qs = NotificationSubscription.objects.filter(
user=self.request.user,
notification_type__name__in=[
NotificationType.Type.PROVIDER_REVIEWS_SUBMISSION_CONFIRMATION.value,
NotificationType.Type.PROVIDER_REVIEWS_RESUBMISSION_CONFIRMATION.value,
NotificationType.Type.PROVIDER_NEW_PENDING_WITHDRAW_REQUESTS.value,
NotificationType.Type.REVIEWS_SUBMISSION_STATUS.value,
],
)
for instance in qs:
serializer = self.get_serializer(instance=instance, data=request.data, partial=True)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
return ret


class AbstractProviderSubscriptionDetail(SubscriptionDetail):
view_name = 'provider-notification-subscription-detail'
Expand Down
8 changes: 5 additions & 3 deletions api_tests/subscriptions/views/test_subscriptions_detail.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def notification(self, user):
)

@pytest.fixture()
def url(self, notification):
return f'/{API_BASE}subscriptions/{notification._id}/'
def url(self, user):
return f'/{API_BASE}subscriptions/{user._id}_global_file_updated/'

@pytest.fixture()
def url_invalid(self):
Expand Down Expand Up @@ -119,8 +119,9 @@ def test_subscription_detail_invalid_payload_400(
url,
payload_invalid,
auth=user.auth,
expect_errors=True
expect_errors=True,
)

assert res.status_code == 400
assert res.json['errors'][0]['detail'] == ('"invalid-frequency" is not a valid choice.')

Expand Down Expand Up @@ -151,6 +152,7 @@ def test_subscription_detail_patch_no_user(
def test_subscription_detail_patch(
self, app, user, user_no_auth, notification, url, url_invalid, payload, payload_invalid
):

res = app.patch_json_api(url, payload, auth=user.auth)
assert res.status_code == 200
assert res.json['data']['attributes']['frequency'] == 'none'
4 changes: 2 additions & 2 deletions api_tests/subscriptions/views/test_subscriptions_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def test_list_complete(
# There should only be 3 notifications: users' global, node's file updates and provider's preprint added.
assert len(notification_ids) == 3
assert f'{user._id}_global_file_updated' in notification_ids
assert f'{provider._id}_new_pending_submissions' in notification_ids
assert f'{user._id}_global_reviews' in notification_ids
assert f'{node._id}_file_updated' in notification_ids

def test_unauthenticated(self, app, url):
Expand Down Expand Up @@ -122,5 +122,5 @@ def test_value_filter_id(

# Confirm it’s the expected subscription object
attributes = data[0]['attributes']
assert attributes['event_name'] is None # event names are legacy
assert attributes['event_name'] == 'files_updated' # event names are legacy
assert attributes['frequency'] in ['instantly', 'daily', 'none']
Loading