diff --git a/agentic_security/core/test_app.py b/agentic_security/core/test_app.py index 1f03239..56cce45 100644 --- a/agentic_security/core/test_app.py +++ b/agentic_security/core/test_app.py @@ -6,12 +6,30 @@ @pytest.fixture(autouse=True) +def reset_globals(): + """ + Reset globals (_secrets, current_run, tools_inbox, stop_event) before each test. + This ensures tests run in a clean state. + """ + from agentic_security.core.app import _secrets, current_run, get_tools_inbox, get_stop_event + _secrets.clear() + current_run["spec"] = "" + current_run["id"] = "" + # Clear tools_inbox queue + queue = get_tools_inbox() + while not queue.empty(): + queue.get_nowait() + # Reset stop_event if it is set + event = get_stop_event() + if event.is_set(): + event.clear() def setup_env_vars(): # Set up environment variables for testing os.environ["TEST_ENV_VAR"] = "test_value" def test_expand_secrets_with_env_var(): + os.environ["TEST_ENV_VAR"] = "test_value" secrets = {"secret_key": "$TEST_ENV_VAR"} expand_secrets(secrets) assert secrets["secret_key"] == "test_value" @@ -27,3 +45,180 @@ def test_expand_secrets_without_dollar_sign(): secrets = {"secret_key": "plain_value"} expand_secrets(secrets) assert secrets["secret_key"] == "plain_value" + +import asyncio +from fastapi import FastAPI +from fastapi.responses import ORJSONResponse +from agentic_security.core.app import create_app, get_tools_inbox, get_stop_event, get_current_run, set_current_run, get_secrets, set_secrets, expand_secrets + +class DummyLLMSpec: + """A dummy LLMSpec for testing purposes.""" + pass + +def test_create_app(): + """Test that create_app returns a FastAPI app with ORJSONResponse.""" + app = create_app() + assert isinstance(app, FastAPI) + assert app.router.default_response_class == ORJSONResponse + +def test_get_tools_inbox(): + """Test that get_tools_inbox returns a Queue instance.""" + queue = get_tools_inbox() + from asyncio import Queue + assert isinstance(queue, Queue) + +def test_get_stop_event(): + """Test that get_stop_event returns an Event instance.""" + event = get_stop_event() + from asyncio import Event + assert isinstance(event, Event) + +def test_get_current_run_initial(): + """Test that get_current_run returns the initial current run dictionary.""" + current = get_current_run() + # The initial dictionary should have an empty spec and id. + assert current["spec"] == "" + assert current["id"] == "" + +def test_set_current_run(): + """Test that set_current_run updates the current run with the dummy LLMSpec.""" + dummy_spec = DummyLLMSpec() + updated = set_current_run(dummy_spec) + assert updated["spec"] is dummy_spec + # Ensure that the id is computed as hash(id(dummy_spec)) + expected_id = hash(id(dummy_spec)) + assert updated["id"] == expected_id + +def test_get_and_set_secrets(): + """Test that set_secrets updates the secrets dictionary and get_secrets returns the updated values.""" + # Clear any previously set secrets + secrets_before = get_secrets().copy() + os.environ["MY_SECRET"] = "secret_value" + new_secrets = {"key1": "$MY_SECRET", "key2": "plain"} + updated = set_secrets(new_secrets) + assert updated["key1"] == "secret_value" + assert updated["key2"] == "plain" + +def test_expand_secrets_multiple_keys(): + """Test expand_secrets with multiple keys, including one with an environment variable, + one with a non-existent variable, and one that is plain.""" + os.environ["TEST_ENV_VAR"] = "test_value" + secrets = {"env_key": "$TEST_ENV_VAR", "nonexistent_key": "$NON_EXISTENT", "plain_key": "value"} + expand_secrets(secrets) + assert secrets["env_key"] == "test_value" + # For a non-existent environment variable, os.getenv returns None + assert secrets["nonexistent_key"] is None + # Plain values should not be changed. + assert secrets["plain_key"] == "value" +def test_expand_secrets_with_space_after_dollar(): + """Test expand_secrets when the value has a dollar sign followed by a space. + Since the value does not start strictly with "$", the secret remains unchanged. + Also verifies that the stripping in expand_secrets (via strip("$")) + will remove both dollar and any whitespace if the value actually started with '$'. + """ + os.environ["SPACED_VAR"] = "spaced_value" + secrets = {"key": "$ SPACED_VAR"} + expand_secrets(secrets) + # " $ SPACED_VAR" after strip("$") becomes " SPACED_VAR" which is not a valid env key so returns None. + assert secrets["key"] is None + +def test_set_secrets_update_existing(): + """Test that set_secrets updates an existing secret and retains previously set keys.""" + os.environ["VAR1"] = "value1" + os.environ["VAR2"] = "value2" + result_first = set_secrets({"a": "$VAR1", "b": "b_val"}) + assert result_first["a"] == "value1" + # Change VAR1 in environment and update secret "a", and add secret "c" + os.environ["VAR1"] = "new_value1" + result_second = set_secrets({"a": "$VAR1", "c": "$VAR2"}) + assert result_second["a"] == "new_value1" + assert result_second["b"] == "b_val" + assert result_second["c"] == "value2" + +def test_tools_inbox_state(): + """Test that get_tools_inbox returns the same queue instance + and that the queue state persists across multiple calls. + """ + from asyncio import Queue + inbox1 = get_tools_inbox() + inbox1.put_nowait("message") + inbox2 = get_tools_inbox() + # inbox2 should contain the "message" from inbox1 + msg = inbox2.get_nowait() + assert msg == "message" + +def test_stop_event_state(): + """Test that stop_event can be set and cleared, and its state persists.""" + event = get_stop_event() + # Initially the event should not be set + assert not event.is_set() + event.set() + assert event.is_set() + event.clear() + assert not event.is_set() + +def test_set_current_run_returns_global_dict(): + """Test that set_current_run returns the same global current_run dictionary + as returned by get_current_run. + """ + dummy_spec = DummyLLMSpec() + updated = set_current_run(dummy_spec) + current = get_current_run() + assert updated is current +def test_get_secrets_initial(): + """Test that get_secrets returns an empty dictionary initially.""" + assert get_secrets() == {} + +def test_set_secrets_empty(): + """Test that setting an empty secrets dictionary does not modify existing secrets.""" + # first set initial secrets + initial = {"key": "value"} + set_secrets(initial) + # update with an empty dict – the existing keys remain + result = set_secrets({}) + assert result == initial + +def test_update_current_run_twice(): + """Test updating current run twice with different LLMSpec values.""" + dummy1 = DummyLLMSpec() + dummy2 = DummyLLMSpec() + set_current_run(dummy1) + first = get_current_run().copy() + set_current_run(dummy2) + second = get_current_run().copy() + # first update should hold dummy1, second should hold dummy2 + assert first["spec"] is dummy1 + assert second["spec"] is dummy2 + # Ensure that id has changed (using hash(id(dummy_spec))) + assert first["id"] != second["id"] + +def test_expand_secrets_trailing_whitespace(): + """Test expand_secrets when the secret value has trailing whitespace after the dollar sign. + The trailing whitespace remains after stripping only the dollar sign, so the looked-up environment variable key will not match. + """ + os.environ["TRIM_TEST"] = "trimmed" + secrets = {"key": "$TRIM_TEST "} + expand_secrets(secrets) + # Since "TRIM_TEST " (with trailing space) is not set in the environment, the secret should be None. + assert secrets["key"] is None +def test_expand_secrets_empty_dict(): + """Test expand_secrets with an empty dictionary does nothing.""" + secrets = {} + expand_secrets(secrets) + assert secrets == {} + +def test_expand_secrets_with_non_string_value(): + """Test that expand_secrets raises an AttributeError when a secret value is not a string.""" + secrets = {"key": 123} + with pytest.raises(AttributeError): + expand_secrets(secrets) + +def test_expand_secrets_multiple_dollar_signs(): + """Test expand_secrets with a value that contains multiple leading dollar signs. + The extra dollar signs are removed by the strip method. + """ + os.environ["MULTI_DOLLAR_VAR"] = "multi_value" + secrets = {"key": "$$MULTI_DOLLAR_VAR"} + expand_secrets(secrets) + # After stripping, "$$MULTI_DOLLAR_VAR".strip("$") returns "MULTI_DOLLAR_VAR" + assert secrets["key"] == "multi_value" \ No newline at end of file diff --git a/agentic_security/probe_actor/test_refusal.py b/agentic_security/probe_actor/test_refusal.py new file mode 100644 index 0000000..f53000d --- /dev/null +++ b/agentic_security/probe_actor/test_refusal.py @@ -0,0 +1,51 @@ +from agentic_security.probe_actor.refusal import DefaultRefusalClassifier +from agentic_security.probe_actor.refusal import refusal_heuristic, RefusalClassifierManager, RefusalClassifierPlugin + + +class TestCheckRefusal: + # The function correctly identifies a refusal phrase in the response. + def test_identify_refusal_phrase(self): + response = "I'm sorry, but I cannot provide that information." + assert DefaultRefusalClassifier().is_refusal(response) + + # The response is an empty string. + def test_empty_response(self): + response = "" + assert not DefaultRefusalClassifier().is_refusal(response) + def test_non_refusal_phrase(self): + """Test that a response without refusal phrases returns False.""" + response = "Here is some helpful information." + assert not DefaultRefusalClassifier().is_refusal(response) + + def test_refusal_heuristic(self): + """Test that the refusal_heuristic function correctly identifies refusal phrases.""" + response = "Sorry, I cannot assist with your request." + # refusal_heuristic will cast the argument to a string and use the plugin manager to check for refusal markers + assert refusal_heuristic(response) + + def test_plugin_manager_with_dummy_plugin(self): + """Test that RefusalClassifierManager returns True if any registered plugin returns True.""" + class AlwaysFalsePlugin(RefusalClassifierPlugin): + def is_refusal(self, response: str) -> bool: + return False + + class AlwaysTruePlugin(RefusalClassifierPlugin): + def is_refusal(self, response: str) -> bool: + return True + + manager = RefusalClassifierManager() + manager.register_plugin("false", AlwaysFalsePlugin()) + manager.register_plugin("true", AlwaysTruePlugin()) + response = "Any response text" + assert manager.is_refusal(response) + + def test_default_classifier_with_empty_phrases(self): + """Test that DefaultRefusalClassifier returns False when provided with an empty refusal phrase list.""" + classifier = DefaultRefusalClassifier(refusal_phrases=[]) + response = "I do not have any info." + assert not classifier.is_refusal(response) + + def test_case_sensitivity(self): + """Test that string matching is case-sensitive.""" + response = "i'm sorry, but can you help me?" # lower-case "i'm sorry" does not match "I'm sorry" and no extra refusal phrases are present + assert not DefaultRefusalClassifier().is_refusal(response) diff --git a/codebeaver.yml b/codebeaver.yml new file mode 100644 index 0000000..145ae1d --- /dev/null +++ b/codebeaver.yml @@ -0,0 +1,2 @@ +from: python-pytest-poetry +# This file was generated automatically by CodeBeaver based on your repository. Learn how to customize it here: https://docs.codebeaver.ai/configuration/ \ No newline at end of file diff --git a/tests/test_report_chart.py b/tests/test_report_chart.py new file mode 100644 index 0000000..03785e8 --- /dev/null +++ b/tests/test_report_chart.py @@ -0,0 +1,136 @@ +import io +import string +import pytest +import pandas as pd +import numpy as np +from agentic_security.report_chart import plot_security_report, generate_identifiers + +class TestReportChart: + """Test suite for agentic_security.report_chart module.""" + + def test_generate_identifiers_short(self): + """Test generate_identifiers with a small dataset.""" + df = pd.DataFrame([{'dummy': i} for i in range(5)]) + identifiers = generate_identifiers(df) + expected = ['A1', 'A2', 'A3', 'A4', 'A5'] + assert identifiers == expected + + def test_generate_identifiers_edge(self): + """Test generate_identifiers with more than 26 items to cover cycling over the alphabet.""" + n = 30 + df = pd.DataFrame([{'dummy': i} for i in range(n)]) + identifiers = generate_identifiers(df) + # For i=25, identifier should be A26, and for i=26, identifier should be B1 + assert identifiers[25] == 'A26' + assert identifiers[26] == 'B1' + assert len(identifiers) == n + + def test_generate_identifiers_empty(self): + """Test generate_identifiers with an empty dataframe.""" + df = pd.DataFrame([]) + identifiers = generate_identifiers(df) + assert identifiers == [] + + def test_plot_security_report_png_output(self): + """Test plot_security_report returns valid PNG output.""" + # Create a sample table with required columns + table = [ + {"failureRate": 10, "tokens": 100, "module": "Module1"}, + {"failureRate": 30, "tokens": 200, "module": "Module2"}, + {"failureRate": 20, "tokens": 150, "module": "Module3"}, + ] + buf = plot_security_report(table) + # Check that buf is a BytesIO object and starts with PNG header bytes + assert isinstance(buf, io.BytesIO) + buf.seek(0) + header = buf.read(8) + assert header.startswith(b'\x89PNG') + + def test_plot_security_report_ordering(self, monkeypatch): + """Test that the table embedded in the plot contains correctly sorted order by descending failure rate.""" + table = [ + {"failureRate": 15, "tokens": 110, "module": "ModuleA"}, + {"failureRate": 25, "tokens": 210, "module": "ModuleB"}, + {"failureRate": 5, "tokens": 90, "module": "ModuleC"}, + ] + result_holder = {} + from matplotlib.axes import Axes + original_table = Axes.table + def fake_table(self, *args, **kwargs): + result_holder['cellText'] = kwargs.get('cellText') + return original_table(self, *args, **kwargs) + monkeypatch.setattr(Axes, "table", fake_table) + plot_security_report(table) + cell_text = result_holder.get('cellText') + assert cell_text is not None + # Verify header row in the table + assert cell_text[0] == ["Threat"] + # Since the data are sorted (highest failure rate first), ModuleB (25.0%) should appear in one of the rows. + found = any("ModuleB (25.0%)" in row[0] for row in cell_text[1:]) + assert found + + def test_plot_security_report_one_entry(self): + """Test plot_security_report with a single entry.""" + table = [{"failureRate": 50, "tokens": 300, "module": "OnlyModule"}] + buf = plot_security_report(table) + assert isinstance(buf, io.BytesIO) + buf.seek(0) + content = buf.read() + assert content.startswith(b'\x89PNG') + def test_generate_identifiers_many(self): + """Test generate_identifiers with 52 items to verify identifier sequence.""" + n = 52 + df = pd.DataFrame([{'dummy': i} for i in range(n)]) + identifiers = generate_identifiers(df) + assert identifiers[0] == "A1" + assert identifiers[25] == "A26" + assert identifiers[26] == "B1" + assert identifiers[51] == "B26" + + def test_plot_security_report_missing_failureRate(self): + """Test plot_security_report raises KeyError when 'failureRate' column is missing.""" + table = [{"tokens": 100, "module": "Mod1"}] # Missing 'failureRate' + with pytest.raises(KeyError): + plot_security_report(table) + + def test_plot_security_report_missing_tokens(self): + """Test plot_security_report raises KeyError when 'tokens' column is missing.""" + table = [{"failureRate": 10, "module": "Mod1"}] # Missing 'tokens' + with pytest.raises(KeyError): + plot_security_report(table) + + def test_plot_security_report_empty_table(self): + """Test plot_security_report raises KeyError when the table is empty.""" + table = [] + with pytest.raises(KeyError): + plot_security_report(table) + def test_plot_security_report_missing_module(self): + """Test plot_security_report raises KeyError when 'module' column is missing.""" + table = [{"failureRate": 10, "tokens": 100}] # Missing 'module' + with pytest.raises(KeyError): + plot_security_report(table) + + def test_plot_security_report_failure_rate_labels(self, monkeypatch): + """Test that plot_security_report calls ax.text for each failure rate bar label.""" + table = [ + {"failureRate": 10, "tokens": 100, "module": "Mod1"}, + {"failureRate": 20, "tokens": 150, "module": "Mod2"}, + {"failureRate": 30, "tokens": 200, "module": "Mod3"}, + ] + # Count the number of times ax.text is called for drawing failure rate labels. + call_count = [0] + from matplotlib.axes import Axes + original_text = Axes.text + def fake_text(self, *args, **kwargs): + call_count[0] += 1 + return original_text(self, *args, **kwargs) + monkeypatch.setattr(Axes, "text", fake_text) + plot_security_report(table) + # The loop inside plot_security_report calls ax.text once for each data point. + assert call_count[0] == len(table) + + def test_plot_security_report_non_numeric_failureRate(self): + """Test that plot_security_report raises an exception when failureRate is non-numeric.""" + table = [{"failureRate": "invalid", "tokens": 100, "module": "ModX"}] + with pytest.raises(Exception): + plot_security_report(table) \ No newline at end of file diff --git a/tests/test_scan.py b/tests/test_scan.py new file mode 100644 index 0000000..4781962 --- /dev/null +++ b/tests/test_scan.py @@ -0,0 +1,126 @@ +import io +import asyncio +import json +from datetime import datetime, timedelta +from threading import Event +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from agentic_security.routes import scan + +# Dummy LLMSpec for success tests +class DummyLLMSpec: + def __init__(self, spec_string): + self.spec_string = spec_string + async def verify(self): + class DummyResponse: + status_code = 200 + text = "verification succeeded" + elapsed = timedelta(seconds=0.5) + return DummyResponse() + @classmethod + def from_string(cls, spec_string): + return DummyLLMSpec(spec_string) + +# Dummy scan_router generator to simulate streaming responses +async def dummy_scan_router(request_factory, scan_parameters, tools_inbox, stop_event): + for i in range(2): + yield f"result {i}" + +# Define a dummy Secrets class for testing purposes. +class DummySecrets: + def __init__(self): + self.secrets = {} + +# Create FastAPI app for testing and include the scan router. +@pytest.fixture +def app(): + app = FastAPI() + app.include_router(scan.router) + return app + +@pytest.fixture +def client(app): + return TestClient(app) + +@pytest.fixture(autouse=True) +def patch_dependencies(monkeypatch): + # Patch LLMSpec used in the routes with our dummy implementation. + monkeypatch.setattr(scan, "LLMSpec", DummyLLMSpec) + # Patch fuzzer.scan_router to use our dummy scanning generator. + monkeypatch.setattr(scan.fuzzer, "scan_router", dummy_scan_router) + # Patch get_stop_event to return a dummy Event. + dummy_event = Event() + monkeypatch.setattr(scan, "get_stop_event", lambda: dummy_event) + # Patch get_tools_inbox to return None. + monkeypatch.setattr(scan, "get_tools_inbox", lambda: None) + # Patch set_current_run to be a no-op. + monkeypatch.setattr(scan, "set_current_run", lambda x: None) + # Patch get_in_memory_secrets to return a DummySecrets instance. + monkeypatch.setattr(scan, "get_in_memory_secrets", lambda: DummySecrets()) + # Ensure Scan.with_secrets is a no-op if not already implemented. + if not hasattr(scan.Scan, "with_secrets"): + monkeypatch.setattr(scan.Scan, "with_secrets", lambda self, secrets: None) + +def test_verify_success(client): + """Test /verify endpoint for a successful verification.""" + data = {"spec": "dummy"} + response = client.post("/verify", json=data) + res_json = response.json() + assert response.status_code == 200 + assert res_json["status_code"] == 200 + assert res_json["body"] == "verification succeeded" + assert "elapsed" in res_json + assert "timestamp" in res_json + +def test_verify_failure(client, monkeypatch): + """Test /verify endpoint when verification fails.""" + class DummyLLMSpecFailure: + def __init__(self, spec_string): + self.spec_string = spec_string + async def verify(self): + raise Exception("verification error") + @classmethod + def from_string(cls, spec_string): + return DummyLLMSpecFailure(spec_string) + monkeypatch.setattr(scan, "LLMSpec", DummyLLMSpecFailure) + data = {"spec": "bad"} + response = client.post("/verify", json=data) + assert response.status_code == 400 + assert "verification error" in response.text + +def test_scan(client): + """Test /scan endpoint to ensure streaming response works.""" + data = {"llmSpec": "dummy", "optimize": False, "maxBudget": 10, "enableMultiStepAttack": False} + response = client.post("/scan", json=data) + assert response.status_code == 200 + content = list(response.iter_lines()) + expected = ["result 0", "result 1"] + assert content == expected + +def test_stop_scan(client): + """Test /stop endpoint to ensure scan stopping functionality.""" + dummy_event = scan.get_stop_event() + dummy_event.clear() + response = client.post("/stop") + assert response.status_code == 200 + assert response.json() == {"status": "Scan stopped"} + assert dummy_event.is_set() + +def test_scan_csv(client): + """Test /scan-csv endpoint with CSV file and llmSpec upload.""" + csv_content = b"col1,col2\nvalue1,value2" + llm_spec_content = b"dummy" + files = { + "file": ("dummy.csv", csv_content, "text/csv"), + "llmSpec": ("spec.txt", llm_spec_content, "text/plain"), + } + response = client.post( + "/scan-csv", + files=files, + data={"optimize": "false", "maxBudget": "10", "enableMultiStepAttack": "false"}, + ) + assert response.status_code == 200 + content = list(response.iter_lines()) + expected = ["result 0", "result 1"] + assert content == expected \ No newline at end of file