diff --git a/netbox/ipam/graphql/filters.py b/netbox/ipam/graphql/filters.py index 50509d6a2b..b8326cc43a 100644 --- a/netbox/ipam/graphql/filters.py +++ b/netbox/ipam/graphql/filters.py @@ -79,12 +79,28 @@ class ASNRangeFilter(TenancyFilterMixin, OrganizationalModelFilterMixin): @strawberry_django.filter_type(models.Aggregate, lookups=True) class AggregateFilter(ContactFilterMixin, TenancyFilterMixin, PrimaryModelFilterMixin): - prefix: Annotated['PrefixFilter', strawberry.lazy('ipam.graphql.filters')] | None = strawberry_django.filter_field() - prefix_id: ID | None = strawberry_django.filter_field() + prefix: FilterLookup[str] | None = strawberry_django.filter_field() rir: Annotated['RIRFilter', strawberry.lazy('ipam.graphql.filters')] | None = strawberry_django.filter_field() rir_id: ID | None = strawberry_django.filter_field() date_added: DateFilterLookup[date] | None = strawberry_django.filter_field() + @strawberry_django.filter_field() + def contains(self, value: list[str], prefix) -> Q: + """ + Return aggregates whose `prefix` contains any of the supplied networks. + Mirrors PrefixFilter.contains but operates on the Aggregate.prefix field itself. + """ + if not value: + return Q() + q = Q() + for subnet in value: + try: + query = str(netaddr.IPNetwork(subnet.strip()).cidr) + except (AddrFormatError, ValueError): + continue + q |= Q(**{f"{prefix}prefix__net_contains": query}) + return q + @strawberry_django.filter_type(models.FHRPGroup, lookups=True) class FHRPGroupFilter(PrimaryModelFilterMixin): @@ -119,28 +135,28 @@ class FHRPGroupAssignmentFilter(BaseObjectTypeFilterMixin, ChangeLogFilterMixin) ) @strawberry_django.filter_field() - def device_id(self, queryset, value: list[str], prefix) -> Q: - return self.filter_device('id', value) + def device_id(self, value: list[str], prefix) -> Q: + return self.filter_device('id', value, prefix) @strawberry_django.filter_field() def device(self, value: list[str], prefix) -> Q: - return self.filter_device('name', value) + return self.filter_device('name', value, prefix) @strawberry_django.filter_field() def virtual_machine_id(self, value: list[str], prefix) -> Q: - return Q(interface_id__in=VMInterface.objects.filter(virtual_machine_id__in=value)) + return Q(**{f"{prefix}interface_id__in": VMInterface.objects.filter(virtual_machine_id__in=value)}) @strawberry_django.filter_field() def virtual_machine(self, value: list[str], prefix) -> Q: - return Q(interface_id__in=VMInterface.objects.filter(virtual_machine__name__in=value)) + return Q(**{f"{prefix}interface_id__in": VMInterface.objects.filter(virtual_machine__name__in=value)}) - def filter_device(self, field, value) -> Q: + def filter_device(self, field, value, prefix) -> Q: """Helper to standardize logic for device and device_id filters""" devices = Device.objects.filter(**{f'{field}__in': value}) interface_ids = [] for device in devices: interface_ids.extend(device.vc_interfaces().values_list('id', flat=True)) - return Q(interface_id__in=interface_ids) + return Q(**{f"{prefix}interface_id__in": interface_ids}) @strawberry_django.filter_type(models.IPAddress, lookups=True) @@ -180,9 +196,9 @@ def parent(self, value: list[str], prefix) -> Q: for subnet in value: try: query = str(netaddr.IPNetwork(subnet.strip()).cidr) - q |= Q(address__net_host_contained=query) except (AddrFormatError, ValueError): - return Q() + continue + q |= Q(**{f"{prefix}address__net_host_contained": query}) return q @strawberry_django.filter_field() @@ -217,9 +233,14 @@ def parent(self, value: list[str], prefix) -> Q: for subnet in value: try: query = str(netaddr.IPNetwork(subnet.strip()).cidr) - q |= Q(start_address__net_host_contained=query, end_address__net_host_contained=query) except (AddrFormatError, ValueError): - return Q() + continue + q |= Q( + **{ + f"{prefix}start_address__net_host_contained": query, + f"{prefix}end_address__net_host_contained": query, + } + ) return q @strawberry_django.filter_field() @@ -228,10 +249,17 @@ def contains(self, value: list[str], prefix) -> Q: return Q() q = Q() for subnet in value: - net = netaddr.IPNetwork(subnet.strip()) + try: + net = netaddr.IPNetwork(subnet.strip()) + query_start = str(netaddr.IPAddress(net.first)) + query_end = str(netaddr.IPAddress(net.last)) + except (AddrFormatError, ValueError): + continue q |= Q( - start_address__host__inet__lte=str(netaddr.IPAddress(net.first)), - end_address__host__inet__gte=str(netaddr.IPAddress(net.last)), + **{ + f"{prefix}start_address__host__inet__lte": query_start, + f"{prefix}end_address__host__inet__gte": query_end, + } ) return q @@ -257,8 +285,11 @@ def contains(self, value: list[str], prefix) -> Q: return Q() q = Q() for subnet in value: - query = str(netaddr.IPNetwork(subnet.strip()).cidr) - q |= Q(prefix__net_contains=query) + try: + query = str(netaddr.IPNetwork(subnet.strip()).cidr) + except (AddrFormatError, ValueError): + continue + q |= Q(**{f"{prefix}prefix__net_contains": query}) return q diff --git a/netbox/ipam/tests/test_api.py b/netbox/ipam/tests/test_api.py index 16a24b7732..f36b2a96b4 100644 --- a/netbox/ipam/tests/test_api.py +++ b/netbox/ipam/tests/test_api.py @@ -323,6 +323,55 @@ def setUpTestData(cls): }, ] + @tag('regression') + def test_graphql_aggregate_prefix_exact(self): + """ + Test case to verify aggregate prefix equality via field lookup in GraphQL API. + """ + + self.add_permissions('ipam.view_aggregate', 'ipam.view_rir') + + rir = RIR.objects.create(name='RFC6598', slug='rfc6598', is_private=True) + aggregate1 = Aggregate.objects.create(prefix='100.64.0.0/10', rir=rir) + Aggregate.objects.create(prefix='203.0.113.0/24', rir=rir) + + url = reverse('graphql') + query = """{ + aggregate_list(filters: { prefix: { exact: "100.64.0.0/10" } }) { prefix } + }""" + response = self.client.post(url, data={'query': query}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = response.json() + self.assertNotIn('errors', data) + + prefixes = {row['prefix'] for row in data['data']['aggregate_list']} + self.assertIn(str(aggregate1.prefix), prefixes) + + @tag('regression') + def test_graphql_aggregate_contains_skips_invalid(self): + """ + Test the GraphQL API Aggregate `contains` filter skips invalid input. + """ + + self.add_permissions('ipam.view_aggregate', 'ipam.view_rir') + + rir = RIR.objects.create(name='RIR 3', slug='rir-3', is_private=False) + aggregate1 = Aggregate.objects.create(prefix='100.64.0.0/10', rir=rir) + Aggregate.objects.create(prefix='203.0.113.0/24', rir=rir) + + url = reverse('graphql') + query = """{ + aggregate_list(filters: { contains: ["100.64.16.0/24", "not-a-cidr", ""] }) { prefix } + }""" + response = self.client.post(url, data={'query': query}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = response.json() + self.assertNotIn('errors', data) + + prefixes = {row['prefix'] for row in data['data']['aggregate_list']} + self.assertIn(str(aggregate1.prefix), prefixes) + # No exception occurred; invalid entries were ignored + class RoleTest(APIViewTestCases.APIViewTestCase): model = Role @@ -546,6 +595,30 @@ def test_create_multiple_available_ips(self): self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertEqual(len(response.data), 8) + @tag('regression') + def test_graphql_tenant_prefixes_contains_nested_skips_invalid(self): + """ + Test the GraphQL API Tenant nested Prefix `contains` filter skips invalid input. + """ + + self.add_permissions('ipam.view_prefix', 'ipam.view_vrf', 'tenancy.view_tenant') + + tenant = Tenant.objects.create(name='Tenant 1', slug='tenant-1') + vrf = VRF.objects.create(name='Test VRF 1', rd='64512:1') + Prefix.objects.create(prefix='10.20.0.0/16', vrf=vrf, tenant=tenant) + Prefix.objects.create(prefix='198.51.100.0/24', vrf=vrf) # non-tenant + + url = reverse('graphql') + query = """{ + tenant_list(filters: { prefixes: { contains: ["10.20.1.0/24", "not-a-cidr"] } }) { id } + }""" + response = self.client.post(url, data={'query': query}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = response.json() + self.assertNotIn('errors', data) + + self.assertTrue(data['data']['tenant_list']) # tenant returned + class IPRangeTest(APIViewTestCases.APIViewTestCase): model = IPRange @@ -645,6 +718,65 @@ def test_create_multiple_available_ips(self): self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertEqual(len(response.data), 8) + @tag('regression') + def test_graphql_tenant_ip_ranges_parent_nested_skips_invalid(self): + """ + Test the GraphQL API Tenant nested IP Range `parent` filter skips invalid input. + """ + + self.add_permissions('tenancy.view_tenant', 'ipam.view_iprange', 'ipam.view_vrf') + + tenant = Tenant.objects.create(name='Tenant 1', slug='tenant-1') + vrf = VRF.objects.create(name='Test VRF 1', rd='64512:1') + IPRange.objects.create( + start_address=IPNetwork('10.30.0.1/24'), end_address=IPNetwork('10.30.0.255/24'), vrf=vrf, tenant=tenant + ) + IPRange.objects.create( + start_address=IPNetwork('10.31.0.1/24'), end_address=IPNetwork('10.31.0.255/24'), vrf=vrf, tenant=tenant + ) + + url = reverse('graphql') + query = """{ + tenant_list(filters: { + name: { exact: "Tenant 1" } + ip_ranges: { parent: ["10.30.0.0/24", "bogus"] } + }) { id } + }""" + response = self.client.post(url, data={'query': query}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = response.json() + self.assertNotIn('errors', data) + self.assertTrue(data['data']['tenant_list']) # tenant returned + # No exception occurred; invalid entries were ignored + + @tag('regression') + def test_graphql_tenant_ip_ranges_contains_nested_skips_invalid(self): + """ + Test the GraphQL API Tenant nested IP Range `contains` filter skips invalid input. + """ + + self.add_permissions('tenancy.view_tenant', 'ipam.view_iprange', 'ipam.view_vrf') + + tenant = Tenant.objects.create(name='Tenant 2', slug='tenant-2') + vrf = VRF.objects.create(name='Test VRF 1', rd='64512:2') + IPRange.objects.create( + start_address=IPNetwork('10.40.0.1/24'), end_address=IPNetwork('10.40.0.255/24'), vrf=vrf, tenant=tenant + ) + + url = reverse('graphql') + query = """{ + tenant_list(filters: { + name: { exact: "Tenant 2" } + ip_ranges: { contains: ["10.40.0.128/25", "###"] } + }) { id } + }""" + response = self.client.post(url, data={'query': query}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = response.json() + self.assertNotIn('errors', data) + self.assertTrue(data['data']['tenant_list']) # tenant returned + # No exception occurred; invalid entries were ignored + class IPAddressTest(APIViewTestCases.APIViewTestCase): model = IPAddress @@ -731,6 +863,75 @@ def test_assign_object(self): response = self.client.patch(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) + @tag('regression') + def test_graphql_device_primary_ip4_assigned_nested(self): + """ + Test the GraphQL API Device nested IP Address `primary_ip4` filter. + """ + + self.add_permissions('dcim.view_device', 'dcim.view_interface', 'ipam.view_ipaddress') + + site = Site.objects.create(name='Site 1') + manufacturer = Manufacturer.objects.create(name='Manufacturer 1') + device_type = DeviceType.objects.create(model='Device Type 1', manufacturer=manufacturer) + role = DeviceRole.objects.create(name='Switch') + + device1 = Device.objects.create(name='Device 1', site=site, device_type=device_type, role=role, status='active') + interface1 = Interface.objects.create(name='Interface 1', device=device1, type='1000baset') + ip1 = IPAddress.objects.create(address='10.0.0.1/24') + ip1.assigned_object = interface1 + ip1.save() + device1.primary_ip4 = ip1 + device1.save() + + device2 = Device.objects.create(name='Device 2', site=site, device_type=device_type, role=role, status='active') + + url = reverse('graphql') + query = """{ + device_list(filters: { primary_ip4: { assigned: true } }) { id name } + }""" + response = self.client.post(url, data={'query': query}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = response.json() + self.assertNotIn('errors', data) + + ids = {row['id'] for row in data['data']['device_list']} + self.assertIn(str(device1.pk), ids) + self.assertNotIn(str(device2.pk), ids) + + @tag('regression') + def test_graphql_device_primary_ip4_parent_nested_skips_invalid(self): + """ + Test the GraphQL API Device nested IP Address `parent` filter skips invalid input. + """ + + self.add_permissions('dcim.view_device', 'dcim.view_interface', 'ipam.view_ipaddress') + + site = Site.objects.create(name='Site 1') + manufacturer = Manufacturer.objects.create(name='Manufacturer 1') + device_type = DeviceType.objects.create(model='Device Type 1', manufacturer=manufacturer) + role = DeviceRole.objects.create(name='Switch') + + device1 = Device.objects.create(name='Device 1', site=site, device_type=device_type, role=role, status='active') + interface1 = Interface.objects.create(name='Interface 1', device=device1, type='1000baset') + ip1 = IPAddress.objects.create(address='192.0.2.10/24') + ip1.assigned_object = interface1 + ip1.save() + device1.primary_ip4 = ip1 + device1.save() + + url = reverse('graphql') + query = """{ + device_list(filters: { primary_ip4: { parent: ["192.0.2.0/24", "bad-cidr"] } }) { id } + }""" + response = self.client.post(url, data={'query': query}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = response.json() + self.assertNotIn('errors', data) + + ids = {row['id'] for row in data['data']['device_list']} + self.assertIn(str(device1.pk), ids) + class FHRPGroupTest(APIViewTestCases.APIViewTestCase): model = FHRPGroup