diff --git a/oqs/oqs.py b/oqs/oqs.py index 1320a37..d3cc628 100644 --- a/oqs/oqs.py +++ b/oqs/oqs.py @@ -466,6 +466,11 @@ def get_supported_kem_mechanisms() -> tuple[str, ...]: return _supported_KEMs +# Register the OQS_SIG_supports_ctx_str function from the C library +native().OQS_SIG_supports_ctx_str.restype = ct.c_bool +native().OQS_SIG_supports_ctx_str.argtypes = [ct.c_char_p] + + class Signature(ct.Structure): """ An OQS Signature wraps native/C liboqs OQS_SIG structs. @@ -485,7 +490,6 @@ class Signature(ct.Structure): ("alg_version", ct.c_char_p), ("claimed_nist_level", ct.c_ubyte), ("euf_cma", ct.c_ubyte), - ("sig_with_ctx_support", ct.c_ubyte), ("length_public_key", ct.c_size_t), ("length_secret_key", ct.c_size_t), ("length_signature", ct.c_size_t), @@ -515,7 +519,7 @@ def __init__(self, alg_name: str, secret_key: Union[int, bytes, None] = None) -> self.alg_version = self._sig.contents.alg_version self.claimed_nist_level = self._sig.contents.claimed_nist_level self.euf_cma = self._sig.contents.euf_cma - self.sig_with_ctx_support = self._sig.contents.sig_with_ctx_support + self.sig_with_ctx_support = native().OQS_SIG_supports_ctx_str(self.method_name) self.length_public_key = self._sig.contents.length_public_key self.length_secret_key = self._sig.contents.length_secret_key self.length_signature = self._sig.contents.length_signature @@ -634,7 +638,7 @@ def sign_with_ctx_str(self, message: bytes, context: bytes) -> bytes: :param context: the context string. :param message: the message to sign. """ - if context and not self._sig.contents.sig_with_ctx_support: + if context and not self.sig_with_ctx_support: msg = "Signing with context string not supported" raise RuntimeError(msg) @@ -681,7 +685,7 @@ def verify_with_ctx_str( :param context: the context string. :param public_key: the signer's public key. """ - if context and not self._sig.contents.sig_with_ctx_support: + if context and not self.sig_with_ctx_support: msg = "Verifying with context string not supported" raise RuntimeError(msg) diff --git a/tests/test_sig.py b/tests/test_sig.py index b579e1a..185f6a2 100644 --- a/tests/test_sig.py +++ b/tests/test_sig.py @@ -2,7 +2,7 @@ import random import oqs -from oqs.oqs import Signature +from oqs.oqs import Signature, native # Sigs for which unit testing is disabled disabled_sig_patterns = [] @@ -44,6 +44,31 @@ def check_correctness_with_ctx_str(alg_name: str) -> None: assert sig.verify_with_ctx_str(message, signature, context, public_key) # noqa: S101 +def test_sig_with_ctx_support_detection() -> None: + """ + Test that sig_with_ctx_support matches the C API and that sign_with_ctx_str + raises on unsupported algorithms. + """ + for alg_name in oqs.get_enabled_sig_mechanisms(): + with Signature(alg_name) as sig: + # Check Python attribute matches C API + c_api_result = native().OQS_SIG_supports_ctx_str(sig.method_name) + assert bool(sig.sig_with_ctx_support) == bool(c_api_result), ( # noqa: S101 + f"sig_with_ctx_support mismatch for {alg_name}" + ) + # If not supported, sign_with_ctx_str should raise + if not sig.sig_with_ctx_support: + try: + sig.sign_with_ctx_str(b"msg", b"context") + except RuntimeError as e: + if "not supported" not in str(e): + msg = f"Unexpected exception message: {e}" + raise AssertionError(msg) from e + else: + msg = f"sign_with_ctx_str did not raise for {alg_name} without context support" + raise AssertionError(msg) + + def test_wrong_message() -> tuple[None, str]: for alg_name in oqs.get_enabled_sig_mechanisms(): if any(item in alg_name for item in disabled_sig_patterns):