diff --git a/minos/api_gateway/discovery/database/client.py b/minos/api_gateway/discovery/database/client.py index 952d11b..932fac1 100644 --- a/minos/api_gateway/discovery/database/client.py +++ b/minos/api_gateway/discovery/database/client.py @@ -16,6 +16,7 @@ class OrdersMinosApiRouter(MinosApiRouter): import logging from typing import ( Any, + Optional, ) import aioredis @@ -41,17 +42,14 @@ class MinosRedisClient: __slots__ = "address", "port", "password", "redis" - def __init__(self, config: MinosConfig, pool_size: int = 50): + def __init__(self, config: MinosConfig, pool_size: Optional[int] = None): """Perform initial configuration and connection to Redis""" address = config.discovery.database.host port = config.discovery.database.port password = config.discovery.database.password - pool = aioredis.ConnectionPool.from_url( - f"redis://{address}:{port}", password=password, max_connections=pool_size - ) - self.redis = aioredis.Redis(connection_pool=pool) + self.redis = aioredis.from_url(f"redis://{address}:{port}", password=password, max_connections=pool_size) async def get_data(self, key: str) -> str: """Get redis value by key""" @@ -78,13 +76,7 @@ async def get_all(self) -> list: return data async def set_data(self, key: str, data: dict): - async with self.redis as r: - await r.set(key, json.dumps(data)) - await r.save() - - async def update_data(self): # pragma: no cover - """Update specific value""" - pass + await self.redis.set(key, json.dumps(data)) async def delete_data(self, key: str): deleted_elements = await self.redis.delete(key) diff --git a/tests/test_api_gateway/test_discovery/dataset.py b/tests/test_api_gateway/test_discovery/dataset.py new file mode 100644 index 0000000..448aeab --- /dev/null +++ b/tests/test_api_gateway/test_discovery/dataset.py @@ -0,0 +1,39 @@ +import random +import socket +import struct +from uuid import ( + uuid4, +) + + +def generate_random_microservice_names(quantity: int): + random_names = [f"test_endpoint_{str(uuid4())}" for x in range(quantity)] + + return random_names + + +def generate_record(name): + ip = socket.inet_ntoa(struct.pack(">I", random.randint(1, 0xFFFFFFFF))) + port = random.randint(1, 9999) + + record = { + "address": f"{ip}", + "port": port, + "endpoints": [["GET", f"test_endpoint_{name}"], ["POST", f"test_endpoint_{name}"]], + } + + return record + + +def generate_record_old(x): + ip = socket.inet_ntoa(struct.pack(">I", random.randint(1, 0xFFFFFFFF))) + port = random.randint(1, 9999) + name = f"microservice_{x}" + + record = { + "address": f"{ip}", + "port": port, + "endpoints": [["GET", f"test_endpoint_{name}"], ["POST", f"test_endpoint_{name}"]], + } + + return name, record diff --git a/tests/test_api_gateway/test_discovery/test_views/test_microservice.py b/tests/test_api_gateway/test_discovery/test_views/test_microservice.py index 623b957..501b671 100644 --- a/tests/test_api_gateway/test_discovery/test_views/test_microservice.py +++ b/tests/test_api_gateway/test_discovery/test_views/test_microservice.py @@ -1,3 +1,5 @@ +import asyncio + from aiohttp.test_utils import ( AioHTTPTestCase, unittest_run_loop, @@ -9,6 +11,11 @@ from minos.api_gateway.discovery import ( DiscoveryService, ) +from tests.test_api_gateway.test_discovery.dataset import ( + generate_random_microservice_names, + generate_record, + generate_record_old, +) from tests.utils import ( BASE_PATH, ) @@ -37,13 +44,83 @@ async def test_post(self): self.assertEqual(201, response.status) - async def test_bulk_post(self): - name = "test_name" - body = {"address": "1.1.1.1", "port": 1, "endpoints": [["GET", "test_endpoint_1"], ["POST", "test_endpoint_2"]]} + async def test_bulk_update(self): + names = generate_random_microservice_names(50) - response = await self.client.post(f"/microservices/{name}", json=body) + tasks = list() + # Create new records + for name in names: + body = generate_record(name) + tasks.append(self.client.post(f"/microservices/{name}", json=body)) - self.assertEqual(201, response.status) + results = await asyncio.gather(*tasks) + + for result in results: + self.assertEqual(201, result.status) + + # Update existing records + expected = list() + tasks = list() + for name in names: + body = generate_record(name) + expected.append({"name": name, "path": f"/microservices/{name}", "body": body}) + tasks.append(self.client.post(f"/microservices/{name}", json=body)) + + results = await asyncio.gather(*tasks) + + for result in results: + self.assertEqual(201, result.status) + + # Check updated records are correct + for record in expected: + response = await self.client.get( + f"/microservices?verb={record['body']['endpoints'][0][0]}&path={record['body']['endpoints'][0][1]}" + ) + + self.assertEqual(200, response.status) + + body = await response.json() + + self.assertEqual(record["body"]["address"], body["address"]) + self.assertEqual(record["body"]["port"], int(body["port"])) + self.assertEqual(record["name"], body["name"]) + + async def test_bulk_update_2(self): + expected = list() + tasks = list() + # Create new records + for x in range(50): + name, body = generate_record_old(x) + tasks.append(self.client.post(f"/microservices/{name}", json=body)) + + results = await asyncio.gather(*tasks) + + for result in results: + self.assertEqual(201, result.status) + + tasks = list() + for x in range(50): + name, body = generate_record_old(x) + expected.append({"name": name, "path": f"/microservices/{name}", "body": body}) + tasks.append(self.client.post(f"/microservices/{name}", json=body)) + + results = await asyncio.gather(*tasks) + + for result in results: + self.assertEqual(201, result.status) + + for record in expected: + response = await self.client.get( + f"/microservices?verb={record['body']['endpoints'][0][0]}&path={record['body']['endpoints'][0][1]}" + ) + + self.assertEqual(200, response.status) + + body = await response.json() + + self.assertEqual(record["body"]["address"], body["address"]) + self.assertEqual(int(record["body"]["port"]), int(body["port"])) + self.assertEqual(record["name"], body["name"]) @unittest_run_loop async def test_post_missing_param(self):