diff --git a/magma/frontend/pyverilog_importer.py b/magma/frontend/pyverilog_importer.py index 34452a462..ba1a94f09 100644 --- a/magma/frontend/pyverilog_importer.py +++ b/magma/frontend/pyverilog_importer.py @@ -1,4 +1,6 @@ from collections import OrderedDict +import io + import hwtypes import pyverilog.dataflow.visit import pyverilog.vparser.parser @@ -8,12 +10,15 @@ from ..circuit import Circuit from ..interface import IO from ..passes.tsort import tsort -from ..t import In, Out, InOut +from ..t import In, Out, InOut, Type from .verilog_importer import (ImportMode, MultipleModuleDeclarationError, MultiplePortDeclarationError, VerilogImporter, VerilogImportError) from .verilog_utils import int_const_str_to_int from magma.bitutils import clog2 +from magma.conversions import concat +from magma.stubify import stubify +from magma.wire import wire class PyverilogImportError(VerilogImportError): @@ -55,10 +60,13 @@ def _evaluate_node(node, params): raise PyverilogImportError(f"Unsupported expression: {type(node)}") -def _get_width(width, param_map): +def _get_width(width, param_map, metadata=None): """Evaluates width.msb, width.lsb and returns their difference""" msb = _evaluate_node(width.msb, param_map) lsb = _evaluate_node(width.lsb, param_map) + if metadata is not None: + metadata["msb"] = msb + metadata["lsb"] = lsb return msb - lsb + 1 @@ -123,6 +131,10 @@ def visit_ModuleDef(self, defn): class PyverilogImporter(VerilogImporter): """Implementation of VerilogImporter using pyverilog""" + def __init__(self, type_map): + super().__init__(type_map) + self._magma_defn_to_pyverilog_defn = {} + def _import_defn(self, defn, mode): ports = {} default_params = {} @@ -172,6 +184,7 @@ def import_(self, src, mode): visitor.visit(ast) for name, defn in visitor.defns.items(): circ, default_params = self._import_defn(defn, mode) + self._magma_defn_to_pyverilog_defn[circ] = defn if mode is ImportMode.DEFINE: circ.verilogFile = _get_lines(src, defn.lineno, defn.end_lineno) circ.verilog_source = src @@ -181,3 +194,163 @@ def import_(self, src, mode): } circ.default_kwargs = default_params self.add_module(circ) + + +def _show(node, **kwargs): + buf = io.StringIO() + node.show(buf=buf, **kwargs) + return buf.getvalue() + + +def _update_unique(dct, mapping): + for k, v in mapping.items(): + assert k not in dct, (k) + dct[k] = v + + +def _get_instances(defn): + for child in defn.children(): + if isinstance(child, pyverilog.vparser.parser.Instance): + yield child + elif isinstance(child, pyverilog.vparser.parser.InstanceList): + yield from child.instances + + +def _get_wires(defn): + for child in defn.children(): + if not isinstance(child, pyverilog.vparser.parser.Decl): + continue + for subchild in child.children(): + if not isinstance(subchild, pyverilog.vparser.parser.Wire): + continue + yield subchild + + +def _wire_directed_to_undirected(directed, undirected): + if isinstance(directed, Bits[1]): + directed = directed[0] + if isinstance(undirected, Bits[1]): + undirected = undirected[0] + if directed.is_input(): + if isinstance(undirected, Type) and undirected.is_input(): + return False + directed @= undirected + return True + if directed.is_output(): + undirected @= directed + return True + if directed.is_inout(): + wire(directed, undirected) + return True + raise Exception() + + +def _get_offset(name, metadata): + try: + msb, lsb = metadata[name].values() + except KeyError: + return 0 + return lsb + + +def _evaluate_arg(node, values, metadata={}): + if isinstance(node, pyverilog.vparser.parser.IntConst): + return _evaluate_node(node, {}) + if isinstance(node, pyverilog.vparser.parser.Identifier): + return values[node.name] + if isinstance(node, pyverilog.vparser.parser.Pointer): + offset = _get_offset(node.var.name, metadata) + value = _evaluate_arg(node.var, values) + index = _evaluate_node(node.ptr, {}) + return value[index - offset] + if isinstance(node, pyverilog.vparser.parser.Partselect): + offset = _get_offset(node.var.name, metadata) + value = _evaluate_arg(node.var, values) + msb = _evaluate_node(node.msb, {}) + lsb = _evaluate_node(node.lsb, {}) + return value[lsb - offset : msb + 1 - offset] + if isinstance(node, pyverilog.vparser.parser.Concat): + return concat(*(_evaluate_arg(n, values) for n in reversed(node.list))) + raise Exception() + + +def _process_port_connections( + pyverilog_inst, + magma_inst, + containing_magma_defn, + wires, + metadata, +): + stash = [] + magma_values = {} + magma_values.update(wires) + magma_values.update(containing_magma_defn.interface.ports) + for child in pyverilog_inst.children(): + if not isinstance(child, pyverilog.vparser.parser.PortArg): + continue + magma_port = getattr(magma_inst, child.portname) + magma_arg = _evaluate_arg(child.argname, magma_values, metadata) + wired = _wire_directed_to_undirected(magma_port, magma_arg) + if not wired: + stash.append((magma_port, magma_arg)) + return stash + + +class PyverilogNetlistImporter(PyverilogImporter): + def __init__(self, type_map, stdcells): + super().__init__(type_map) + self._stdcells = stdcells + + def import_(self, src, mode): + super().import_(src, mode) + if mode is not ImportMode.DEFINE: + return + modules = { + defn.name: defn + for defn in self._magma_defn_to_pyverilog_defn.keys() + } + try: + _update_unique(modules, self._stdcells) + except AssertionError as e: + raise MultipleModuleDeclarationError(e.args[0]) + + for magma_defn, pyverilog_defn in ( + self._magma_defn_to_pyverilog_defn.items() + ): + metadata = {} + mod = False + wires = {} + for wire in _get_wires(pyverilog_defn): + assert ( + not wire.signed + and wire.dimensions is None + and wire.value is None + ) + T = Bit + if wire.width is not None: + width = _get_width(wire.width, {}, metadata.setdefault(wire.name, {})) + T = Bits[width] + mod = True + assert wire.name not in wires + wires[wire.name] = T(name=wire.name) + stash = [] + for pyverilog_inst in _get_instances(pyverilog_defn): + instance_module = modules[pyverilog_inst.module] + mod = True + with magma_defn.open(): + magma_inst = instance_module(name=pyverilog_inst.name) + stash += _process_port_connections( + pyverilog_inst, + magma_inst, + magma_defn, + wires, + metadata, + ) + stubify(magma_inst.interface) + for inst_port, driver in stash: + inst_port @= driver.trace() + if mod: + magma_defn.verilogFile = "" + magma_defn.verilog_source = "" + magma_defn._is_definition = True + stubify(magma_defn) diff --git a/magma/stubify.py b/magma/stubify.py index d2a3c3de5..3d227e143 100644 --- a/magma/stubify.py +++ b/magma/stubify.py @@ -113,7 +113,7 @@ def _( # that we *can't* do this in the class itself, since we need to call open() # to tie the outputs first (in stubify()). Afterwards, we can override the # method. - setattr(ckt, "open", classmethod(_stub_open)) + # setattr(ckt, "open", classmethod(_stub_open)) def circuit_stub(