Skip to content

Commit 93fcb34

Browse files
authored
Updates from master (#781)
* Caching for the LLVM compilation * more random-running functionality * `assume` cheatcode * Discriminant bit width fix * Optimisations to eliminate `transmute` thunks
2 parents a1a43e8 + 145951f commit 93fcb34

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+26069
-167
lines changed

kmir/src/kmir/__main__.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pyk.proof.show import APRProofShow
1515
from pyk.proof.tui import APRProofViewer
1616

17-
from .build import HASKELL_DEF_DIR, LLVM_DEF_DIR, LLVM_LIB_DIR
17+
from .build import HASKELL_DEF_DIR, LLVM_LIB_DIR
1818
from .cargo import CargoProject
1919
from .kmir import KMIR, KMIRAPRNodePrinter
2020
from .linker import link
@@ -42,8 +42,6 @@
4242

4343

4444
def _kmir_run(opts: RunOpts) -> None:
45-
kmir = KMIR(HASKELL_DEF_DIR) if opts.haskell_backend else KMIR(LLVM_DEF_DIR)
46-
4745
if opts.file:
4846
smir_info = SMIRInfo.from_file(Path(opts.file))
4947
else:
@@ -54,11 +52,17 @@ def _kmir_run(opts: RunOpts) -> None:
5452
# target = opts.bin if opts.bin else cargo.default_target
5553
smir_info = cargo.smir_for_project(clean=False)
5654

57-
with tempfile.TemporaryDirectory() as work_dir:
58-
kmir = KMIR.from_kompiled_kore(smir_info, symbolic=opts.haskell_backend, target_dir=work_dir)
55+
def run(target_dir: Path):
56+
kmir = KMIR.from_kompiled_kore(smir_info, symbolic=opts.haskell_backend, target_dir=target_dir)
5957
result = kmir.run_smir(smir_info, start_symbol=opts.start_symbol, depth=opts.depth)
6058
print(kmir.kore_to_pretty(result))
6159

60+
if opts.target_dir:
61+
run(target_dir=opts.target_dir)
62+
else:
63+
with tempfile.TemporaryDirectory() as target_dir:
64+
run(target_dir=Path(target_dir))
65+
6266

6367
def _kmir_prove_rs(opts: ProveRSOpts) -> None:
6468
proof = KMIR.prove_rs(opts)
@@ -185,6 +189,7 @@ def _arg_parser() -> ArgumentParser:
185189
run_target_selection = run_parser.add_mutually_exclusive_group()
186190
run_target_selection.add_argument('--bin', metavar='TARGET', help='Target to run')
187191
run_target_selection.add_argument('--file', metavar='SMIR', help='SMIR json file to execute')
192+
run_parser.add_argument('--target-dir', type=Path, metavar='TARGET_DIR', help='SMIR kompilation target directory')
188193
run_parser.add_argument('--depth', type=int, metavar='DEPTH', help='Depth to execute')
189194
run_parser.add_argument(
190195
'--start-symbol', type=str, metavar='SYMBOL', default='main', help='Symbol name to begin execution from'
@@ -313,6 +318,7 @@ def _parse_args(ns: Namespace) -> KMirOpts:
313318
return RunOpts(
314319
bin=ns.bin,
315320
file=ns.file,
321+
target_dir=ns.target_dir,
316322
depth=ns.depth,
317323
start_symbol=ns.start_symbol,
318324
haskell_backend=ns.haskell_backend,

kmir/src/kmir/decoding.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@
2525
UintT,
2626
)
2727
from .value import (
28-
NO_METADATA,
28+
NO_SIZE,
2929
AggregateValue,
3030
AllocRefValue,
3131
BoolValue,
3232
DynamicSize,
3333
IntValue,
34+
Metadata,
3435
RangeValue,
3536
StaticSize,
3637
StrValue,
@@ -120,18 +121,30 @@ def _decode_memory_alloc_or_unable(
120121
except KeyError:
121122
return UnableToDecodeValue(f'Unknown pointee type: {pointee_ty}')
122123

123-
metadata = _metadata(pointee_type_info)
124+
metadata_size = _metadata_size(pointee_type_info)
124125

125126
if len(data) == 8:
126127
# single slim pointer (assumes usize == u64)
127-
return AllocRefValue(alloc_id=alloc_id, metadata=metadata)
128+
return AllocRefValue(
129+
alloc_id=alloc_id,
130+
metadata=Metadata(
131+
size=metadata_size,
132+
pointer_offset=0,
133+
origin_size=metadata_size,
134+
),
135+
)
128136

129-
if len(data) == 16 and metadata == DynamicSize(1):
137+
if len(data) == 16 and metadata_size == DynamicSize(1):
130138
# sufficient data to decode dynamic size (assumes usize == u64)
131139
# expect fat pointer
140+
actual_size = DynamicSize(int.from_bytes(data[8:16], byteorder='little', signed=False))
132141
return AllocRefValue(
133142
alloc_id=alloc_id,
134-
metadata=DynamicSize(int.from_bytes(data[8:16], byteorder='little', signed=False)),
143+
metadata=Metadata(
144+
size=actual_size,
145+
pointer_offset=0,
146+
origin_size=actual_size,
147+
),
135148
)
136149

137150
return UnableToDecodeValue(f'Unable to decode alloc: {data!r}, of type: {type_info}')
@@ -145,14 +158,14 @@ def _pointee_ty(type_info: TypeMetadata) -> Ty | None:
145158
return None
146159

147160

148-
def _metadata(type_info: TypeMetadata) -> MetadataSize:
161+
def _metadata_size(type_info: TypeMetadata) -> MetadataSize:
149162
match type_info:
150163
case ArrayT(length=None):
151164
return DynamicSize(1) # 1 is a placeholder, the actual size is inferred from the slice data
152165
case ArrayT(length=int() as length):
153166
return StaticSize(length)
154167
case _:
155-
return NO_METADATA
168+
return NO_SIZE
156169

157170

158171
def decode_value_or_unable(data: bytes, type_info: TypeMetadata, types: Mapping[Ty, TypeMetadata]) -> Value:

kmir/src/kmir/kast.py

Lines changed: 161 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
from itertools import count
45
from typing import TYPE_CHECKING, NamedTuple
56

67
from pyk.kast.inner import KApply, KSort, KVariable, Subst, build_cons
@@ -11,10 +12,23 @@
1112
from pyk.kast.prelude.utils import token
1213

1314
from .ty import ArrayT, BoolT, EnumT, IntT, PtrT, RefT, StructT, TupleT, Ty, UintT, UnionT
14-
from .value import BoolValue, IntValue
15+
from .value import (
16+
NO_SIZE,
17+
AggregateValue,
18+
BoolValue,
19+
DynamicSize,
20+
IntValue,
21+
Local,
22+
Metadata,
23+
Place,
24+
PtrLocalValue,
25+
RangeValue,
26+
RefValue,
27+
StaticSize,
28+
)
1529

1630
if TYPE_CHECKING:
17-
from collections.abc import Iterable, Mapping, Sequence
31+
from collections.abc import Iterable, Iterator, Mapping, Sequence
1832
from random import Random
1933
from typing import Any, Final
2034

@@ -29,6 +43,9 @@
2943
_LOGGER: Final = logging.getLogger(__name__)
3044

3145

46+
RANDOM_MAX_ARRAY_LEN: Final = 32
47+
48+
3249
LOCAL_0: Final = KApply('newLocal', KApply('ty', token(0)), KApply('Mutability::Not'))
3350

3451

@@ -474,25 +491,7 @@ def to_kast(self) -> KInner:
474491

475492

476493
def _random_locals(random: Random, args: Sequence[_Local], types: Mapping[Ty, TypeMetadata]) -> list[KInner]:
477-
res: list[KInner] = [LOCAL_0]
478-
pointees: list[KInner] = []
479-
480-
next_ref = len(args) + 1
481-
for arg in args:
482-
rvres = _random_value(
483-
random=random,
484-
local=arg,
485-
types=types,
486-
next_ref=next_ref,
487-
)
488-
res.append(rvres.value.to_kast())
489-
match rvres:
490-
case PointerRes(pointee=pointee):
491-
pointees.append(pointee.to_kast())
492-
next_ref += 1
493-
494-
res += pointees
495-
return res
494+
return _RandomArgGen(random=random, args=args, types=types).run()
496495

497496

498497
class SimpleRes(NamedTuple):
@@ -501,55 +500,155 @@ class SimpleRes(NamedTuple):
501500

502501
class ArrayRes(NamedTuple):
503502
value: TypedValue
504-
metadata: MetadataSize
503+
metadata_size: MetadataSize
505504

506505

507-
class PointerRes(NamedTuple):
508-
value: TypedValue
509-
pointee: TypedValue
506+
RandomValueRes = SimpleRes | ArrayRes
510507

511508

512-
RandomValueRes = SimpleRes | ArrayRes | PointerRes
509+
class _RandomArgGen:
510+
_random: Random
511+
_args: Sequence[_Local]
512+
_types: Mapping[Ty, TypeMetadata]
513+
_pointees: list[TypedValue]
514+
_ref: Iterator[int]
513515

516+
def __init__(self, *, random: Random, args: Sequence[_Local], types: Mapping[Ty, TypeMetadata]):
517+
self._random = random
518+
self._args = args
519+
self._types = types
520+
self._pointees = []
521+
self._ref = count(len(args) + 1)
514522

515-
def _random_value(
516-
*,
517-
random: Random,
518-
local: _Local,
519-
types: Mapping[Ty, TypeMetadata],
520-
next_ref: int,
521-
) -> RandomValueRes:
522-
try:
523-
type_info = types[local.ty]
524-
except KeyError as err:
525-
raise ValueError(f'Unknown type: {local.ty}') from err
526-
527-
match type_info:
528-
case BoolT():
529-
return SimpleRes(
530-
TypedValue.from_local(
531-
value=_random_bool_value(random),
532-
local=local,
523+
def run(self) -> list[KInner]:
524+
res: list[KInner] = [LOCAL_0]
525+
res.extend(self._random_value(arg).value.to_kast() for arg in self._args)
526+
res.extend(pointee.to_kast() for pointee in self._pointees)
527+
return res
528+
529+
def _random_value(self, local: _Local) -> RandomValueRes:
530+
try:
531+
type_info = self._types[local.ty]
532+
except KeyError as err:
533+
raise ValueError(f'Unknown type: {local.ty}') from err
534+
535+
match type_info:
536+
case BoolT():
537+
return SimpleRes(
538+
TypedValue.from_local(
539+
value=self._random_bool_value(),
540+
local=local,
541+
)
533542
)
534-
)
535-
case IntT() | UintT():
536-
return SimpleRes(
537-
TypedValue.from_local(
538-
value=_random_int_value(random, type_info),
539-
local=local,
540-
),
541-
)
542-
case _:
543-
raise ValueError(f'Type unsupported for random value generator: {type_info}')
543+
case IntT() | UintT():
544+
return SimpleRes(
545+
TypedValue.from_local(
546+
value=self._random_int_value(type_info),
547+
local=local,
548+
),
549+
)
550+
case EnumT(discriminants=discriminants, fields=fields):
551+
return SimpleRes(
552+
TypedValue.from_local(
553+
value=self._random_enum_value(mut=local.mut, discriminants=discriminants, fields=fields),
554+
local=local,
555+
),
556+
)
557+
case StructT(fields=tys) | TupleT(components=tys):
558+
return SimpleRes(
559+
TypedValue.from_local(
560+
value=self._random_struct_or_tuple_value(mut=local.mut, tys=tys),
561+
local=local,
562+
),
563+
)
564+
case ArrayT(element_type=elem_ty, length=length):
565+
value, metadata_size = self._random_array_value(mut=local.mut, elem_ty=elem_ty, length=length)
566+
return ArrayRes(
567+
value=TypedValue.from_local(
568+
value=value,
569+
local=local,
570+
),
571+
metadata_size=metadata_size,
572+
)
573+
case PtrT() | RefT():
574+
return SimpleRes(
575+
value=TypedValue.from_local(
576+
value=self._random_ptr_value(mut=local.mut, type_info=type_info),
577+
local=local,
578+
),
579+
)
580+
case _:
581+
raise ValueError(f'Type unsupported for random value generator: {type_info}')
544582

583+
def _random_bool_value(self) -> BoolValue:
584+
return BoolValue(bool(self._random.getrandbits(1)))
545585

546-
def _random_bool_value(random: Random) -> BoolValue:
547-
return BoolValue(bool(random.getrandbits(1)))
586+
def _random_int_value(self, type_info: IntT | UintT) -> IntValue:
587+
return IntValue(
588+
value=self._random.randint(type_info.min, type_info.max),
589+
nbits=type_info.nbits,
590+
signed=isinstance(type_info, IntT),
591+
)
548592

593+
def _random_enum_value(
594+
self,
595+
*,
596+
mut: bool,
597+
discriminants: list[int],
598+
fields: list[list[Ty]],
599+
) -> AggregateValue:
600+
variant_idx = self._random.randrange(len(discriminants))
601+
values = self._random_fields(tys=fields[variant_idx], mut=mut)
602+
return AggregateValue(variant_idx, values)
603+
604+
def _random_struct_or_tuple_value(self, *, mut: bool, tys: list[Ty]) -> AggregateValue:
605+
return AggregateValue(0, fields=self._random_fields(tys=tys, mut=mut))
606+
607+
def _random_fields(self, *, tys: list[Ty], mut: bool) -> tuple[Value, ...]:
608+
return tuple(self._random_value(local=_Local(ty=ty, mut=mut)).value.value for ty in tys)
609+
610+
def _random_array_value(self, *, mut: bool, elem_ty: Ty, length: int | None) -> tuple[RangeValue, MetadataSize]:
611+
metadata_size: MetadataSize
612+
if length is None:
613+
length = self._random.randint(0, RANDOM_MAX_ARRAY_LEN)
614+
metadata_size = DynamicSize(length)
615+
else:
616+
metadata_size = StaticSize(length)
617+
618+
elems = tuple(self._random_value(local=_Local(ty=elem_ty, mut=mut)).value.value for _ in range(length))
619+
value = RangeValue(elems)
620+
return value, metadata_size
621+
622+
def _random_ptr_value(self, mut: bool, type_info: PtrT | RefT) -> PtrLocalValue | RefValue:
623+
pointee_local = _Local(ty=type_info.pointee_type, mut=mut)
624+
pointee_res = self._random_value(pointee_local)
625+
self._pointees.append(pointee_res.value)
626+
627+
metadata_size: MetadataSize
628+
match pointee_res:
629+
case ArrayRes(metadata_size=metadata_size):
630+
pass
631+
case _:
632+
metadata_size = NO_SIZE
549633

550-
def _random_int_value(random: Random, type_info: IntT | UintT) -> IntValue:
551-
return IntValue(
552-
value=random.randint(type_info.min, type_info.max),
553-
nbits=type_info.nbits,
554-
signed=isinstance(type_info, IntT),
555-
)
634+
metadata = Metadata(size=metadata_size, pointer_offset=0, origin_size=metadata_size)
635+
636+
ref = next(self._ref)
637+
638+
match type_info:
639+
case PtrT():
640+
return PtrLocalValue(
641+
stack_depth=0,
642+
place=Place(local=Local(ref)),
643+
mut=mut,
644+
metadata=metadata,
645+
)
646+
case RefT():
647+
return RefValue(
648+
stack_depth=0,
649+
place=Place(local=Local(ref)),
650+
mut=mut,
651+
metadata=metadata,
652+
)
653+
case _:
654+
raise AssertionError()

0 commit comments

Comments
 (0)