Skip to content

Commit 84573c9

Browse files
authored
Implement decoding for niche-encoded enums (#743)
1 parent dc978a0 commit 84573c9

File tree

14 files changed

+10546
-24
lines changed

14 files changed

+10546
-24
lines changed

kmir/src/kmir/alloc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def from_dict(dct: dict[str, Any]) -> GlobalAlloc:
3535
case {'Memory': _}:
3636
return Memory.from_dict(dct)
3737
case _:
38-
raise ValueError('Unsupported or invalid GlobalAlloc data: {dct}')
38+
raise ValueError(f'Unsupported or invalid GlobalAlloc data: {dct}')
3939

4040

4141
@dataclass

kmir/src/kmir/decoding.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
ArbitraryFields,
1212
ArrayT,
1313
BoolT,
14-
Direct,
1514
EnumT,
1615
Initialized,
1716
IntT,
@@ -42,7 +41,7 @@
4241

4342
from pyk.kast import KInner
4443

45-
from .ty import FieldsShape, LayoutShape, MachineSize, Scalar, TagEncoding, Ty, TypeMetadata, UintTy
44+
from .ty import FieldsShape, IntegerLength, LayoutShape, MachineSize, Scalar, TagEncoding, Ty, TypeMetadata, UintTy
4645
from .value import Metadata
4746

4847

@@ -241,7 +240,7 @@ def _decode_enum(
241240
fields=fields,
242241
offsets=offsets,
243242
# ---
244-
tag_index=index,
243+
index=index,
245244
# ---
246245
types=types,
247246
)
@@ -282,15 +281,15 @@ def _decode_enum_single(
282281
discriminants: list[int],
283282
fields: list[list[Ty]],
284283
offsets: list[MachineSize],
285-
tag_index: int,
284+
index: int,
286285
types: Mapping[Ty, TypeMetadata],
287286
) -> Value:
287+
assert index == 0, 'Assumed index to always be 0 for Single(index)'
288+
288289
assert len(fields) == 1, 'Expected a single list of field types for single-variant enum'
289290
tys = fields[0]
290291

291292
assert len(discriminants) == 1, 'Expected a single discriminant for single-variant enum'
292-
discriminant = discriminants[0]
293-
assert tag_index == discriminant, 'Assumed tag_index to be the same as the discriminant'
294293

295294
field_values = _decode_fields(data=data, tys=tys, offsets=offsets, types=types)
296295
return AggregateValue(0, field_values)
@@ -310,18 +309,16 @@ def _decode_enum_multiple(
310309
# ---
311310
types: Mapping[Ty, TypeMetadata],
312311
) -> Value:
313-
if not isinstance(tag_encoding, Direct):
314-
raise ValueError(f'Unsupported encoding: {tag_encoding}')
315-
316-
assert tag_field == 0, 'Assumed tag field to be zero'
317312
assert len(offsets) == 1, 'Assumed offsets to only contain the tag offset'
318-
tag_offset = offsets[0]
319-
tag_value = _extract_tag_value(data=data, tag_offset=tag_offset, tag=tag)
313+
assert tag_field == 0, 'Assumed tag field to be zero accordingly'
314+
tag_offset = offsets[tag_field]
315+
tag_value, width = _extract_tag(data=data, tag_offset=tag_offset, tag=tag)
316+
discriminant = tag_encoding.decode(tag_value, width=width)
320317

321318
try:
322-
variant_idx = discriminants.index(tag_value)
319+
variant_idx = discriminants.index(discriminant)
323320
except ValueError as err:
324-
raise ValueError(f'Tag not found: {tag_value}') from err
321+
raise ValueError(f'Discriminant not found: {discriminant}') from err
325322

326323
tys = fields[variant_idx]
327324

@@ -350,16 +347,17 @@ def _decode_fields(
350347
return res
351348

352349

353-
def _extract_tag_value(*, data: bytes, tag_offset: MachineSize, tag: Scalar) -> int:
350+
def _extract_tag(*, data: bytes, tag_offset: MachineSize, tag: Scalar) -> tuple[int, IntegerLength]:
354351
match tag:
355352
case Initialized(
356353
value=PrimitiveInt(
357354
length=length,
358-
signed=signed,
355+
signed=False,
359356
),
360357
valid_range=_,
361358
):
362359
tag_data = data[tag_offset.in_bytes : tag_offset.in_bytes + length.value]
363-
return int.from_bytes(tag_data, byteorder='little', signed=signed)
360+
tag_value = int.from_bytes(tag_data, byteorder='little', signed=False)
361+
return tag_value, length
364362
case _:
365-
raise ValueError('Unsupported tag: {tag}')
363+
raise ValueError(f'Unsupported tag: {tag}')

kmir/src/kmir/ty.py

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from abc import ABC
3+
from abc import ABC, abstractmethod
44
from dataclasses import dataclass
55
from enum import Enum
66
from functools import cached_property
@@ -151,6 +151,13 @@ def from_raw(data: Any) -> EnumT:
151151
case _:
152152
raise _cannot_parse_as('EnumT', data)
153153

154+
def nbytes(self, types: Mapping[Ty, TypeMetadata]) -> int:
155+
match self.layout:
156+
case None:
157+
raise ValueError(f'Cannot determine size, layout is missing for: {self}')
158+
case LayoutShape(size=size):
159+
return size.in_bytes
160+
154161

155162
@dataclass
156163
class LayoutShape:
@@ -349,6 +356,11 @@ class IntegerLength(Enum):
349356
I64 = 8
350357
I128 = 16
351358

359+
def wrapping_sub(self, x: int, y: int) -> int:
360+
bit_width = 8 * self.value
361+
mask = (1 << bit_width) - 1
362+
return (x - y) & mask
363+
352364

353365
@dataclass
354366
class Float(Primitive): ...
@@ -364,18 +376,71 @@ def from_raw(data: Any) -> TagEncoding:
364376
match data:
365377
case 'Direct':
366378
return Direct()
367-
case {'Niche': _}:
368-
return Niche()
379+
case {
380+
'Niche': {
381+
'untagged_variant': untagged_variant,
382+
'niche_variants': niche_variants,
383+
'niche_start': niche_start,
384+
},
385+
}:
386+
return Niche(
387+
untagged_variant=int(untagged_variant),
388+
niche_variants=RangeInclusive.from_raw(niche_variants),
389+
niche_start=int(niche_start),
390+
)
369391
case _:
370392
raise _cannot_parse_as('TagEncoding', data)
371393

394+
@abstractmethod
395+
def decode(self, tag: int, *, width: IntegerLength) -> int: ...
396+
372397

373398
@dataclass
374-
class Direct(TagEncoding): ...
399+
class Direct(TagEncoding):
400+
def decode(self, tag: int, *, width: IntegerLength) -> int:
401+
# The tag directly stores the discriminant.
402+
return tag
375403

376404

377405
@dataclass
378-
class Niche(TagEncoding): ...
406+
class Niche(TagEncoding):
407+
untagged_variant: int
408+
niche_variants: RangeInclusive
409+
niche_start: int
410+
411+
def decode(self, tag: int, *, width: IntegerLength) -> int:
412+
# For this encoding, the discriminant and variant index of each variant coincide.
413+
# To recover the variant index i from tag:
414+
# i = tag.wrapping_sub(niche_start) + niche_variants.start
415+
# If i ends up outside niche_variants, the tag must have encoded the untagged_variant.
416+
i = width.wrapping_sub(tag, self.niche_start) + self.niche_variants.start
417+
if not i in self.niche_variants:
418+
return self.untagged_variant
419+
return i
420+
421+
422+
class RangeInclusive(NamedTuple):
423+
start: int
424+
end: int
425+
426+
@staticmethod
427+
def from_raw(data: Any) -> RangeInclusive:
428+
match data:
429+
case {
430+
'start': start,
431+
'end': end,
432+
}:
433+
return RangeInclusive(
434+
start=int(start),
435+
end=int(end),
436+
)
437+
case _:
438+
raise _cannot_parse_as('RangeInclusive', data)
439+
440+
def __contains__(self, x: object) -> bool:
441+
if isinstance(x, int):
442+
return self.start <= x <= self.end
443+
raise TypeError('Method RangeInclusive.__contains__ is only supported for int, got: {x}')
379444

380445

381446
class WrappingRange(NamedTuple):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Aggregate ( variantIdx ( 0 ) , .List )
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
{
2+
"bytes": [
3+
0
4+
],
5+
"types": [
6+
[
7+
0,
8+
{
9+
"PrimitiveType": {
10+
"Uint": "U8"
11+
}
12+
}
13+
]
14+
],
15+
"typeInfo": {
16+
"EnumType": {
17+
"name": "core::option::Option<core::num::NonZero<u8>>",
18+
"adt_def": 100,
19+
"discriminants": [
20+
0,
21+
1
22+
],
23+
"fields": [
24+
[],
25+
[
26+
0
27+
]
28+
],
29+
"layout": {
30+
"fields": {
31+
"Arbitrary": {
32+
"offsets": [
33+
{
34+
"num_bits": 0
35+
}
36+
]
37+
}
38+
},
39+
"variants": {
40+
"Multiple": {
41+
"tag": {
42+
"Initialized": {
43+
"value": {
44+
"Int": {
45+
"length": "I8",
46+
"signed": false
47+
}
48+
},
49+
"valid_range": {
50+
"start": 1,
51+
"end": 0
52+
}
53+
}
54+
},
55+
"tag_encoding": {
56+
"Niche": {
57+
"untagged_variant": 1,
58+
"niche_variants": {
59+
"start": 0,
60+
"end": 0
61+
},
62+
"niche_start": 0
63+
}
64+
},
65+
"tag_field": 0,
66+
"variants": [
67+
{
68+
"fields": {
69+
"Arbitrary": {
70+
"offsets": []
71+
}
72+
},
73+
"variants": {
74+
"Single": {
75+
"index": 0
76+
}
77+
},
78+
"abi": {
79+
"Aggregate": {
80+
"sized": true
81+
}
82+
},
83+
"abi_align": 1,
84+
"size": {
85+
"num_bits": 0
86+
}
87+
},
88+
{
89+
"fields": {
90+
"Arbitrary": {
91+
"offsets": [
92+
{
93+
"num_bits": 0
94+
}
95+
]
96+
}
97+
},
98+
"variants": {
99+
"Single": {
100+
"index": 1
101+
}
102+
},
103+
"abi": {
104+
"Scalar": {
105+
"Initialized": {
106+
"value": {
107+
"Int": {
108+
"length": "I8",
109+
"signed": false
110+
}
111+
},
112+
"valid_range": {
113+
"start": 1,
114+
"end": 255
115+
}
116+
}
117+
}
118+
},
119+
"abi_align": 1,
120+
"size": {
121+
"num_bits": 8
122+
}
123+
}
124+
]
125+
}
126+
},
127+
"abi": {
128+
"Scalar": {
129+
"Initialized": {
130+
"value": {
131+
"Int": {
132+
"length": "I8",
133+
"signed": false
134+
}
135+
},
136+
"valid_range": {
137+
"start": 1,
138+
"end": 0
139+
}
140+
}
141+
}
142+
},
143+
"abi_align": 1,
144+
"size": {
145+
"num_bits": 8
146+
}
147+
}
148+
}
149+
}
150+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Aggregate ( variantIdx ( 1 ) , ListItem ( Integer ( 123 , 8 , false ) ) )

0 commit comments

Comments
 (0)