Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions scripts/check-magic
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
#

import re
import math
import pathlib

from sympy import simplify, sympify, Function
from sympy import simplify, sympify, Function, Rational

def get_c_source_files():
return get_files("mlkem/**/*.c")
Expand All @@ -20,6 +21,17 @@ def get_header_files():
def get_files(pattern):
return list(map(str, pathlib.Path().glob(pattern)))

# Standard color definitions
GREEN="\033[32m"
RED="\033[31m"
BLUE="\033[94m"
BOLD="\033[1m"
NORMAL="\033[0m"

CHECKED = f"{GREEN}✓{NORMAL}"
FAIL = f"{RED}✗{NORMAL}"
REMEMBERED = f"{BLUE}⊢{NORMAL}"

def check_magic_numbers():
mlkem_q = 3329
exceptions = [mlkem_q,
Expand Down Expand Up @@ -64,9 +76,21 @@ def check_magic_numbers():
y = int(y)
m = int(m)
return signed_mod(pow(x,y,m),m)
def safe_round(x):
if x - math.floor(x) == Rational(1, 2):
raise ValueError(f"Ambiguous rounding: {x} is an odd multiple of 0.5 and it is unclear if round-up or round-down is desired")
return round(x)
def safe_floordiv(x, y):
x = int(x)
y = int(y)
if x % y != 0:
raise ValueError(f"Non-integral division: {x} // {y} has remainder {x % y}")
return x // y
locals_dict = {'signed_mod': signed_mod,
'unsigned_mod': unsigned_mod,
'pow': pow_mod }
'pow': pow_mod,
'round': safe_round,
'intdiv': safe_floordiv }
locals_dict.update(known_magics)
return sympify(m, locals=locals_dict)

Expand All @@ -82,6 +106,7 @@ def check_magic_numbers():
enabled = True
magic_dict = {'MLKEM_Q': mlkem_q}
magic_expr = None
verified_magics = {}
for i, l in enumerate(content):
if enabled is True and disable_marker in l:
enabled = False
Expand All @@ -94,6 +119,12 @@ def check_magic_numbers():
l, g = get_magic(l)
if g is not None:
magic_val, magic_expr = g
magic_val_check = evaluate_magic(magic_expr, magic_dict)
if magic_val != magic_val_check:
print(f"{FAIL}:{filename}:{i}: Mismatching magic annotation: {magic_val} != {magic_expr} (= {magic_val_check})")
exit(1)
print(f"{REMEMBERED}:{filename}:{i}: Verified explanation {magic_val} == {magic_expr}")
verified_magics[magic_val] = magic_expr

found = next(re.finditer(pattern, l), None)
if found is None:
Expand All @@ -103,16 +134,12 @@ def check_magic_numbers():
if is_exception(filename, l, magic):
continue

if magic_expr is not None:
val = evaluate_magic(magic_expr, magic_dict)
if magic_val != val:
raise Exception(f"{filename}:{i}: Mismatching magic annotation: {magic_val} != {val}")
if val == magic:
print(f"[OK] {filename}:{i}: Verified magic constant {magic} == {magic_expr}")
else:
raise Exception(f"{filename}:{i}: Magic constant mismatch {magic} != {magic_expr}")
else:
raise Exception(f"{filename}:{i}: No explanation for magic value {magic}")
explanation = verified_magics.get(magic, None)
if explanation is None:
print(f"{FAIL}:{filename}:{i}: No explanation for magic value {magic}")
exit(1)

print(f"{CHECKED}:{filename}:{i}: {magic} previously explained as {explanation}")

# If this is a #define's clause, remember it
define = get_define(l)
Expand Down
Loading