From b959b37cdb3e1c00220b2e76e0644c123e030e9d Mon Sep 17 00:00:00 2001 From: Martin Hauser Date: Wed, 22 Oct 2025 11:27:46 +0200 Subject: [PATCH] feat(ipam): Enhance filter methods with dynamic prefixing Refactors filter methods to enable dynamic query prefixing for greater flexibility in nested queries. Improves error handling for invalid network inputs and ensures consistency across IPAM filter implementations. Fixes #20541 --- netbox/ipam/graphql/filters.py | 67 ++++++++--- netbox/ipam/tests/test_api.py | 201 +++++++++++++++++++++++++++++++++ 2 files changed, 250 insertions(+), 18 deletions(-) diff --git a/netbox/ipam/graphql/filters.py b/netbox/ipam/graphql/filters.py index 50509d6a2be..b8326cc43ac 100644 --- a/netbox/ipam/graphql/filters.py +++ b/netbox/ipam/graphql/filters.py @@ -79,12 +79,28 @@ class ASNRangeFilter(TenancyFilterMixin, OrganizationalModelFilterMixin): @strawberry_django.filter_type(models.Aggregate, lookups=True) class AggregateFilter(ContactFilterMixin, TenancyFilterMixin, PrimaryModelFilterMixin): - prefix: Annotated['PrefixFilter', strawberry.lazy('ipam.graphql.filters')] | None = strawberry_django.filter_field() - prefix_id: ID | None = strawberry_django.filter_field() + prefix: FilterLookup[str] | None = strawberry_django.filter_field() rir: Annotated['RIRFilter', strawberry.lazy('ipam.graphql.filters')] | None = strawberry_django.filter_field() rir_id: ID | None = strawberry_django.filter_field() date_added: DateFilterLookup[date] | None = strawberry_django.filter_field() + @strawberry_django.filter_field() + def contains(self, value: list[str], prefix) -> Q: + """ + Return aggregates whose `prefix` contains any of the supplied networks. + Mirrors PrefixFilter.contains but operates on the Aggregate.prefix field itself. + """ + if not value: + return Q() + q = Q() + for subnet in value: + try: + query = str(netaddr.IPNetwork(subnet.strip()).cidr) + except (AddrFormatError, ValueError): + continue + q |= Q(**{f"{prefix}prefix__net_contains": query}) + return q + @strawberry_django.filter_type(models.FHRPGroup, lookups=True) class FHRPGroupFilter(PrimaryModelFilterMixin): @@ -119,28 +135,28 @@ class FHRPGroupAssignmentFilter(BaseObjectTypeFilterMixin, ChangeLogFilterMixin) ) @strawberry_django.filter_field() - def device_id(self, queryset, value: list[str], prefix) -> Q: - return self.filter_device('id', value) + def device_id(self, value: list[str], prefix) -> Q: + return self.filter_device('id', value, prefix) @strawberry_django.filter_field() def device(self, value: list[str], prefix) -> Q: - return self.filter_device('name', value) + return self.filter_device('name', value, prefix) @strawberry_django.filter_field() def virtual_machine_id(self, value: list[str], prefix) -> Q: - return Q(interface_id__in=VMInterface.objects.filter(virtual_machine_id__in=value)) + return Q(**{f"{prefix}interface_id__in": VMInterface.objects.filter(virtual_machine_id__in=value)}) @strawberry_django.filter_field() def virtual_machine(self, value: list[str], prefix) -> Q: - return Q(interface_id__in=VMInterface.objects.filter(virtual_machine__name__in=value)) + return Q(**{f"{prefix}interface_id__in": VMInterface.objects.filter(virtual_machine__name__in=value)}) - def filter_device(self, field, value) -> Q: + def filter_device(self, field, value, prefix) -> Q: """Helper to standardize logic for device and device_id filters""" devices = Device.objects.filter(**{f'{field}__in': value}) interface_ids = [] for device in devices: interface_ids.extend(device.vc_interfaces().values_list('id', flat=True)) - return Q(interface_id__in=interface_ids) + return Q(**{f"{prefix}interface_id__in": interface_ids}) @strawberry_django.filter_type(models.IPAddress, lookups=True) @@ -180,9 +196,9 @@ def parent(self, value: list[str], prefix) -> Q: for subnet in value: try: query = str(netaddr.IPNetwork(subnet.strip()).cidr) - q |= Q(address__net_host_contained=query) except (AddrFormatError, ValueError): - return Q() + continue + q |= Q(**{f"{prefix}address__net_host_contained": query}) return q @strawberry_django.filter_field() @@ -217,9 +233,14 @@ def parent(self, value: list[str], prefix) -> Q: for subnet in value: try: query = str(netaddr.IPNetwork(subnet.strip()).cidr) - q |= Q(start_address__net_host_contained=query, end_address__net_host_contained=query) except (AddrFormatError, ValueError): - return Q() + continue + q |= Q( + **{ + f"{prefix}start_address__net_host_contained": query, + f"{prefix}end_address__net_host_contained": query, + } + ) return q @strawberry_django.filter_field() @@ -228,10 +249,17 @@ def contains(self, value: list[str], prefix) -> Q: return Q() q = Q() for subnet in value: - net = netaddr.IPNetwork(subnet.strip()) + try: + net = netaddr.IPNetwork(subnet.strip()) + query_start = str(netaddr.IPAddress(net.first)) + query_end = str(netaddr.IPAddress(net.last)) + except (AddrFormatError, ValueError): + continue q |= Q( - start_address__host__inet__lte=str(netaddr.IPAddress(net.first)), - end_address__host__inet__gte=str(netaddr.IPAddress(net.last)), + **{ + f"{prefix}start_address__host__inet__lte": query_start, + f"{prefix}end_address__host__inet__gte": query_end, + } ) return q @@ -257,8 +285,11 @@ def contains(self, value: list[str], prefix) -> Q: return Q() q = Q() for subnet in value: - query = str(netaddr.IPNetwork(subnet.strip()).cidr) - q |= Q(prefix__net_contains=query) + try: + query = str(netaddr.IPNetwork(subnet.strip()).cidr) + except (AddrFormatError, ValueError): + continue + q |= Q(**{f"{prefix}prefix__net_contains": query}) return q diff --git a/netbox/ipam/tests/test_api.py b/netbox/ipam/tests/test_api.py index 16a24b77325..f36b2a96b4b 100644 --- a/netbox/ipam/tests/test_api.py +++ b/netbox/ipam/tests/test_api.py @@ -323,6 +323,55 @@ def setUpTestData(cls): }, ] + @tag('regression') + def test_graphql_aggregate_prefix_exact(self): + """ + Test case to verify aggregate prefix equality via field lookup in GraphQL API. + """ + + self.add_permissions('ipam.view_aggregate', 'ipam.view_rir') + + rir = RIR.objects.create(name='RFC6598', slug='rfc6598', is_private=True) + aggregate1 = Aggregate.objects.create(prefix='100.64.0.0/10', rir=rir) + Aggregate.objects.create(prefix='203.0.113.0/24', rir=rir) + + url = reverse('graphql') + query = """{ + aggregate_list(filters: { prefix: { exact: "100.64.0.0/10" } }) { prefix } + }""" + response = self.client.post(url, data={'query': query}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = response.json() + self.assertNotIn('errors', data) + + prefixes = {row['prefix'] for row in data['data']['aggregate_list']} + self.assertIn(str(aggregate1.prefix), prefixes) + + @tag('regression') + def test_graphql_aggregate_contains_skips_invalid(self): + """ + Test the GraphQL API Aggregate `contains` filter skips invalid input. + """ + + self.add_permissions('ipam.view_aggregate', 'ipam.view_rir') + + rir = RIR.objects.create(name='RIR 3', slug='rir-3', is_private=False) + aggregate1 = Aggregate.objects.create(prefix='100.64.0.0/10', rir=rir) + Aggregate.objects.create(prefix='203.0.113.0/24', rir=rir) + + url = reverse('graphql') + query = """{ + aggregate_list(filters: { contains: ["100.64.16.0/24", "not-a-cidr", ""] }) { prefix } + }""" + response = self.client.post(url, data={'query': query}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = response.json() + self.assertNotIn('errors', data) + + prefixes = {row['prefix'] for row in data['data']['aggregate_list']} + self.assertIn(str(aggregate1.prefix), prefixes) + # No exception occurred; invalid entries were ignored + class RoleTest(APIViewTestCases.APIViewTestCase): model = Role @@ -546,6 +595,30 @@ def test_create_multiple_available_ips(self): self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertEqual(len(response.data), 8) + @tag('regression') + def test_graphql_tenant_prefixes_contains_nested_skips_invalid(self): + """ + Test the GraphQL API Tenant nested Prefix `contains` filter skips invalid input. + """ + + self.add_permissions('ipam.view_prefix', 'ipam.view_vrf', 'tenancy.view_tenant') + + tenant = Tenant.objects.create(name='Tenant 1', slug='tenant-1') + vrf = VRF.objects.create(name='Test VRF 1', rd='64512:1') + Prefix.objects.create(prefix='10.20.0.0/16', vrf=vrf, tenant=tenant) + Prefix.objects.create(prefix='198.51.100.0/24', vrf=vrf) # non-tenant + + url = reverse('graphql') + query = """{ + tenant_list(filters: { prefixes: { contains: ["10.20.1.0/24", "not-a-cidr"] } }) { id } + }""" + response = self.client.post(url, data={'query': query}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = response.json() + self.assertNotIn('errors', data) + + self.assertTrue(data['data']['tenant_list']) # tenant returned + class IPRangeTest(APIViewTestCases.APIViewTestCase): model = IPRange @@ -645,6 +718,65 @@ def test_create_multiple_available_ips(self): self.assertHttpStatus(response, status.HTTP_201_CREATED) self.assertEqual(len(response.data), 8) + @tag('regression') + def test_graphql_tenant_ip_ranges_parent_nested_skips_invalid(self): + """ + Test the GraphQL API Tenant nested IP Range `parent` filter skips invalid input. + """ + + self.add_permissions('tenancy.view_tenant', 'ipam.view_iprange', 'ipam.view_vrf') + + tenant = Tenant.objects.create(name='Tenant 1', slug='tenant-1') + vrf = VRF.objects.create(name='Test VRF 1', rd='64512:1') + IPRange.objects.create( + start_address=IPNetwork('10.30.0.1/24'), end_address=IPNetwork('10.30.0.255/24'), vrf=vrf, tenant=tenant + ) + IPRange.objects.create( + start_address=IPNetwork('10.31.0.1/24'), end_address=IPNetwork('10.31.0.255/24'), vrf=vrf, tenant=tenant + ) + + url = reverse('graphql') + query = """{ + tenant_list(filters: { + name: { exact: "Tenant 1" } + ip_ranges: { parent: ["10.30.0.0/24", "bogus"] } + }) { id } + }""" + response = self.client.post(url, data={'query': query}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = response.json() + self.assertNotIn('errors', data) + self.assertTrue(data['data']['tenant_list']) # tenant returned + # No exception occurred; invalid entries were ignored + + @tag('regression') + def test_graphql_tenant_ip_ranges_contains_nested_skips_invalid(self): + """ + Test the GraphQL API Tenant nested IP Range `contains` filter skips invalid input. + """ + + self.add_permissions('tenancy.view_tenant', 'ipam.view_iprange', 'ipam.view_vrf') + + tenant = Tenant.objects.create(name='Tenant 2', slug='tenant-2') + vrf = VRF.objects.create(name='Test VRF 1', rd='64512:2') + IPRange.objects.create( + start_address=IPNetwork('10.40.0.1/24'), end_address=IPNetwork('10.40.0.255/24'), vrf=vrf, tenant=tenant + ) + + url = reverse('graphql') + query = """{ + tenant_list(filters: { + name: { exact: "Tenant 2" } + ip_ranges: { contains: ["10.40.0.128/25", "###"] } + }) { id } + }""" + response = self.client.post(url, data={'query': query}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = response.json() + self.assertNotIn('errors', data) + self.assertTrue(data['data']['tenant_list']) # tenant returned + # No exception occurred; invalid entries were ignored + class IPAddressTest(APIViewTestCases.APIViewTestCase): model = IPAddress @@ -731,6 +863,75 @@ def test_assign_object(self): response = self.client.patch(url, data, format='json', **self.header) self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST) + @tag('regression') + def test_graphql_device_primary_ip4_assigned_nested(self): + """ + Test the GraphQL API Device nested IP Address `primary_ip4` filter. + """ + + self.add_permissions('dcim.view_device', 'dcim.view_interface', 'ipam.view_ipaddress') + + site = Site.objects.create(name='Site 1') + manufacturer = Manufacturer.objects.create(name='Manufacturer 1') + device_type = DeviceType.objects.create(model='Device Type 1', manufacturer=manufacturer) + role = DeviceRole.objects.create(name='Switch') + + device1 = Device.objects.create(name='Device 1', site=site, device_type=device_type, role=role, status='active') + interface1 = Interface.objects.create(name='Interface 1', device=device1, type='1000baset') + ip1 = IPAddress.objects.create(address='10.0.0.1/24') + ip1.assigned_object = interface1 + ip1.save() + device1.primary_ip4 = ip1 + device1.save() + + device2 = Device.objects.create(name='Device 2', site=site, device_type=device_type, role=role, status='active') + + url = reverse('graphql') + query = """{ + device_list(filters: { primary_ip4: { assigned: true } }) { id name } + }""" + response = self.client.post(url, data={'query': query}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = response.json() + self.assertNotIn('errors', data) + + ids = {row['id'] for row in data['data']['device_list']} + self.assertIn(str(device1.pk), ids) + self.assertNotIn(str(device2.pk), ids) + + @tag('regression') + def test_graphql_device_primary_ip4_parent_nested_skips_invalid(self): + """ + Test the GraphQL API Device nested IP Address `parent` filter skips invalid input. + """ + + self.add_permissions('dcim.view_device', 'dcim.view_interface', 'ipam.view_ipaddress') + + site = Site.objects.create(name='Site 1') + manufacturer = Manufacturer.objects.create(name='Manufacturer 1') + device_type = DeviceType.objects.create(model='Device Type 1', manufacturer=manufacturer) + role = DeviceRole.objects.create(name='Switch') + + device1 = Device.objects.create(name='Device 1', site=site, device_type=device_type, role=role, status='active') + interface1 = Interface.objects.create(name='Interface 1', device=device1, type='1000baset') + ip1 = IPAddress.objects.create(address='192.0.2.10/24') + ip1.assigned_object = interface1 + ip1.save() + device1.primary_ip4 = ip1 + device1.save() + + url = reverse('graphql') + query = """{ + device_list(filters: { primary_ip4: { parent: ["192.0.2.0/24", "bad-cidr"] } }) { id } + }""" + response = self.client.post(url, data={'query': query}, format='json', **self.header) + self.assertHttpStatus(response, status.HTTP_200_OK) + data = response.json() + self.assertNotIn('errors', data) + + ids = {row['id'] for row in data['data']['device_list']} + self.assertIn(str(device1.pk), ids) + class FHRPGroupTest(APIViewTestCases.APIViewTestCase): model = FHRPGroup