Skip to content

Commit c78806f

Browse files
wanda-phiwhitequark
authored andcommitted
sim: make driving parts of a signal from distinct modules possible.
Fixes (part of) #1454.
1 parent a154b64 commit c78806f

File tree

6 files changed

+178
-17
lines changed

6 files changed

+178
-17
lines changed

amaranth/hdl/_xfrm.py

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"FragmentTransformer",
1818
"TransformedElaboratable",
1919
"DomainCollector", "DomainRenamer", "DomainLowerer",
20+
"LHSMaskCollector",
2021
"ResetInserter", "EnableInserter"]
2122

2223

@@ -601,6 +602,71 @@ def on_fragment(self, fragment):
601602
return super().on_fragment(fragment)
602603

603604

605+
class LHSMaskCollector:
606+
def __init__(self):
607+
self.lhs = SignalDict()
608+
609+
def visit_stmt(self, stmt):
610+
if type(stmt) is Assign:
611+
self.visit_value(stmt.lhs, ~0)
612+
elif type(stmt) is Switch:
613+
for (_, substmt, _) in stmt.cases:
614+
self.visit_stmt(substmt)
615+
elif type(stmt) in (Property, Print):
616+
pass
617+
elif isinstance(stmt, Iterable):
618+
for substmt in stmt:
619+
self.visit_stmt(substmt)
620+
else:
621+
assert False # :nocov:
622+
623+
def visit_value(self, value, mask):
624+
if type(value) in (Signal, ClockSignal, ResetSignal):
625+
mask &= (1 << len(value)) - 1
626+
self.lhs.setdefault(value, 0)
627+
self.lhs[value] |= mask
628+
elif type(value) is Operator:
629+
assert value.operator in ("s", "u")
630+
self.visit_value(value.operands[0], mask)
631+
elif type(value) is Slice:
632+
slice_mask = (1 << value.stop) - (1 << value.start)
633+
mask <<= value.start
634+
mask &= slice_mask
635+
self.visit_value(value.value, mask)
636+
elif type(value) is Part:
637+
# Could be more accurate, but if you're relying on such details, you're not seeing
638+
# the Light of Heaven.
639+
self.visit_value(value.value, ~0)
640+
elif type(value) is Concat:
641+
for part in value.parts:
642+
self.visit_value(part, mask)
643+
mask >>= len(part)
644+
elif type(value) is SwitchValue:
645+
for (_, subvalue) in value.cases:
646+
self.visit_value(subvalue, mask)
647+
else:
648+
assert False # :nocov:
649+
650+
def chunks(self):
651+
for signal, mask in self.lhs.items():
652+
if mask == (1 << len(signal)) - 1:
653+
yield signal, 0, None
654+
else:
655+
start = 0
656+
while start < len(signal):
657+
if ((mask >> start) & 1) == 0:
658+
start += 1
659+
else:
660+
stop = start
661+
while stop < len(signal) and ((mask >> stop) & 1) == 1:
662+
stop += 1
663+
yield (signal, start, stop)
664+
start = stop
665+
666+
def masks(self):
667+
yield from self.lhs.items()
668+
669+
604670
class _ControlInserter(FragmentTransformer):
605671
def __init__(self, controls):
606672
self.src_loc = None
@@ -615,10 +681,9 @@ def on_fragment(self, fragment):
615681
for domain, statements in fragment.statements.items():
616682
if domain == "comb" or domain not in self.controls:
617683
continue
618-
signals = SignalSet()
619-
for stmt in statements:
620-
signals |= stmt._lhs_signals()
621-
self._insert_control(new_fragment, domain, signals)
684+
lhs_masks = LHSMaskCollector()
685+
lhs_masks.visit_stmt(statements)
686+
self._insert_control(new_fragment, domain, lhs_masks)
622687
return new_fragment
623688

624689
def _insert_control(self, fragment, domain, signals):
@@ -630,13 +695,20 @@ def __call__(self, value, *, src_loc_at=0):
630695

631696

632697
class ResetInserter(_ControlInserter):
633-
def _insert_control(self, fragment, domain, signals):
634-
stmts = [s.eq(Const(s.init, s.shape())) for s in signals if not s.reset_less]
698+
def _insert_control(self, fragment, domain, lhs_masks):
699+
stmts = []
700+
for signal, start, stop in lhs_masks.chunks():
701+
if signal.reset_less:
702+
continue
703+
if start == 0 and stop is None:
704+
stmts.append(signal.eq(Const(signal.init, signal.shape())))
705+
else:
706+
stmts.append(signal[start:stop].eq(Const(signal.init, signal.shape())[start:stop]))
635707
fragment.add_statements(domain, Switch(self.controls[domain], [(1, stmts, None)], src_loc=self.src_loc))
636708

637709

638710
class EnableInserter(_ControlInserter):
639-
def _insert_control(self, fragment, domain, signals):
711+
def _insert_control(self, fragment, domain, _lhs_masks):
640712
if domain in fragment.statements:
641713
fragment.statements[domain] = _StatementList([Switch(
642714
self.controls[domain],

amaranth/sim/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class BaseSignalState:
2323
curr = NotImplemented
2424
next = NotImplemented
2525

26-
def update(self, value):
26+
def update(self, value, mask=~0):
2727
raise NotImplementedError # :nocov:
2828

2929

amaranth/sim/_pyrtl.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from ..hdl import *
77
from ..hdl._ast import SignalSet, _StatementList, Property
8-
from ..hdl._xfrm import ValueVisitor, StatementVisitor
8+
from ..hdl._xfrm import ValueVisitor, StatementVisitor, LHSMaskCollector
99
from ..hdl._mem import MemoryInstance
1010
from ._base import BaseProcess
1111
from ._pyeval import value_to_string
@@ -487,19 +487,20 @@ def __call__(self, fragment):
487487
for domain_name in domains:
488488
domain_stmts = fragment.statements.get(domain_name, _StatementList())
489489
domain_process = PyRTLProcess(is_comb=domain_name == "comb")
490-
domain_signals = domain_stmts._lhs_signals()
490+
lhs_masks = LHSMaskCollector()
491+
lhs_masks.visit_stmt(domain_stmts)
491492

492493
if isinstance(fragment, MemoryInstance):
493494
for port in fragment._read_ports:
494495
if port._domain == domain_name:
495-
domain_signals.update(port._data._lhs_signals())
496+
lhs_masks.visit_value(port._data, ~0)
496497

497498
emitter = _PythonEmitter()
498499
emitter.append(f"def run():")
499500
emitter._level += 1
500501

501502
if domain_name == "comb":
502-
for signal in domain_signals:
503+
for (signal, _) in lhs_masks.masks():
503504
signal_index = self.state.get_signal(signal)
504505
self.state.slots[signal_index].is_comb = True
505506
emitter.append(f"next_{signal_index} = {signal.init}")
@@ -533,7 +534,7 @@ def __call__(self, fragment):
533534
if domain.async_reset and domain.rst is not None:
534535
self.state.add_signal_waker(domain.rst, edge_waker(domain_process, 1))
535536

536-
for signal in domain_signals:
537+
for (signal, _) in lhs_masks.masks():
537538
signal_index = self.state.get_signal(signal)
538539
emitter.append(f"next_{signal_index} = slots[{signal_index}].next")
539540

@@ -546,7 +547,7 @@ def __call__(self, fragment):
546547
emitter.append(f"if {rst}:")
547548
with emitter.indent():
548549
emitter.append("pass")
549-
for signal in domain_signals:
550+
for (signal, _) in lhs_masks.masks():
550551
if not signal.reset_less:
551552
signal_index = self.state.get_signal(signal)
552553
emitter.append(f"next_{signal_index} = {signal.init}")
@@ -592,9 +593,11 @@ def __call__(self, fragment):
592593

593594
lhs(port._data)(data)
594595

595-
for signal in domain_signals:
596+
for (signal, mask) in lhs_masks.masks():
597+
if signal.shape().signed and (mask & 1 << (len(signal) - 1)):
598+
mask |= -1 << len(signal)
596599
signal_index = self.state.get_signal(signal)
597-
emitter.append(f"slots[{signal_index}].update(next_{signal_index})")
600+
emitter.append(f"slots[{signal_index}].update(next_{signal_index}, {mask})")
598601

599602
# There shouldn't be any exceptions raised by the generated code, but if there are
600603
# (almost certainly due to a bug in the code generator), use this environment variable

amaranth/sim/pysim.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,8 @@ def add_waker(self, waker):
369369
assert waker not in self.wakers
370370
self.wakers.append(waker)
371371

372-
def update(self, value):
372+
def update(self, value, mask=~0):
373+
value = (self.next & ~mask) | (value & mask)
373374
if self.next != value:
374375
self.next = value
375376
self.pending.add(self)

tests/test_hdl_xfrm.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def setUp(self):
227227
self.s1 = Signal()
228228
self.s2 = Signal(init=1)
229229
self.s3 = Signal(init=1, reset_less=True)
230+
self.s4 = Signal(8, init=0x3a)
230231
self.c1 = Signal()
231232

232233
def test_reset_default(self):
@@ -281,6 +282,40 @@ def test_reset_value(self):
281282
)
282283
""")
283284

285+
def test_reset_mask(self):
286+
f = Fragment()
287+
f.add_statements("sync", self.s4[2:4].eq(0))
288+
289+
f = ResetInserter(self.c1)(f)
290+
self.assertRepr(f.statements["sync"], """
291+
(
292+
(eq (slice (sig s4) 2:4) (const 1'd0))
293+
(switch (sig c1)
294+
(case 1 (eq (slice (sig s4) 2:4) (slice (const 8'd58) 2:4)))
295+
)
296+
)
297+
""")
298+
299+
f = Fragment()
300+
f.add_statements("sync", self.s4[2:4].eq(0))
301+
f.add_statements("sync", self.s4[3:5].eq(0))
302+
f.add_statements("sync", self.s4[6:10].eq(0))
303+
304+
f = ResetInserter(self.c1)(f)
305+
self.assertRepr(f.statements["sync"], """
306+
(
307+
(eq (slice (sig s4) 2:4) (const 1'd0))
308+
(eq (slice (sig s4) 3:5) (const 1'd0))
309+
(eq (slice (sig s4) 6:8) (const 1'd0))
310+
(switch (sig c1)
311+
(case 1
312+
(eq (slice (sig s4) 2:5) (slice (const 8'd58) 2:5))
313+
(eq (slice (sig s4) 6:8) (slice (const 8'd58) 6:8))
314+
)
315+
)
316+
)
317+
""")
318+
284319
def test_reset_less(self):
285320
f = Fragment()
286321
f.add_statements("sync", self.s3.eq(0))
@@ -423,3 +458,31 @@ def test_composition(self):
423458
)
424459
)
425460
""")
461+
462+
class LHSMaskCollectorTestCase(FHDLTestCase):
463+
def test_slice(self):
464+
s = Signal(8)
465+
lhs = LHSMaskCollector()
466+
lhs.visit_value(s[2:5], ~0)
467+
self.assertEqual(lhs.lhs[s], 0x1c)
468+
469+
def test_slice_slice(self):
470+
s = Signal(8)
471+
lhs = LHSMaskCollector()
472+
lhs.visit_value(s[2:7][1:3], ~0)
473+
self.assertEqual(lhs.lhs[s], 0x18)
474+
475+
def test_slice_concat(self):
476+
s1 = Signal(8)
477+
s2 = Signal(8)
478+
lhs = LHSMaskCollector()
479+
lhs.visit_value(Cat(s1, s2)[4:11], ~0)
480+
self.assertEqual(lhs.lhs[s1], 0xf0)
481+
self.assertEqual(lhs.lhs[s2], 0x07)
482+
483+
def test_slice_part(self):
484+
s = Signal(8)
485+
idx = Signal(8)
486+
lhs = LHSMaskCollector()
487+
lhs.visit_value(s.bit_select(idx, 5)[1:3], ~0)
488+
self.assertEqual(lhs.lhs[s], 0xff)

tests/test_sim.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,6 +1413,28 @@ async def testbench(ctx):
14131413
with self.assertSimulation(Module(), traces=[mem1, mem2, mem3]) as sim:
14141414
sim.add_testbench(testbench)
14151415

1416+
def test_multiple_modules(self):
1417+
m = Module()
1418+
m.submodules.m1 = m1 = Module()
1419+
m.submodules.m2 = m2 = Module()
1420+
a = Signal(8)
1421+
b = Signal(8)
1422+
m1.d.comb += b[0:2].eq(a[0:2])
1423+
m1.d.comb += b[4:6].eq(a[4:6])
1424+
m2.d.comb += b[2:4].eq(a[2:4])
1425+
m2.d.comb += b[6:8].eq(a[6:8])
1426+
with self.assertSimulation(m) as sim:
1427+
async def testbench(ctx):
1428+
ctx.set(a, 0)
1429+
self.assertEqual(ctx.get(b), 0)
1430+
ctx.set(a, 0x12)
1431+
self.assertEqual(ctx.get(b), 0x12)
1432+
ctx.set(a, 0x34)
1433+
self.assertEqual(ctx.get(b), 0x34)
1434+
ctx.set(a, 0xdb)
1435+
self.assertEqual(ctx.get(b), 0xdb)
1436+
sim.add_testbench(testbench)
1437+
14161438

14171439
class SimulatorTracesTestCase(FHDLTestCase):
14181440
def assertDef(self, traces, flat_traces):

0 commit comments

Comments
 (0)