Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 49 additions & 18 deletions netbox/ipam/graphql/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,28 @@ class ASNRangeFilter(TenancyFilterMixin, OrganizationalModelFilterMixin):

@strawberry_django.filter_type(models.Aggregate, lookups=True)
class AggregateFilter(ContactFilterMixin, TenancyFilterMixin, PrimaryModelFilterMixin):
prefix: Annotated['PrefixFilter', strawberry.lazy('ipam.graphql.filters')] | None = strawberry_django.filter_field()
prefix_id: ID | None = strawberry_django.filter_field()
prefix: FilterLookup[str] | None = strawberry_django.filter_field()
rir: Annotated['RIRFilter', strawberry.lazy('ipam.graphql.filters')] | None = strawberry_django.filter_field()
rir_id: ID | None = strawberry_django.filter_field()
date_added: DateFilterLookup[date] | None = strawberry_django.filter_field()

@strawberry_django.filter_field()
def contains(self, value: list[str], prefix) -> Q:
"""
Return aggregates whose `prefix` contains any of the supplied networks.
Mirrors PrefixFilter.contains but operates on the Aggregate.prefix field itself.
"""
if not value:
return Q()
q = Q()
for subnet in value:
try:
query = str(netaddr.IPNetwork(subnet.strip()).cidr)
except (AddrFormatError, ValueError):
continue
q |= Q(**{f"{prefix}prefix__net_contains": query})
return q


@strawberry_django.filter_type(models.FHRPGroup, lookups=True)
class FHRPGroupFilter(PrimaryModelFilterMixin):
Expand Down Expand Up @@ -119,28 +135,28 @@ class FHRPGroupAssignmentFilter(BaseObjectTypeFilterMixin, ChangeLogFilterMixin)
)

@strawberry_django.filter_field()
def device_id(self, queryset, value: list[str], prefix) -> Q:
return self.filter_device('id', value)
def device_id(self, value: list[str], prefix) -> Q:
return self.filter_device('id', value, prefix)

@strawberry_django.filter_field()
def device(self, value: list[str], prefix) -> Q:
return self.filter_device('name', value)
return self.filter_device('name', value, prefix)

@strawberry_django.filter_field()
def virtual_machine_id(self, value: list[str], prefix) -> Q:
return Q(interface_id__in=VMInterface.objects.filter(virtual_machine_id__in=value))
return Q(**{f"{prefix}interface_id__in": VMInterface.objects.filter(virtual_machine_id__in=value)})

@strawberry_django.filter_field()
def virtual_machine(self, value: list[str], prefix) -> Q:
return Q(interface_id__in=VMInterface.objects.filter(virtual_machine__name__in=value))
return Q(**{f"{prefix}interface_id__in": VMInterface.objects.filter(virtual_machine__name__in=value)})

def filter_device(self, field, value) -> Q:
def filter_device(self, field, value, prefix) -> Q:
"""Helper to standardize logic for device and device_id filters"""
devices = Device.objects.filter(**{f'{field}__in': value})
interface_ids = []
for device in devices:
interface_ids.extend(device.vc_interfaces().values_list('id', flat=True))
return Q(interface_id__in=interface_ids)
return Q(**{f"{prefix}interface_id__in": interface_ids})


@strawberry_django.filter_type(models.IPAddress, lookups=True)
Expand Down Expand Up @@ -180,9 +196,9 @@ def parent(self, value: list[str], prefix) -> Q:
for subnet in value:
try:
query = str(netaddr.IPNetwork(subnet.strip()).cidr)
q |= Q(address__net_host_contained=query)
except (AddrFormatError, ValueError):
return Q()
continue
q |= Q(**{f"{prefix}address__net_host_contained": query})
return q

@strawberry_django.filter_field()
Expand Down Expand Up @@ -217,9 +233,14 @@ def parent(self, value: list[str], prefix) -> Q:
for subnet in value:
try:
query = str(netaddr.IPNetwork(subnet.strip()).cidr)
q |= Q(start_address__net_host_contained=query, end_address__net_host_contained=query)
except (AddrFormatError, ValueError):
return Q()
continue
q |= Q(
**{
f"{prefix}start_address__net_host_contained": query,
f"{prefix}end_address__net_host_contained": query,
}
)
return q

@strawberry_django.filter_field()
Expand All @@ -228,10 +249,17 @@ def contains(self, value: list[str], prefix) -> Q:
return Q()
q = Q()
for subnet in value:
net = netaddr.IPNetwork(subnet.strip())
try:
net = netaddr.IPNetwork(subnet.strip())
query_start = str(netaddr.IPAddress(net.first))
query_end = str(netaddr.IPAddress(net.last))
except (AddrFormatError, ValueError):
continue
q |= Q(
start_address__host__inet__lte=str(netaddr.IPAddress(net.first)),
end_address__host__inet__gte=str(netaddr.IPAddress(net.last)),
**{
f"{prefix}start_address__host__inet__lte": query_start,
f"{prefix}end_address__host__inet__gte": query_end,
}
)
return q

Expand All @@ -257,8 +285,11 @@ def contains(self, value: list[str], prefix) -> Q:
return Q()
q = Q()
for subnet in value:
query = str(netaddr.IPNetwork(subnet.strip()).cidr)
q |= Q(prefix__net_contains=query)
try:
query = str(netaddr.IPNetwork(subnet.strip()).cidr)
except (AddrFormatError, ValueError):
continue
q |= Q(**{f"{prefix}prefix__net_contains": query})
return q


Expand Down
201 changes: 201 additions & 0 deletions netbox/ipam/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,55 @@ def setUpTestData(cls):
},
]

@tag('regression')
def test_graphql_aggregate_prefix_exact(self):
"""
Test case to verify aggregate prefix equality via field lookup in GraphQL API.
"""

self.add_permissions('ipam.view_aggregate', 'ipam.view_rir')

rir = RIR.objects.create(name='RFC6598', slug='rfc6598', is_private=True)
aggregate1 = Aggregate.objects.create(prefix='100.64.0.0/10', rir=rir)
Aggregate.objects.create(prefix='203.0.113.0/24', rir=rir)

url = reverse('graphql')
query = """{
aggregate_list(filters: { prefix: { exact: "100.64.0.0/10" } }) { prefix }
}"""
response = self.client.post(url, data={'query': query}, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = response.json()
self.assertNotIn('errors', data)

prefixes = {row['prefix'] for row in data['data']['aggregate_list']}
self.assertIn(str(aggregate1.prefix), prefixes)

@tag('regression')
def test_graphql_aggregate_contains_skips_invalid(self):
"""
Test the GraphQL API Aggregate `contains` filter skips invalid input.
"""

self.add_permissions('ipam.view_aggregate', 'ipam.view_rir')

rir = RIR.objects.create(name='RIR 3', slug='rir-3', is_private=False)
aggregate1 = Aggregate.objects.create(prefix='100.64.0.0/10', rir=rir)
Aggregate.objects.create(prefix='203.0.113.0/24', rir=rir)

url = reverse('graphql')
query = """{
aggregate_list(filters: { contains: ["100.64.16.0/24", "not-a-cidr", ""] }) { prefix }
}"""
response = self.client.post(url, data={'query': query}, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = response.json()
self.assertNotIn('errors', data)

prefixes = {row['prefix'] for row in data['data']['aggregate_list']}
self.assertIn(str(aggregate1.prefix), prefixes)
# No exception occurred; invalid entries were ignored


class RoleTest(APIViewTestCases.APIViewTestCase):
model = Role
Expand Down Expand Up @@ -546,6 +595,30 @@ def test_create_multiple_available_ips(self):
self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(len(response.data), 8)

@tag('regression')
def test_graphql_tenant_prefixes_contains_nested_skips_invalid(self):
"""
Test the GraphQL API Tenant nested Prefix `contains` filter skips invalid input.
"""

self.add_permissions('ipam.view_prefix', 'ipam.view_vrf', 'tenancy.view_tenant')

tenant = Tenant.objects.create(name='Tenant 1', slug='tenant-1')
vrf = VRF.objects.create(name='Test VRF 1', rd='64512:1')
Prefix.objects.create(prefix='10.20.0.0/16', vrf=vrf, tenant=tenant)
Prefix.objects.create(prefix='198.51.100.0/24', vrf=vrf) # non-tenant

url = reverse('graphql')
query = """{
tenant_list(filters: { prefixes: { contains: ["10.20.1.0/24", "not-a-cidr"] } }) { id }
}"""
response = self.client.post(url, data={'query': query}, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = response.json()
self.assertNotIn('errors', data)

self.assertTrue(data['data']['tenant_list']) # tenant returned


class IPRangeTest(APIViewTestCases.APIViewTestCase):
model = IPRange
Expand Down Expand Up @@ -645,6 +718,65 @@ def test_create_multiple_available_ips(self):
self.assertHttpStatus(response, status.HTTP_201_CREATED)
self.assertEqual(len(response.data), 8)

@tag('regression')
def test_graphql_tenant_ip_ranges_parent_nested_skips_invalid(self):
"""
Test the GraphQL API Tenant nested IP Range `parent` filter skips invalid input.
"""

self.add_permissions('tenancy.view_tenant', 'ipam.view_iprange', 'ipam.view_vrf')

tenant = Tenant.objects.create(name='Tenant 1', slug='tenant-1')
vrf = VRF.objects.create(name='Test VRF 1', rd='64512:1')
IPRange.objects.create(
start_address=IPNetwork('10.30.0.1/24'), end_address=IPNetwork('10.30.0.255/24'), vrf=vrf, tenant=tenant
)
IPRange.objects.create(
start_address=IPNetwork('10.31.0.1/24'), end_address=IPNetwork('10.31.0.255/24'), vrf=vrf, tenant=tenant
)

url = reverse('graphql')
query = """{
tenant_list(filters: {
name: { exact: "Tenant 1" }
ip_ranges: { parent: ["10.30.0.0/24", "bogus"] }
}) { id }
}"""
response = self.client.post(url, data={'query': query}, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = response.json()
self.assertNotIn('errors', data)
self.assertTrue(data['data']['tenant_list']) # tenant returned
# No exception occurred; invalid entries were ignored

@tag('regression')
def test_graphql_tenant_ip_ranges_contains_nested_skips_invalid(self):
"""
Test the GraphQL API Tenant nested IP Range `contains` filter skips invalid input.
"""

self.add_permissions('tenancy.view_tenant', 'ipam.view_iprange', 'ipam.view_vrf')

tenant = Tenant.objects.create(name='Tenant 2', slug='tenant-2')
vrf = VRF.objects.create(name='Test VRF 1', rd='64512:2')
IPRange.objects.create(
start_address=IPNetwork('10.40.0.1/24'), end_address=IPNetwork('10.40.0.255/24'), vrf=vrf, tenant=tenant
)

url = reverse('graphql')
query = """{
tenant_list(filters: {
name: { exact: "Tenant 2" }
ip_ranges: { contains: ["10.40.0.128/25", "###"] }
}) { id }
}"""
response = self.client.post(url, data={'query': query}, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = response.json()
self.assertNotIn('errors', data)
self.assertTrue(data['data']['tenant_list']) # tenant returned
# No exception occurred; invalid entries were ignored


class IPAddressTest(APIViewTestCases.APIViewTestCase):
model = IPAddress
Expand Down Expand Up @@ -731,6 +863,75 @@ def test_assign_object(self):
response = self.client.patch(url, data, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_400_BAD_REQUEST)

@tag('regression')
def test_graphql_device_primary_ip4_assigned_nested(self):
"""
Test the GraphQL API Device nested IP Address `primary_ip4` filter.
"""

self.add_permissions('dcim.view_device', 'dcim.view_interface', 'ipam.view_ipaddress')

site = Site.objects.create(name='Site 1')
manufacturer = Manufacturer.objects.create(name='Manufacturer 1')
device_type = DeviceType.objects.create(model='Device Type 1', manufacturer=manufacturer)
role = DeviceRole.objects.create(name='Switch')

device1 = Device.objects.create(name='Device 1', site=site, device_type=device_type, role=role, status='active')
interface1 = Interface.objects.create(name='Interface 1', device=device1, type='1000baset')
ip1 = IPAddress.objects.create(address='10.0.0.1/24')
ip1.assigned_object = interface1
ip1.save()
device1.primary_ip4 = ip1
device1.save()

device2 = Device.objects.create(name='Device 2', site=site, device_type=device_type, role=role, status='active')

url = reverse('graphql')
query = """{
device_list(filters: { primary_ip4: { assigned: true } }) { id name }
}"""
response = self.client.post(url, data={'query': query}, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = response.json()
self.assertNotIn('errors', data)

ids = {row['id'] for row in data['data']['device_list']}
self.assertIn(str(device1.pk), ids)
self.assertNotIn(str(device2.pk), ids)

@tag('regression')
def test_graphql_device_primary_ip4_parent_nested_skips_invalid(self):
"""
Test the GraphQL API Device nested IP Address `parent` filter skips invalid input.
"""

self.add_permissions('dcim.view_device', 'dcim.view_interface', 'ipam.view_ipaddress')

site = Site.objects.create(name='Site 1')
manufacturer = Manufacturer.objects.create(name='Manufacturer 1')
device_type = DeviceType.objects.create(model='Device Type 1', manufacturer=manufacturer)
role = DeviceRole.objects.create(name='Switch')

device1 = Device.objects.create(name='Device 1', site=site, device_type=device_type, role=role, status='active')
interface1 = Interface.objects.create(name='Interface 1', device=device1, type='1000baset')
ip1 = IPAddress.objects.create(address='192.0.2.10/24')
ip1.assigned_object = interface1
ip1.save()
device1.primary_ip4 = ip1
device1.save()

url = reverse('graphql')
query = """{
device_list(filters: { primary_ip4: { parent: ["192.0.2.0/24", "bad-cidr"] } }) { id }
}"""
response = self.client.post(url, data={'query': query}, format='json', **self.header)
self.assertHttpStatus(response, status.HTTP_200_OK)
data = response.json()
self.assertNotIn('errors', data)

ids = {row['id'] for row in data['data']['device_list']}
self.assertIn(str(device1.pk), ids)


class FHRPGroupTest(APIViewTestCases.APIViewTestCase):
model = FHRPGroup
Expand Down