diff --git a/pyoaev/apis/endpoint.py b/pyoaev/apis/endpoint.py index ead70a6..f9683cc 100644 --- a/pyoaev/apis/endpoint.py +++ b/pyoaev/apis/endpoint.py @@ -1,6 +1,7 @@ from typing import Any, Dict from pyoaev import exceptions as exc +from pyoaev.apis.inputs.search import SearchPaginationInput from pyoaev.base import RESTManager, RESTObject from pyoaev.utils import RequiredOptional @@ -36,3 +37,11 @@ def upsert(self, endpoint: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: path = f"{self.path}/agentless/upsert" result = self.openaev.http_post(path, post_data=endpoint, **kwargs) return result + + @exc.on_http_error(exc.OpenAEVUpdateError) + def searchTargets( + self, input: SearchPaginationInput, **kwargs: Any + ) -> Dict[str, Any]: + path = f"{self.path}/targets" + result = self.openaev.http_post(path, post_data=input.to_dict(), **kwargs) + return result diff --git a/pyoaev/contracts/contract_config.py b/pyoaev/contracts/contract_config.py index 60d6937..56de12f 100644 --- a/pyoaev/contracts/contract_config.py +++ b/pyoaev/contracts/contract_config.py @@ -32,6 +32,11 @@ class ContractFieldType(str, Enum): Payload: str = "payload" +class ContractFieldKey(str, Enum): + Asset: str = "assets" + AssetGroup: str = "asset_groups" + + class ContractOutputType(str, Enum): Text: str = "text" Number: str = "number" @@ -270,6 +275,7 @@ def get_type(self) -> str: @dataclass class ContractAsset(ContractCardinalityElement): + key: str = field(default=ContractFieldKey.Asset.value, init=False) @property def get_type(self) -> str: @@ -278,6 +284,7 @@ def get_type(self) -> str: @dataclass class ContractAssetGroup(ContractCardinalityElement): + key: str = field(default=ContractFieldKey.AssetGroup.value, init=False) @property def get_type(self) -> str: diff --git a/test/apis/endpoint/__init__.py b/test/apis/endpoint/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/apis/endpoint/test_endpoint.py b/test/apis/endpoint/test_endpoint.py new file mode 100644 index 0000000..dcf081e --- /dev/null +++ b/test/apis/endpoint/test_endpoint.py @@ -0,0 +1,57 @@ +from unittest import TestCase, main, mock +from unittest.mock import ANY + +from pyoaev import OpenAEV +from pyoaev.apis.inputs.search import Filter, FilterGroup, SearchPaginationInput + + +def mock_response(**kwargs): + class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + self.history = None + self.content = None + self.headers = {"Content-Type": "application/json"} + + def json(self): + return self.json_data + + return MockResponse(None, 200) + + +class TestInjectorContract(TestCase): + @mock.patch("requests.Session.request", side_effect=mock_response) + def test_search_input_correctly_serialised(self, mock_request): + api_client = OpenAEV("url", "token") + + search_input = SearchPaginationInput( + 0, + 20, + FilterGroup( + "or", + [Filter("targets", "and", "eq", ["target_1", "target_2", "target_3"])], + ), + None, + None, + ) + + expected_json = search_input.to_dict() + api_client.endpoint.searchTargets(search_input) + + mock_request.assert_called_once_with( + method="post", + url="url/api/endpoints/targets", + params={}, + data=None, + timeout=None, + stream=False, + verify=True, + json=expected_json, + headers=ANY, + auth=ANY, + ) + + +if __name__ == "__main__": + main()