diff --git a/.gitignore b/.gitignore index 24b4f728..ade9905e 100644 --- a/.gitignore +++ b/.gitignore @@ -187,3 +187,4 @@ cobertura.xml docs/build !docs/build/.nojekyll uv.lock +lcov.info diff --git a/Cargo.toml b/Cargo.toml index 63fd698c..f4726cfa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,13 +17,13 @@ crate-type = ["cdylib"] [dependencies] fixedbitset = "0.5" -hashbrown = "0.15" itertools = "0.14" -pyo3 = { version = "0.25", features = ["abi3-py39", "hashbrown"] } +pyo3 = { version = "0.25", features = ["abi3-py39"] } thiserror = "2" tracing = { version = "0.1", features = ["release_max_level_off"] } [dev-dependencies] +maplit = "1" rand = "0.9" rstest = "0.25" rstest_reuse = "0.7" diff --git a/docs/source/swiflow.rst b/docs/source/swiflow.rst index 1bd3f028..1209175d 100644 --- a/docs/source/swiflow.rst +++ b/docs/source/swiflow.rst @@ -6,7 +6,7 @@ swiflow.common module .. automodule:: swiflow.common :members: - :exclude-members: Plane, PPlane, V, P + :exclude-members: Plane, PPlane .. autoclass:: swiflow.common.Plane diff --git a/examples/flow.py b/examples/flow.py index 4f8423bf..20c9738c 100644 --- a/examples/flow.py +++ b/examples/flow.py @@ -14,7 +14,7 @@ # 1 - 3 - 5 # | # 2 - 4 - 6 -g = nx.Graph([(1, 3), (2, 4), (3, 5), (4, 6)]) +g = nx.Graph([(1, 3), (2, 4), (3, 5), (4, 6), (3, 4)]) iset = {1, 2} oset = {5, 6} diff --git a/examples/gflow.py b/examples/gflow.py index 8d37a388..0fe96cb3 100644 --- a/examples/gflow.py +++ b/examples/gflow.py @@ -22,7 +22,7 @@ oset = {4, 5} planes = {0: Plane.XY, 1: Plane.XY, 2: Plane.XZ, 3: Plane.YZ} -result = gflow.find(g, iset, oset, planes) +result = gflow.find(g, iset, oset, planes=planes) # Found assert result is not None diff --git a/examples/pflow.py b/examples/pflow.py index 6dd36758..30bdd3d4 100644 --- a/examples/pflow.py +++ b/examples/pflow.py @@ -20,7 +20,7 @@ oset = {4} pplanes = {0: PPlane.Z, 1: PPlane.Z, 2: PPlane.Y, 3: PPlane.Y} -result = pflow.find(g, iset, oset, pplanes) +result = pflow.find(g, iset, oset, pplanes=pplanes) # Found assert result is not None diff --git a/pyproject.toml b/pyproject.toml index 5f387b17..e7168145 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ python-source = "python" [tool.mypy] python_version = "3.9" strict = true -files = ["docs/source/conf.py", "python", "tests"] +files = ["docs/source/conf.py", "examples", "python", "tests"] [tool.pyright] reportUnknownArgumentType = "information" @@ -123,6 +123,10 @@ required-imports = ["from __future__ import annotations"] [tool.ruff.lint.pydocstyle] convention = "numpy" +[tool.ruff.lint.pylint] +max-positional-args = 4 +max-args = 12 + [tool.ruff.lint.per-file-ignores] "docs/**/*.py" = [ "D1", # undocumented-XXX diff --git a/python/swiflow/__init__.py b/python/swiflow/__init__.py index 126d71c6..8028f3a7 100644 --- a/python/swiflow/__init__.py +++ b/python/swiflow/__init__.py @@ -1 +1 @@ -"""Initialize the swiflow package.""" +"""swiflow: Rust binding of generalized and pauli flow finding algorithms.""" diff --git a/python/swiflow/_common.py b/python/swiflow/_common.py index e3dc2a71..1057719e 100644 --- a/python/swiflow/_common.py +++ b/python/swiflow/_common.py @@ -2,17 +2,19 @@ from __future__ import annotations -from collections.abc import Callable, Iterable, Mapping +from collections.abc import Callable, Hashable, Iterable, Mapping from collections.abc import Set as AbstractSet -from typing import Generic +from typing import Generic, TypeVar import networkx as nx +from typing_extensions import ParamSpec from swiflow._impl import FlowValidationMessage -from swiflow.common import P, S, T, V +_V = TypeVar("_V", bound=Hashable) -def check_graph(g: nx.Graph[V], iset: AbstractSet[V], oset: AbstractSet[V]) -> None: + +def check_graph(g: nx.Graph[_V], iset: AbstractSet[_V], oset: AbstractSet[_V]) -> None: """Check if `(g, iset, oset)` is a valid open graph for MBQC. Raises @@ -46,7 +48,7 @@ def check_graph(g: nx.Graph[V], iset: AbstractSet[V], oset: AbstractSet[V]) -> N raise ValueError(msg) -def check_planelike(vset: AbstractSet[V], oset: AbstractSet[V], plike: Mapping[V, P]) -> None: +def check_planelike(vset: AbstractSet[_V], oset: AbstractSet[_V], plike: Mapping[_V, _P]) -> None: r"""Check if measurement description is valid. Parameters @@ -80,13 +82,26 @@ def check_planelike(vset: AbstractSet[V], oset: AbstractSet[V], plike: Mapping[V raise ValueError(msg) -class IndexMap(Generic[V]): +def odd_neighbors(g: nx.Graph[_V], kset: AbstractSet[_V]) -> set[_V]: + """Compute odd neighbors of `kset` in `g`.""" + ret: set[_V] = set() + for k in kset: + ret.symmetric_difference_update(g.neighbors(k)) + return ret + + +_T = TypeVar("_T") +_P = TypeVar("_P") +_S = ParamSpec("_S") + + +class IndexMap(Generic[_V]): """Map between `V` and 0-based indices.""" - __v2i: dict[V, int] - __i2v: list[V] + __v2i: dict[_V, int] + __i2v: list[_V] - def __init__(self, vset: AbstractSet[V]) -> None: + def __init__(self, vset: AbstractSet[_V]) -> None: """Initialize the map from `vset`. Parameters @@ -105,7 +120,7 @@ def __init__(self, vset: AbstractSet[V]) -> None: self.__i2v = list(vset) self.__v2i = {v: i for i, v in enumerate(self.__i2v)} - def encode(self, v: V) -> int: + def encode(self, v: _V) -> int: """Encode `v` to the index. Returns @@ -124,7 +139,7 @@ def encode(self, v: V) -> int: raise ValueError(msg) return ind - def encode_graph(self, g: nx.Graph[V]) -> list[set[int]]: + def encode_graph(self, g: nx.Graph[_V]) -> list[set[int]]: """Encode graph. Returns @@ -133,11 +148,11 @@ def encode_graph(self, g: nx.Graph[V]) -> list[set[int]]: """ return [self.encode_set(g[v].keys()) for v in self.__i2v] - def encode_set(self, vset: AbstractSet[V]) -> set[int]: + def encode_set(self, vset: AbstractSet[_V]) -> set[int]: """Encode set.""" return {self.encode(v) for v in vset} - def encode_dictkey(self, mapping: Mapping[V, P]) -> dict[int, P]: + def encode_dictkey(self, mapping: Mapping[_V, _P]) -> dict[int, _P]: """Encode dict key. Returns @@ -146,7 +161,7 @@ def encode_dictkey(self, mapping: Mapping[V, P]) -> dict[int, P]: """ return {self.encode(k): v for k, v in mapping.items()} - def encode_flow(self, f: Mapping[V, V]) -> dict[int, int]: + def encode_flow(self, f: Mapping[_V, _V]) -> dict[int, int]: """Encode flow. Returns @@ -155,7 +170,7 @@ def encode_flow(self, f: Mapping[V, V]) -> dict[int, int]: """ return {self.encode(i): self.encode(j) for i, j in f.items()} - def encode_gflow(self, f: Mapping[V, AbstractSet[V]]) -> dict[int, set[int]]: + def encode_gflow(self, f: Mapping[_V, AbstractSet[_V]]) -> dict[int, set[int]]: """Encode gflow. Returns @@ -164,24 +179,24 @@ def encode_gflow(self, f: Mapping[V, AbstractSet[V]]) -> dict[int, set[int]]: """ return {self.encode(i): self.encode_set(si) for i, si in f.items()} - def encode_layer(self, layer: Mapping[V, int]) -> list[int]: - """Encode layer. + def encode_layers(self, layers: Mapping[_V, int]) -> list[int]: + """Encode layers. Returns ------- - `layer` values transformed. + `layers` values transformed. Notes ----- `list` is used instead of `dict` here because no missing values are allowed here. """ try: - return [layer[v] for v in self.__i2v] + return [layers[v] for v in self.__i2v] except KeyError: msg = "Layers must be specified for all nodes." raise ValueError(msg) from None - def decode(self, i: int) -> V: + def decode(self, i: int) -> _V: """Decode the index. Returns @@ -200,11 +215,11 @@ def decode(self, i: int) -> V: raise ValueError(msg) from None return v - def decode_set(self, iset: AbstractSet[int]) -> set[V]: + def decode_set(self, iset: AbstractSet[int]) -> set[_V]: """Decode set.""" return {self.decode(i) for i in iset} - def decode_flow(self, f_: Mapping[int, int]) -> dict[V, V]: + def decode_flow(self, f_: Mapping[int, int]) -> dict[_V, _V]: """Decode MBQC flow. Returns @@ -213,7 +228,7 @@ def decode_flow(self, f_: Mapping[int, int]) -> dict[V, V]: """ return {self.decode(i): self.decode(j) for i, j in f_.items()} - def decode_gflow(self, f_: Mapping[int, AbstractSet[int]]) -> dict[V, set[V]]: + def decode_gflow(self, f_: Mapping[int, AbstractSet[int]]) -> dict[_V, set[_V]]: """Decode MBQC gflow. Returns @@ -222,18 +237,18 @@ def decode_gflow(self, f_: Mapping[int, AbstractSet[int]]) -> dict[V, set[V]]: """ return {self.decode(i): self.decode_set(si) for i, si in f_.items()} - def decode_layer(self, layer_: Iterable[int]) -> dict[V, int]: - """Decode MBQC layer. + def decode_layers(self, layers_: Iterable[int]) -> dict[_V, int]: + """Decode MBQC layers. Returns ------- - `layer_` transformed. + `layers_` transformed. Notes ----- `list` (generalized as `Iterable`) is used instead of `dict` here because no missing values are allowed here. """ - return {self.decode(i): li for i, li in enumerate(layer_)} + return {self.decode(i): li for i, li in enumerate(layers_)} def decode_err(self, err: ValueError) -> ValueError: """Decode the error message stored in the first ctor argument of ValueError.""" @@ -268,7 +283,7 @@ def decode_err(self, err: ValueError) -> ValueError: raise TypeError # pragma: no cover return ValueError(msg) - def ecatch(self, f: Callable[S, T], *args: S.args, **kwargs: S.kwargs) -> T: + def ecatch(self, f: Callable[_S, _T], *args: _S.args, **kwargs: _S.kwargs) -> _T: """Wrap binding call to decode raw error messages.""" try: return f(*args, **kwargs) diff --git a/python/swiflow/_impl/flow.pyi b/python/swiflow/_impl/flow.pyi index 1fe37a58..211ae76d 100644 --- a/python/swiflow/_impl/flow.pyi +++ b/python/swiflow/_impl/flow.pyi @@ -1,2 +1,7 @@ def find(g: list[set[int]], iset: set[int], oset: set[int]) -> tuple[dict[int, int], list[int]] | None: ... -def verify(flow: tuple[dict[int, int], list[int]], g: list[set[int]], iset: set[int], oset: set[int]) -> None: ... +def verify( + flow: tuple[dict[int, int], list[int]], + g: list[set[int]], + iset: set[int], + oset: set[int], +) -> None: ... diff --git a/python/swiflow/_impl/gflow.pyi b/python/swiflow/_impl/gflow.pyi index 2770059c..bd5d88cf 100644 --- a/python/swiflow/_impl/gflow.pyi +++ b/python/swiflow/_impl/gflow.pyi @@ -4,12 +4,12 @@ class Plane: XZ: Plane def find( - g: list[set[int]], iset: set[int], oset: set[int], plane: dict[int, Plane] + g: list[set[int]], iset: set[int], oset: set[int], planes: dict[int, Plane] ) -> tuple[dict[int, set[int]], list[int]] | None: ... def verify( gflow: tuple[dict[int, set[int]], list[int]], g: list[set[int]], iset: set[int], oset: set[int], - plane: dict[int, Plane], + planes: dict[int, Plane], ) -> None: ... diff --git a/python/swiflow/_impl/pflow.pyi b/python/swiflow/_impl/pflow.pyi index 8765895b..0f6b6c30 100644 --- a/python/swiflow/_impl/pflow.pyi +++ b/python/swiflow/_impl/pflow.pyi @@ -7,12 +7,12 @@ class PPlane: Z: PPlane def find( - g: list[set[int]], iset: set[int], oset: set[int], pplane: dict[int, PPlane] + g: list[set[int]], iset: set[int], oset: set[int], pplanes: dict[int, PPlane] ) -> tuple[dict[int, set[int]], list[int]] | None: ... def verify( pflow: tuple[dict[int, set[int]], list[int]], g: list[set[int]], iset: set[int], oset: set[int], - pplane: dict[int, PPlane], + pplanes: dict[int, PPlane], ) -> None: ... diff --git a/python/swiflow/common.py b/python/swiflow/common.py index 311edf21..595c0acf 100644 --- a/python/swiflow/common.py +++ b/python/swiflow/common.py @@ -2,40 +2,129 @@ from __future__ import annotations -import dataclasses -from collections.abc import Hashable -from typing import Generic, TypeVar - -from typing_extensions import ParamSpec +import itertools +from collections.abc import Hashable, Mapping, MutableSet +from collections.abc import Set as AbstractSet +from typing import TYPE_CHECKING, TypeVar +from swiflow import _common from swiflow._impl import gflow, pflow +if TYPE_CHECKING: + import networkx as nx + Plane = gflow.Plane PPlane = pflow.PPlane -T = TypeVar("T") -V = TypeVar("V", bound=Hashable) -P = TypeVar("P", Plane, PPlane) -S = ParamSpec("S") +_V = TypeVar("_V", bound=Hashable) + +Flow = dict[_V, _V] +"""Flow map as a dictionary. :math:`f(u)` is stored in :py:obj:`f[u]`.""" + +GFlow = dict[_V, set[_V]] +"""Generalized flow map as a dictionary. :math:`f(u)` is stored in :py:obj:`f[u]`.""" + +PFlow = dict[_V, set[_V]] +"""Pauli flow map as a dictionary. :math:`f(u)` is stored in :py:obj:`f[u]`.""" + +Layers = dict[_V, int] +r"""Layer of each node representing the partial order. :math:`layers(u) > layers(v)` implies :math:`u \prec v`. +""" -@dataclasses.dataclass(frozen=True) -class FlowResult(Generic[V]): - r"""Causal flow of an open graph.""" +def _infer_layers_impl(pred: Mapping[_V, MutableSet[_V]], succ: Mapping[_V, AbstractSet[_V]]) -> Mapping[_V, int]: + """Fix flow layers one by one depending on order constraints. - f: dict[V, V] - """Flow map as a dictionary. :math:`f(u)` is stored in :py:obj:`f[u]`.""" - layer: dict[V, int] - r"""Layer of each node representing the partial order. :math:`layer(u) > layer(v)` implies :math:`u \prec v`. + Notes + ----- + :py:obj:`pred` is mutated in-place. """ + work = {u for u, pu in pred.items() if not pu} + ret: dict[_V, int] = {} + for l_now in itertools.count(): + if not work: + break + next_work: set[_V] = set() + for u in work: + ret[u] = l_now + for v in succ[u]: + ent = pred[v] + ent.discard(u) + if not ent: + next_work.add(v) + work = next_work + if len(ret) != len(succ): + msg = "Cannot satisfy all the partial order constraints." + raise ValueError(msg) + return ret + + +def _is_special( + pp: PPlane | None, + in_fu: bool, # noqa: FBT001 + in_fu_odd: bool, # noqa: FBT001 +) -> bool: + if pp == PPlane.X: + return in_fu + if pp == PPlane.Y: + return in_fu and in_fu_odd + if pp == PPlane.Z: + return in_fu_odd + return False + + +def _special_edges( + g: nx.Graph[_V], + anyflow: Mapping[_V, _V | AbstractSet[_V]], + pplanes: Mapping[_V, PPlane] | None, +) -> set[tuple[_V, _V]]: + """Compute special edges that can bypass partial order constraints in Pauli flow.""" + # MEMO: Unify with Rust implementation + ret: set[tuple[_V, _V]] = set() + if pplanes is None: + return ret + for u, fu_ in anyflow.items(): + fu = fu_ if isinstance(fu_, AbstractSet) else {fu_} + fu_odd = _common.odd_neighbors(g, fu) + for v in itertools.chain(fu, fu_odd): + if u == v: + continue + if _is_special(pplanes.get(v), v in fu, v in fu_odd): + ret.add((u, v)) + return ret + +def infer_layers( + g: nx.Graph[_V], + anyflow: Mapping[_V, _V | AbstractSet[_V]], + pplanes: Mapping[_V, PPlane] | None = None, +) -> Mapping[_V, int]: + """Infer layers from flow/gflow using greedy algorithm. -@dataclasses.dataclass(frozen=True) -class GFlowResult(Generic[V]): - r"""Generalized flow of an open graph.""" + Parameters + ---------- + g : `networkx.Graph` + Simple graph representing MBQC pattern. + anyflow : `tuple` of flow-like/layer + Flow to verify. Compatible with both flow and generalized flow. + pplanes : `collections.abc.Mapping`, optional + Measurement plane or Pauli index. - f: dict[V, set[V]] - """Generalized flow map as a dictionary. :math:`f(u)` is stored in :py:obj:`f[u]`.""" - layer: dict[V, int] - r"""Layer of each node representing the partial order. :math:`layer(u) > layer(v)` implies :math:`u \prec v`. + Notes + ----- + This function operates in Pauli flow mode only when :py:obj`pplanes` is explicitly given. """ + special = _special_edges(g, anyflow, pplanes) + pred: dict[_V, set[_V]] = {u: set() for u in g.nodes} + succ: dict[_V, set[_V]] = {u: set() for u in g.nodes} + for u, fu_ in anyflow.items(): + fu = fu_ if isinstance(fu_, AbstractSet) else {fu_} + fu_odd = _common.odd_neighbors(g, fu) + for v in itertools.chain(fu, fu_odd): + if u == v or (u, v) in special: + continue + # Reversed + pred[u].add(v) + succ[v].add(u) + # MEMO: `pred` is invalidated by `_infer_layers_impl` + return _infer_layers_impl(pred, succ) diff --git a/python/swiflow/flow.py b/python/swiflow/flow.py index 4fb7189c..16058235 100644 --- a/python/swiflow/flow.py +++ b/python/swiflow/flow.py @@ -6,20 +6,24 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from collections.abc import Hashable, Mapping +from typing import TYPE_CHECKING, TypeVar -from swiflow import _common +from swiflow import _common, common from swiflow._common import IndexMap from swiflow._impl import flow as flow_bind -from swiflow.common import FlowResult, V +from swiflow.common import Flow, Layers if TYPE_CHECKING: from collections.abc import Set as AbstractSet import networkx as nx +_V = TypeVar("_V", bound=Hashable) +FlowResult = tuple[Flow[_V], Layers[_V]] -def find(g: nx.Graph[V], iset: AbstractSet[V], oset: AbstractSet[V]) -> FlowResult[V] | None: + +def find(g: nx.Graph[_V], iset: AbstractSet[_V], oset: AbstractSet[_V]) -> FlowResult[_V] | None: """Compute causal flow. If it returns a flow, it is guaranteed to be maximally-delayed, i.e., the number of layers is minimized. @@ -35,7 +39,7 @@ def find(g: nx.Graph[V], iset: AbstractSet[V], oset: AbstractSet[V]) -> FlowResu Returns ------- - `FlowResult` or `None` + `tuple` of flow/layers or `None` Return the flow if any, otherwise `None`. """ _common.check_graph(g, iset, oset) @@ -45,19 +49,28 @@ def find(g: nx.Graph[V], iset: AbstractSet[V], oset: AbstractSet[V]) -> FlowResu iset_ = codec.encode_set(iset) oset_ = codec.encode_set(oset) if ret_ := flow_bind.find(g_, iset_, oset_): - f_, layer_ = ret_ + f_, layers_ = ret_ f = codec.decode_flow(f_) - layer = codec.decode_layer(layer_) - return FlowResult(f, layer) + layers = codec.decode_layers(layers_) + return f, layers return None -def verify(flow: FlowResult[V], g: nx.Graph[V], iset: AbstractSet[V], oset: AbstractSet[V]) -> None: - """Verify maximally-delayed causal flow. +_Flow = Mapping[_V, _V] +_Layer = Mapping[_V, int] + + +def verify( + flow: tuple[_Flow[_V], _Layer[_V]] | _Flow[_V], + g: nx.Graph[_V], + iset: AbstractSet[_V], + oset: AbstractSet[_V], +) -> None: + """Verify causal flow. Parameters ---------- - flow : `FlowResult` + flow : flow (required) and layers (optional) Flow to verify. g : `networkx.Graph` Simple graph representing MBQC pattern. @@ -77,6 +90,11 @@ def verify(flow: FlowResult[V], g: nx.Graph[V], iset: AbstractSet[V], oset: Abst g_ = codec.encode_graph(g) iset_ = codec.encode_set(iset) oset_ = codec.encode_set(oset) - f_ = codec.encode_flow(flow.f) - layer_ = codec.encode_layer(flow.layer) - codec.ecatch(flow_bind.verify, (f_, layer_), g_, iset_, oset_) + if isinstance(flow, tuple): + f, layers = flow + common.infer_layers(g, f) + else: + f = flow + layers = common.infer_layers(g, f) + f_ = (codec.encode_flow(f), codec.encode_layers(layers)) + codec.ecatch(flow_bind.verify, f_, g_, iset_, oset_) diff --git a/python/swiflow/gflow.py b/python/swiflow/gflow.py index 3f36f3ca..7b7ec418 100644 --- a/python/swiflow/gflow.py +++ b/python/swiflow/gflow.py @@ -6,26 +6,29 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from collections.abc import Hashable, Mapping +from collections.abc import Set as AbstractSet +from typing import TYPE_CHECKING, TypeVar -from swiflow import _common +from swiflow import _common, common from swiflow._common import IndexMap from swiflow._impl import gflow as gflow_bind -from swiflow.common import GFlowResult, Plane, V +from swiflow.common import GFlow, Layers, Plane if TYPE_CHECKING: - from collections.abc import Mapping - from collections.abc import Set as AbstractSet - import networkx as nx +_V = TypeVar("_V", bound=Hashable) +GFlowResult = tuple[GFlow[_V], Layers[_V]] + def find( - g: nx.Graph[V], - iset: AbstractSet[V], - oset: AbstractSet[V], - plane: Mapping[V, Plane] | None = None, -) -> GFlowResult[V] | None: + g: nx.Graph[_V], + iset: AbstractSet[_V], + oset: AbstractSet[_V], + *, + planes: Mapping[_V, Plane] | None = None, +) -> GFlowResult[_V] | None: r"""Compute generalized flow. If it returns a gflow, it is guaranteed to be maximally-delayed, i.e., the number of layers is minimized. @@ -38,45 +41,50 @@ def find( Input nodes. oset : `collections.abc.Set` Output nodes. - plane : `collections.abc.Mapping` + planes : `collections.abc.Mapping` Measurement plane for each node in :math:`V \setminus O`. Defaults to `Plane.XY`. Returns ------- - `GFlowResult` or `None` + `tuple` of gflow/layers or `None` Return the gflow if any, otherwise `None`. """ _common.check_graph(g, iset, oset) vset = g.nodes - if plane is None: - plane = dict.fromkeys(vset - oset, Plane.XY) - _common.check_planelike(vset, oset, plane) + if planes is None: + planes = dict.fromkeys(vset - oset, Plane.XY) + _common.check_planelike(vset, oset, planes) codec = IndexMap(vset) g_ = codec.encode_graph(g) iset_ = codec.encode_set(iset) oset_ = codec.encode_set(oset) - plane_ = codec.encode_dictkey(plane) - if ret_ := gflow_bind.find(g_, iset_, oset_, plane_): - f_, layer_ = ret_ + planes_ = codec.encode_dictkey(planes) + if ret_ := gflow_bind.find(g_, iset_, oset_, planes_): + f_, layers_ = ret_ f = codec.decode_gflow(f_) - layer = codec.decode_layer(layer_) - return GFlowResult(f, layer) + layers = codec.decode_layers(layers_) + return f, layers return None +_GFlow = Mapping[_V, AbstractSet[_V]] +_Layer = Mapping[_V, int] + + def verify( - gflow: GFlowResult[V], - g: nx.Graph[V], - iset: AbstractSet[V], - oset: AbstractSet[V], - plane: Mapping[V, Plane] | None = None, + gflow: tuple[_GFlow[_V], _Layer[_V]] | _GFlow[_V], + g: nx.Graph[_V], + iset: AbstractSet[_V], + oset: AbstractSet[_V], + *, + planes: Mapping[_V, Plane] | None = None, ) -> None: - r"""Verify maximally-delayed generalized flow. + r"""Verify generalized flow. Parameters ---------- - gflow : `GFlowResult` + gflow : gflow (required) and layers (optional) Generalized flow to verify. g : `networkx.Graph` Simple graph representing MBQC pattern. @@ -84,7 +92,7 @@ def verify( Input nodes. oset : `collections.abc.Set` Output nodes. - plane : `collections.abc.Mapping` + planes : `collections.abc.Mapping` Measurement plane for each node in :math:`V \setminus O`. Defaults to `Plane.XY`. @@ -95,13 +103,18 @@ def verify( """ _common.check_graph(g, iset, oset) vset = g.nodes - if plane is None: - plane = dict.fromkeys(vset - oset, Plane.XY) + if planes is None: + planes = dict.fromkeys(vset - oset, Plane.XY) codec = IndexMap(vset) g_ = codec.encode_graph(g) iset_ = codec.encode_set(iset) oset_ = codec.encode_set(oset) - plane_ = codec.encode_dictkey(plane) - f_ = codec.encode_gflow(gflow.f) - layer_ = codec.encode_layer(gflow.layer) - codec.ecatch(gflow_bind.verify, (f_, layer_), g_, iset_, oset_, plane_) + planes_ = codec.encode_dictkey(planes) + if isinstance(gflow, tuple): + f, layers = gflow + common.infer_layers(g, f) + else: + f = gflow + layers = common.infer_layers(g, f) + f_ = (codec.encode_gflow(f), codec.encode_layers(layers)) + codec.ecatch(gflow_bind.verify, f_, g_, iset_, oset_, planes_) diff --git a/python/swiflow/pflow.py b/python/swiflow/pflow.py index f6177c40..24e3976b 100644 --- a/python/swiflow/pflow.py +++ b/python/swiflow/pflow.py @@ -7,26 +7,29 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING +from collections.abc import Hashable, Mapping +from collections.abc import Set as AbstractSet +from typing import TYPE_CHECKING, TypeVar -from swiflow import _common +from swiflow import _common, common from swiflow._common import IndexMap from swiflow._impl import pflow as pflow_bind -from swiflow.common import GFlowResult, PPlane, V +from swiflow.common import Layers, PFlow, PPlane if TYPE_CHECKING: - from collections.abc import Mapping - from collections.abc import Set as AbstractSet - import networkx as nx +_V = TypeVar("_V", bound=Hashable) +PFlowResult = tuple[PFlow[_V], Layers[_V]] + def find( - g: nx.Graph[V], - iset: AbstractSet[V], - oset: AbstractSet[V], - pplane: Mapping[V, PPlane] | None = None, -) -> GFlowResult[V] | None: + g: nx.Graph[_V], + iset: AbstractSet[_V], + oset: AbstractSet[_V], + *, + pplanes: Mapping[_V, PPlane] | None = None, +) -> PFlowResult[_V] | None: r"""Compute Pauli flow. If it returns a Pauli flow, it is guaranteed to be maximally-delayed, i.e., the number of layers is minimized. @@ -39,13 +42,13 @@ def find( Input nodes. oset : `collections.abc.Set` Output nodes. - pplane : `collections.abc.Mapping` + pplanes : `collections.abc.Mapping` Measurement plane or Pauli index for each node in :math:`V \setminus O`. Defaults to `PPlane.XY`. Returns ------- - `GFlowResult` or `None` + `tuple` of Pauli flow/layers or `None` Return the Pauli flow if any, otherwise `None`. Notes @@ -54,37 +57,42 @@ def find( """ _common.check_graph(g, iset, oset) vset = g.nodes - if pplane is None: - pplane = dict.fromkeys(vset - oset, PPlane.XY) - _common.check_planelike(vset, oset, pplane) - if all(pp not in {PPlane.X, PPlane.Y, PPlane.Z} for pp in pplane.values()): + if pplanes is None: + pplanes = dict.fromkeys(vset - oset, PPlane.XY) + _common.check_planelike(vset, oset, pplanes) + if all(pp not in {PPlane.X, PPlane.Y, PPlane.Z} for pp in pplanes.values()): msg = "No Pauli measurement found. Use gflow.find instead." warnings.warn(msg, stacklevel=1) codec = IndexMap(vset) g_ = codec.encode_graph(g) iset_ = codec.encode_set(iset) oset_ = codec.encode_set(oset) - pplane_ = codec.encode_dictkey(pplane) - if ret_ := pflow_bind.find(g_, iset_, oset_, pplane_): - f_, layer_ = ret_ + pplanes_ = codec.encode_dictkey(pplanes) + if ret_ := pflow_bind.find(g_, iset_, oset_, pplanes_): + f_, layers_ = ret_ f = codec.decode_gflow(f_) - layer = codec.decode_layer(layer_) - return GFlowResult(f, layer) + layers = codec.decode_layers(layers_) + return f, layers return None +_PFlow = Mapping[_V, AbstractSet[_V]] +_Layer = Mapping[_V, int] + + def verify( - pflow: GFlowResult[V], - g: nx.Graph[V], - iset: AbstractSet[V], - oset: AbstractSet[V], - pplane: Mapping[V, PPlane] | None = None, + pflow: tuple[_PFlow[_V], _Layer[_V]] | _PFlow[_V], + g: nx.Graph[_V], + iset: AbstractSet[_V], + oset: AbstractSet[_V], + *, + pplanes: Mapping[_V, PPlane] | None = None, ) -> None: - r"""Verify maximally-delayed Pauli flow. + r"""Verify Pauli flow. Parameters ---------- - pflow : `GFlowResult` + pflow : Pauli flow (required) and layers (optional) Pauli flow to verify. g : `networkx.Graph` Simple graph representing MBQC pattern. @@ -92,7 +100,7 @@ def verify( Input nodes. oset : `collections.abc.Set` Output nodes. - pplane : `collections.abc.Mapping` + pplanes : `collections.abc.Mapping` Measurement plane or Pauli index for each node in :math:`V \setminus O`. Defaults to `PPlane.XY`. @@ -103,13 +111,18 @@ def verify( """ _common.check_graph(g, iset, oset) vset = g.nodes - if pplane is None: - pplane = dict.fromkeys(vset - oset, PPlane.XY) + if pplanes is None: + pplanes = dict.fromkeys(vset - oset, PPlane.XY) codec = IndexMap(vset) g_ = codec.encode_graph(g) iset_ = codec.encode_set(iset) oset_ = codec.encode_set(oset) - pplane_ = codec.encode_dictkey(pplane) - f_ = codec.encode_gflow(pflow.f) - layer_ = codec.encode_layer(pflow.layer) - codec.ecatch(pflow_bind.verify, (f_, layer_), g_, iset_, oset_, pplane_) + pplanes_ = codec.encode_dictkey(pplanes) + if isinstance(pflow, tuple): + f, layers = pflow + common.infer_layers(g, f, pplanes) + else: + f = pflow + layers = common.infer_layers(g, f, pplanes) + f_ = (codec.encode_gflow(f), codec.encode_layers(layers)) + codec.ecatch(pflow_bind.verify, f_, g_, iset_, oset_, pplanes_) diff --git a/src/common.rs b/src/common.rs index 43c84c21..9e329f9e 100644 --- a/src/common.rs +++ b/src/common.rs @@ -1,25 +1,35 @@ //! Common functionalities. -use std::collections::BTreeSet; +use std::collections::{BTreeSet, HashSet}; use pyo3::{exceptions::PyValueError, prelude::*}; use thiserror::Error; -use crate::{gflow::Plane, pflow::PPlane}; +use crate::{ + common::FlowValidationError::{ + ExcessiveNonZeroLayer, ExcessiveZeroLayer, InvalidFlowCodomain, InvalidFlowDomain, + }, + gflow::Plane, + pflow::PPlane, +}; -/// Set of nodes indexed by 0-based integers. -pub type Nodes = hashbrown::HashSet; +/// Node index. +pub type Node = usize; +/// Layer index. +pub type Layer = usize; +/// Set of nodes. +pub type Nodes = HashSet; /// Simple graph encoded as list of neighbors. pub type Graph = Vec; /// Layer representation of the flow partial order. -pub type Layer = Vec; +pub type Layers = Vec; /// Ordered set of nodes. /// /// # Note /// /// Used only when iteration order matters. -pub(crate) type OrderedNodes = BTreeSet; +pub(crate) type OrderedNodes = BTreeSet; /// Error type for flow validation. /// @@ -29,21 +39,21 @@ pub(crate) type OrderedNodes = BTreeSet; pub enum FlowValidationError { // Keep in sync with Python-side error messages #[error("layer-{layer} node {node} inside output nodes")] - ExcessiveNonZeroLayer { node: usize, layer: usize }, + ExcessiveNonZeroLayer { node: Node, layer: Layer }, #[error("zero-layer node {node} outside output nodes")] - ExcessiveZeroLayer { node: usize }, + ExcessiveZeroLayer { node: Node }, #[error("f({node}) has invalid codomain")] - InvalidFlowCodomain { node: usize }, + InvalidFlowCodomain { node: Node }, #[error("f({node}) has invalid domain")] - InvalidFlowDomain { node: usize }, + InvalidFlowDomain { node: Node }, #[error("node {node} has invalid measurement specification")] - InvalidMeasurementSpec { node: usize }, + InvalidMeasurementSpec { node: Node }, #[error("flow-order inconsistency on nodes ({}, {})",.nodes.0, .nodes.1)] - InconsistentFlowOrder { nodes: (usize, usize) }, + InconsistentFlowOrder { nodes: (Node, Node) }, #[error("broken {plane:?} measurement on node {node}")] - InconsistentFlowPlane { node: usize, plane: Plane }, + InconsistentFlowPlane { node: Node, plane: Plane }, #[error("broken {pplane:?} measurement on node {node}")] - InconsistentFlowPPlane { node: usize, pplane: PPlane }, + InconsistentFlowPPlane { node: Node, pplane: PPlane }, } impl From for PyErr { @@ -54,7 +64,7 @@ impl From for PyErr { } // TODO: Remove once stabilized -pub const FATAL_MSG: &str = "\ +pub(crate) const FATAL_MSG: &str = "\ !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ! POST VERIFICATION FAILED ! !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -62,12 +72,146 @@ pub const FATAL_MSG: &str = "\ Please report to the developers via GitHub: https://github.com/TeamGraphix/swiflow/issues/new"; +/// Checks if the layer-zero nodes are correctly chosen. +/// +/// This check can be skipped unless maximally-delayed flow is required. +/// +/// # Arguments +/// +/// - `layers`: The layer. +/// - `oset`: The set of output nodes. +/// - `iff`: If `true`, `layers[u] == 0` "iff" `u` is in `oset`. Otherwise "if". +pub(crate) fn check_initial( + layers: &[Layer], + oset: &Nodes, + iff: bool, +) -> Result<(), FlowValidationError> { + for (u, &lu) in layers.iter().enumerate() { + match (oset.contains(&u), lu == 0) { + (true, false) => { + Err(ExcessiveNonZeroLayer { node: u, layer: lu })?; + } + (false, true) if iff => { + Err(ExcessiveZeroLayer { node: u })?; + } + _ => {} + } + } + Ok(()) +} + +/// Checks if the domain of `f` is in `vset - oset` and the codomain is in `vset - iset`. +/// +/// # Arguments +/// +/// - `f_flatiter`: Flow, gflow, or pflow as `impl Iterator`. +/// - `vset`: All nodes. +/// - `iset`: Input nodes. +/// - `oset`: Output nodes. +pub(crate) fn check_domain<'a, 'b>( + f_flatiter: impl Iterator, + vset: &Nodes, + iset: &Nodes, + oset: &Nodes, +) -> Result<(), FlowValidationError> { + let icset = vset - iset; + let ocset = vset - oset; + let mut dom = Nodes::new(); + for (&i, &fi) in f_flatiter { + dom.insert(i); + if !icset.contains(&fi) { + Err(InvalidFlowCodomain { node: i })?; + } + } + if let Some(&i) = dom.symmetric_difference(&ocset).next() { + Err(InvalidFlowDomain { node: i })?; + } + Ok(()) +} + #[cfg(test)] mod tests { + use core::iter; + use std::collections::HashMap; + use super::*; + use crate::common::Nodes; #[test] fn test_err_from() { let _ = PyErr::from(FlowValidationError::ExcessiveNonZeroLayer { node: 1, layer: 2 }); } + + #[test] + fn test_check_initial() { + let layers = vec![0, 0, 0, 1, 1, 1]; + let oset = Nodes::from([0, 1]); + check_initial(&layers, &oset, false).unwrap(); + } + + #[test] + fn test_check_initial_ng() { + let layers = vec![0, 0, 0, 1, 1, 1]; + let oset = Nodes::from([0, 1, 2, 3]); + assert!(check_initial(&layers, &oset, false).is_err()); + } + + #[test] + fn test_check_initial_iff() { + let layers = vec![0, 0, 0, 1, 1, 1]; + let oset = Nodes::from([0, 1, 2]); + check_initial(&layers, &oset, true).unwrap(); + } + + #[test] + fn test_check_initial_iff_ng() { + let layers = vec![0, 0, 0, 1, 1, 1]; + let oset = Nodes::from([0, 1]); + assert!(check_initial(&layers, &oset, true).is_err()); + } + + #[test] + fn test_check_domain_flow() { + let f = HashMap::::from([(0, 1), (1, 2)]); + let vset = Nodes::from([0, 1, 2]); + let iset = Nodes::from([0]); + let oset = Nodes::from([2]); + check_domain(f.iter(), &vset, &iset, &oset).unwrap(); + } + + #[test] + fn test_check_domain_gflow() { + let f = HashMap::::from([(0, Nodes::from([1, 2])), (1, Nodes::from([2]))]); + let vset = Nodes::from([0, 1, 2]); + let iset = Nodes::from([0]); + let oset = Nodes::from([2]); + let f_flatiter = f + .iter() + .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); + check_domain(f_flatiter, &vset, &iset, &oset).unwrap(); + } + + #[test] + fn test_check_domain_ng_iset() { + let f = HashMap::::from([(0, Nodes::from([0, 1])), (2, Nodes::from([2]))]); + let vset = Nodes::from([0, 1, 2]); + let iset = Nodes::from([0]); + let oset = Nodes::from([2]); + let f_flatiter = f + .iter() + .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); + assert!(check_domain(f_flatiter, &vset, &iset, &oset).is_err()); + } + + #[test] + fn test_check_domain_ng_oset() { + let f = HashMap::::from([(0, Nodes::from([1])), (1, Nodes::from([0]))]); + let vset = Nodes::from([0, 1, 2]); + let iset = Nodes::from([0]); + let oset = Nodes::from([2]); + let f_flatiter = f + .iter() + .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); + assert!(check_domain(f_flatiter, &vset, &iset, &oset).is_err()); + } } diff --git a/src/flow.rs b/src/flow.rs index 2d2a6e34..f29f19ff 100644 --- a/src/flow.rs +++ b/src/flow.rs @@ -1,37 +1,46 @@ //! Maximally-delayed causal flow algorithm. -use hashbrown; +use std::collections::HashMap; + use pyo3::prelude::*; use crate::{ common::{ - FATAL_MSG, + self, FATAL_MSG, FlowValidationError::{self, InconsistentFlowOrder}, - Graph, Layer, Nodes, + Graph, Layer, Layers, Node, Nodes, }, - internal::{utils::InPlaceSetDiff, validate}, + internal::utils::InPlaceSetDiff, }; -type Flow = hashbrown::HashMap; +type Flow = HashMap; -/// Checks the definition of causal flow. +/// Checks the geometric constraints of flow. /// -/// 1. i -> f(i) -/// 2. j in neighbors(f(i)) => i == j or i -> j -/// 3. i in neighbors(f(i)) -fn check_definition(f: &Flow, layer: &Layer, g: &Graph) -> Result<(), FlowValidationError> { +/// - i in N(f(i)) +fn check_def_geom(f: &Flow, g: &[Nodes]) -> Result<(), FlowValidationError> { for (&i, &fi) in f { - if layer[i] <= layer[fi] { + if !g[i].contains(&fi) { + Err(InconsistentFlowOrder { nodes: (i, fi) })?; + } + } + Ok(()) +} + +/// Checks the layer constraints of flow. +/// +/// - i -> f(i) +/// - j in N(f(i)) => i == j or i -> j +fn check_def_layer(f: &Flow, layers: &[Layer], g: &[Nodes]) -> Result<(), FlowValidationError> { + for (&i, &fi) in f { + if layers[i] <= layers[fi] { Err(InconsistentFlowOrder { nodes: (i, fi) })?; } for &j in &g[fi] { - if i != j && layer[i] <= layer[j] { + if i != j && layers[i] <= layers[j] { Err(InconsistentFlowOrder { nodes: (i, j) })?; } } - if !(g[fi].contains(&i) && g[i].contains(&fi)) { - Err(InconsistentFlowOrder { nodes: (i, fi) })?; - } } Ok(()) } @@ -56,7 +65,7 @@ fn check_definition(f: &Flow, layer: &Layer, g: &Graph) -> Result<(), FlowValida #[tracing::instrument] #[expect(clippy::needless_pass_by_value)] #[inline] -pub fn find(g: Graph, iset: Nodes, mut oset: Nodes) -> Option<(Flow, Layer)> { +pub fn find(g: Graph, iset: Nodes, mut oset: Nodes) -> Option<(Flow, Layers)> { let n = g.len(); let vset = (0..n).collect::(); let mut cset = &oset - &iset; @@ -64,7 +73,7 @@ pub fn find(g: Graph, iset: Nodes, mut oset: Nodes) -> Option<(Flow, Layer)> { let ocset = &vset - &oset; let oset_orig = oset.clone(); let mut f = Flow::with_capacity(ocset.len()); - let mut layer = vec![0_usize; n]; + let mut layers = vec![0_usize; n]; // check[v] = g[v] & (vset - oset) let mut check = g.iter().map(|x| x & &ocset).collect::>(); let mut oset_work = Nodes::new(); @@ -82,7 +91,7 @@ pub fn find(g: Graph, iset: Nodes, mut oset: Nodes) -> Option<(Flow, Layer)> { tracing::debug!("f({u}) = {v}"); f.insert(u, v); tracing::debug!("layer({u}) = {l}"); - layer[u] = l; + layers[u] = l; oset_work.insert(u); cset_work.insert(v); } @@ -102,14 +111,15 @@ pub fn find(g: Graph, iset: Nodes, mut oset: Nodes) -> Option<(Flow, Layer)> { if oset == vset { tracing::debug!("flow found"); tracing::debug!("flow : {f:?}"); - tracing::debug!("layer: {layer:?}"); + tracing::debug!("layers: {layers:?}"); // TODO: Remove this block once stabilized { - validate::check_domain(f.iter(), &vset, &iset, &oset_orig).expect(FATAL_MSG); - validate::check_initial(&layer, &oset_orig, true).expect(FATAL_MSG); - check_definition(&f, &layer, &g).expect(FATAL_MSG); + common::check_domain(f.iter(), &vset, &iset, &oset_orig).expect(FATAL_MSG); + common::check_initial(&layers, &oset_orig, true).expect(FATAL_MSG); + check_def_geom(&f, &g).expect(FATAL_MSG); + check_def_layer(&f, &layers, &g).expect(FATAL_MSG); } - Some((f, layer)) + Some((f, layers)) } else { tracing::debug!("flow not found"); None @@ -125,18 +135,19 @@ pub fn find(g: Graph, iset: Nodes, mut oset: Nodes) -> Option<(Flow, Layer)> { #[pyfunction] #[expect(clippy::needless_pass_by_value)] #[inline] -pub fn verify(flow: (Flow, Layer), g: Graph, iset: Nodes, oset: Nodes) -> PyResult<()> { - let (f, layer) = flow; +pub fn verify(flow: (Flow, Layers), g: Graph, iset: Nodes, oset: Nodes) -> PyResult<()> { + let (f, layers) = flow; let n = g.len(); let vset = (0..n).collect::(); - validate::check_domain(f.iter(), &vset, &iset, &oset)?; - validate::check_initial(&layer, &oset, true)?; - check_definition(&f, &layer, &g)?; + common::check_domain(f.iter(), &vset, &iset, &oset)?; + check_def_geom(&f, &g)?; + check_def_layer(&f, &layers, &g)?; Ok(()) } #[cfg(test)] mod tests { + use maplit::hashmap; use test_log; use super::*; @@ -146,25 +157,21 @@ mod tests { fn test_check_definition_ng() { // Violate 0 -> f(0) = 1 assert_eq!( - check_definition(&map! { 0: 1 }, &vec![0, 0], &test_utils::graph(&[(0, 1)])), + check_def_layer(&hashmap! { 0 => 1 }, &[0, 0], &test_utils::graph(&[(0, 1)])), Err(InconsistentFlowOrder { nodes: (0, 1) }) ); // Violate 1 in nb(f(0)) = nb(2) => 0 == 1 or 0 -> 1 assert_eq!( - check_definition( - &map! { 0: 2 }, - &vec![1, 1, 0], + check_def_layer( + &hashmap! { 0 => 2 }, + &[1, 1, 0], &test_utils::graph(&[(0, 1), (1, 2)]) ), Err(InconsistentFlowOrder { nodes: (0, 1) }) ); // Violate 0 in nb(f(0)) = nb(2) assert_eq!( - check_definition( - &map! { 0: 2 }, - &vec![2, 1, 0], - &test_utils::graph(&[(0, 1), (1, 2)]) - ), + check_def_geom(&hashmap! { 0 => 2 }, &test_utils::graph(&[(0, 1), (1, 2)])), Err(InconsistentFlowOrder { nodes: (0, 2) }) ); } @@ -173,38 +180,38 @@ mod tests { fn test_find_case0() { let TestCase { g, iset, oset } = test_utils::CASE0.clone(); let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone()).unwrap(); assert_eq!(f.len(), flen); - assert_eq!(layer, vec![0, 0]); - verify((f, layer), g, iset, oset).unwrap(); + assert_eq!(layers, vec![0, 0]); + verify((f, layers), g, iset, oset).unwrap(); } #[test_log::test] fn test_find_case1() { let TestCase { g, iset, oset } = test_utils::CASE1.clone(); let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], 1); assert_eq!(f[&1], 2); assert_eq!(f[&2], 3); assert_eq!(f[&3], 4); - assert_eq!(layer, vec![4, 3, 2, 1, 0]); - verify((f, layer), g, iset, oset).unwrap(); + assert_eq!(layers, vec![4, 3, 2, 1, 0]); + verify((f, layers), g, iset, oset).unwrap(); } #[test_log::test] fn test_find_case2() { let TestCase { g, iset, oset } = test_utils::CASE2.clone(); let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], 2); assert_eq!(f[&1], 3); assert_eq!(f[&2], 4); assert_eq!(f[&3], 5); - assert_eq!(layer, vec![2, 2, 1, 1, 0, 0]); - verify((f, layer), g, iset, oset).unwrap(); + assert_eq!(layers, vec![2, 2, 1, 1, 0, 0]); + verify((f, layers), g, iset, oset).unwrap(); } #[test_log::test] diff --git a/src/gflow.rs b/src/gflow.rs index 9df46c54..aa05a101 100644 --- a/src/gflow.rs +++ b/src/gflow.rs @@ -1,23 +1,22 @@ //! Maximally-delayed generalized flow algorithm. use core::iter; +use std::collections::HashMap; use fixedbitset::FixedBitSet; -use hashbrown; use pyo3::prelude::*; use crate::{ common::{ - FATAL_MSG, + self, FATAL_MSG, FlowValidationError::{ self, InconsistentFlowOrder, InconsistentFlowPlane, InvalidMeasurementSpec, }, - Graph, Layer, Nodes, OrderedNodes, + Graph, Layer, Layers, Node, Nodes, OrderedNodes, }, internal::{ gf2_linalg::GF2Solver, utils::{self, InPlaceSetDiff}, - validate, }, }; @@ -30,40 +29,23 @@ pub enum Plane { XZ, } -type Planes = hashbrown::HashMap; -type GFlow = hashbrown::HashMap; +type Planes = HashMap; +type GFlow = HashMap; -/// Checks the definition of gflow. +/// Checks the geometric constraints of gflow. /// -/// 1. i -> g(i) -/// 2. j in Odd(g(i)) => i == j or i -> j -/// 3. i not in g(i) and in Odd(g(i)) if plane(i) == XY -/// 4. i in g(i) and in Odd(g(i)) if plane(i) == YZ -/// 5. i in g(i) and not in Odd(g(i)) if plane(i) == XZ -fn check_definition( - f: &GFlow, - layer: &Layer, - g: &Graph, - planes: &Planes, -) -> Result<(), FlowValidationError> { - for &i in itertools::chain(f.keys(), planes.keys()) { +/// - XY: i not in g(i) and in Odd(g(i)) +/// - YZ: i in g(i) and in Odd(g(i)) +/// - XZ: i in g(i) and not in Odd(g(i)) +fn check_def_geom(f: &GFlow, g: &[Nodes], planes: &Planes) -> Result<(), FlowValidationError> { + for &i in Iterator::chain(f.keys(), planes.keys()) { if f.contains_key(&i) != planes.contains_key(&i) { Err(InvalidMeasurementSpec { node: i })?; } } for (&i, fi) in f { let pi = planes[&i]; - for &fij in fi { - if i != fij && layer[i] <= layer[fij] { - Err(InconsistentFlowOrder { nodes: (i, fij) })?; - } - } let odd_fi = utils::odd_neighbors(g, fi); - for &j in &odd_fi { - if i != j && layer[i] <= layer[j] { - Err(InconsistentFlowOrder { nodes: (i, j) })?; - } - } let in_info = (fi.contains(&i), odd_fi.contains(&i)); match pi { Plane::XY if in_info != (false, true) => { @@ -90,18 +72,39 @@ fn check_definition( Ok(()) } +/// Checks the layer constraints of gflow. +/// +/// - i -> g(i) +/// - j in Odd(g(i)) => i == j or i -> j +fn check_def_layer(f: &GFlow, layers: &[Layer], g: &[Nodes]) -> Result<(), FlowValidationError> { + for (&i, fi) in f { + for &fij in fi { + if i != fij && layers[i] <= layers[fij] { + Err(InconsistentFlowOrder { nodes: (i, fij) })?; + } + } + let odd_fi = utils::odd_neighbors(g, fi); + for &j in &odd_fi { + if i != j && layers[i] <= layers[j] { + Err(InconsistentFlowOrder { nodes: (i, j) })?; + } + } + } + Ok(()) +} + /// Initializes the working matrix. fn init_work( work: &mut [FixedBitSet], - g: &Graph, + g: &[Nodes], planes: &Planes, ocset: &OrderedNodes, omiset: &OrderedNodes, ) { let ncols = omiset.len(); // Set-to-index maps - let oc2i = utils::indexmap::>(ocset); - let omi2i = utils::indexmap::>(omiset); + let oc2i = utils::indexmap::>(ocset); + let omi2i = utils::indexmap::>(omiset); // Encode node as one-hot vector for (i, &u) in ocset.iter().enumerate() { let gu = &g[u]; @@ -152,7 +155,7 @@ fn init_work( #[tracing::instrument] #[expect(clippy::needless_pass_by_value)] #[inline] -pub fn find(g: Graph, iset: Nodes, oset: Nodes, planes: Planes) -> Option<(GFlow, Layer)> { +pub fn find(g: Graph, iset: Nodes, oset: Nodes, planes: Planes) -> Option<(GFlow, Layers)> { let n = g.len(); let vset = (0..n).collect::(); let mut cset = Nodes::new(); @@ -160,7 +163,7 @@ pub fn find(g: Graph, iset: Nodes, oset: Nodes, planes: Planes) -> Option<(GFlow let mut ocset = vset.difference(&oset).copied().collect::(); let mut omiset = oset.difference(&iset).copied().collect::(); let mut f = GFlow::with_capacity(ocset.len()); - let mut layer = vec![0_usize; n]; + let mut layers = vec![0_usize; n]; let mut nrows = ocset.len(); let mut ncols = omiset.len(); let mut neqs = ocset.len(); @@ -210,7 +213,7 @@ pub fn find(g: Graph, iset: Nodes, oset: Nodes, planes: Planes) -> Option<(GFlow tracing::debug!("f({u}) = {fu:?}"); f.insert(u, fu); tracing::debug!("layer({u}) = {l}"); - layer[u] = l; + layers[u] = l; } if cset.is_empty() { break; @@ -221,17 +224,18 @@ pub fn find(g: Graph, iset: Nodes, oset: Nodes, planes: Planes) -> Option<(GFlow if ocset.is_empty() { tracing::debug!("gflow found"); tracing::debug!("gflow: {f:?}"); - tracing::debug!("layer: {layer:?}"); + tracing::debug!("layers: {layers:?}"); // TODO: Remove this block once stabilized { let f_flatiter = f .iter() .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); - validate::check_domain(f_flatiter, &vset, &iset, &oset).expect(FATAL_MSG); - validate::check_initial(&layer, &oset, true).expect(FATAL_MSG); - check_definition(&f, &layer, &g, &planes).expect(FATAL_MSG); + common::check_domain(f_flatiter, &vset, &iset, &oset).expect(FATAL_MSG); + common::check_initial(&layers, &oset, true).expect(FATAL_MSG); + check_def_geom(&f, &g, &planes).expect(FATAL_MSG); + check_def_layer(&f, &layers, &g).expect(FATAL_MSG); } - Some((f, layer)) + Some((f, layers)) } else { tracing::debug!("gflow not found"); None @@ -248,26 +252,27 @@ pub fn find(g: Graph, iset: Nodes, oset: Nodes, planes: Planes) -> Option<(GFlow #[expect(clippy::needless_pass_by_value)] #[inline] pub fn verify( - gflow: (GFlow, Layer), + gflow: (GFlow, Layers), g: Graph, iset: Nodes, oset: Nodes, planes: Planes, ) -> PyResult<()> { - let (f, layer) = gflow; + let (f, layers) = gflow; let n = g.len(); let vset = (0..n).collect::(); let f_flatiter = f .iter() .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); - validate::check_domain(f_flatiter, &vset, &iset, &oset)?; - validate::check_initial(&layer, &oset, true)?; - check_definition(&f, &layer, &g, &planes)?; + common::check_domain(f_flatiter, &vset, &iset, &oset)?; + check_def_geom(&f, &g, &planes)?; + check_def_layer(&f, &layers, &g)?; Ok(()) } #[cfg(test)] mod tests { + use maplit::{hashmap, hashset}; use test_log; use super::*; @@ -277,44 +282,37 @@ mod tests { fn test_check_definition_ng() { // Missing Plane specification assert_eq!( - check_definition( - &map! { 0: set!{1} }, - &vec![1, 0], + check_def_geom( + &hashmap! { 0 => hashset!{1} }, &test_utils::graph(&[(0, 1)]), - &map! {}, + &hashmap! {}, ), Err(InvalidMeasurementSpec { node: 0 }) ); // Violate 0 -> f(0) = 1 assert_eq!( - check_definition( - &map! { 0: set!{1} }, - &vec![0, 0], + check_def_layer( + &hashmap! { 0 => hashset!{1} }, + &[0, 0], &test_utils::graph(&[(0, 1)]), - &map! { 0: Plane::XY }, ), Err(InconsistentFlowOrder { nodes: (0, 1) }) ); // Violate 1 in nb(f(0)) = nb(2) => 0 == 1 or 0 -> 1 assert_eq!( - check_definition( - &map! { 0: set!{2}, 1: set!{2} }, - &vec![1, 1, 0], + check_def_layer( + &hashmap! { 0 => hashset!{2}, 1 => hashset!{2} }, + &[1, 1, 0], &test_utils::graph(&[(0, 1), (1, 2)]), - &map! { - 0: Plane::XY, - 1: Plane::XY - }, ), Err(InconsistentFlowOrder { nodes: (0, 1) }) ); // Violate XY: 0 in f(0) assert_eq!( - check_definition( - &map! { 0: set!{0} }, - &vec![1, 0], + check_def_geom( + &hashmap! { 0 => hashset!{0} }, &test_utils::graph(&[(0, 1)]), - &map! { 0: Plane::XY }, + &hashmap! { 0 => Plane::XY }, ), Err(InconsistentFlowPlane { node: 0, @@ -323,11 +321,10 @@ mod tests { ); // Violate YZ: 0 in Odd(f(0)) assert_eq!( - check_definition( - &map! { 0: set!{1} }, - &vec![1, 0], + check_def_geom( + &hashmap! { 0 => hashset!{1} }, &test_utils::graph(&[(0, 1)]), - &map! { 0: Plane::YZ }, + &hashmap! { 0 => Plane::YZ }, ), Err(InconsistentFlowPlane { node: 0, @@ -336,11 +333,10 @@ mod tests { ); // Violate XZ: 0 not in Odd(f(0)) and in f(0) assert_eq!( - check_definition( - &map! { 0: set!{0} }, - &vec![1, 0], + check_def_geom( + &hashmap! { 0 => hashset!{0} }, &test_utils::graph(&[(0, 1)]), - &map! { 0: Plane::XZ }, + &hashmap! { 0 => Plane::XZ }, ), Err(InconsistentFlowPlane { node: 0, @@ -349,11 +345,10 @@ mod tests { ); // Violate XZ: 0 in Odd(f(0)) and not in f(0) assert_eq!( - check_definition( - &map! { 0: set!{1} }, - &vec![1, 0], + check_def_geom( + &hashmap! { 0 => hashset!{1} }, &test_utils::graph(&[(0, 1)]), - &map! { 0: Plane::XZ }, + &hashmap! { 0 => Plane::XZ }, ), Err(InconsistentFlowPlane { node: 0, @@ -365,98 +360,98 @@ mod tests { #[test_log::test] fn test_find_case0() { let TestCase { g, iset, oset } = test_utils::CASE0.clone(); - let planes = map! {}; + let planes = hashmap! {}; let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); assert_eq!(f.len(), flen); - assert_eq!(layer, vec![0, 0]); - verify((f, layer), g, iset, oset, planes).unwrap(); + assert_eq!(layers, vec![0, 0]); + verify((f, layers), g, iset, oset, planes).unwrap(); } #[test_log::test] fn test_find_case1() { let TestCase { g, iset, oset } = test_utils::CASE1.clone(); - let planes = map! { - 0: Plane::XY, - 1: Plane::XY, - 2: Plane::XY, - 3: Plane::XY + let planes = hashmap! { + 0 => Plane::XY, + 1 => Plane::XY, + 2 => Plane::XY, + 3 => Plane::XY }; let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([1])); assert_eq!(f[&1], Nodes::from([2])); assert_eq!(f[&2], Nodes::from([3])); assert_eq!(f[&3], Nodes::from([4])); - assert_eq!(layer, vec![4, 3, 2, 1, 0]); - verify((f, layer), g, iset, oset, planes).unwrap(); + assert_eq!(layers, vec![4, 3, 2, 1, 0]); + verify((f, layers), g, iset, oset, planes).unwrap(); } #[test_log::test] fn test_find_case2() { let TestCase { g, iset, oset } = test_utils::CASE2.clone(); - let planes = map! { - 0: Plane::XY, - 1: Plane::XY, - 2: Plane::XY, - 3: Plane::XY + let planes = hashmap! { + 0 => Plane::XY, + 1 => Plane::XY, + 2 => Plane::XY, + 3 => Plane::XY }; let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([2])); assert_eq!(f[&1], Nodes::from([3])); assert_eq!(f[&2], Nodes::from([4])); assert_eq!(f[&3], Nodes::from([5])); - assert_eq!(layer, vec![2, 2, 1, 1, 0, 0]); - verify((f, layer), g, iset, oset, planes).unwrap(); + assert_eq!(layers, vec![2, 2, 1, 1, 0, 0]); + verify((f, layers), g, iset, oset, planes).unwrap(); } #[test_log::test] fn test_find_case3() { let TestCase { g, iset, oset } = test_utils::CASE3.clone(); - let planes = map! { - 0: Plane::XY, - 1: Plane::XY, - 2: Plane::XY + let planes = hashmap! { + 0 => Plane::XY, + 1 => Plane::XY, + 2 => Plane::XY }; let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([4, 5])); assert_eq!(f[&1], Nodes::from([3, 4, 5])); assert_eq!(f[&2], Nodes::from([3, 5])); - assert_eq!(layer, vec![1, 1, 1, 0, 0, 0]); - verify((f, layer), g, iset, oset, planes).unwrap(); + assert_eq!(layers, vec![1, 1, 1, 0, 0, 0]); + verify((f, layers), g, iset, oset, planes).unwrap(); } #[test_log::test] fn test_find_case4() { let TestCase { g, iset, oset } = test_utils::CASE4.clone(); - let planes = map! { - 0: Plane::XY, - 1: Plane::XY, - 2: Plane::XZ, - 3: Plane::YZ + let planes = hashmap! { + 0 => Plane::XY, + 1 => Plane::XY, + 2 => Plane::XZ, + 3 => Plane::YZ }; let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone(), planes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([2])); assert_eq!(f[&1], Nodes::from([5])); assert_eq!(f[&2], Nodes::from([2, 4])); assert_eq!(f[&3], Nodes::from([3])); - assert_eq!(layer, vec![2, 2, 1, 1, 0, 0]); - verify((f, layer), g, iset, oset, planes).unwrap(); + assert_eq!(layers, vec![2, 2, 1, 1, 0, 0]); + verify((f, layers), g, iset, oset, planes).unwrap(); } #[test_log::test] fn test_find_case5() { let TestCase { g, iset, oset } = test_utils::CASE5.clone(); - let planes = map! { - 0: Plane::XY, - 1: Plane::XY + let planes = hashmap! { + 0 => Plane::XY, + 1 => Plane::XY }; assert!(find(g, iset, oset, planes).is_none()); } @@ -464,11 +459,11 @@ mod tests { #[test_log::test] fn test_find_case6() { let TestCase { g, iset, oset } = test_utils::CASE6.clone(); - let planes = map! { - 0: Plane::XY, - 1: Plane::XY, - 2: Plane::XY, - 3: Plane::XY + let planes = hashmap! { + 0 => Plane::XY, + 1 => Plane::XY, + 2 => Plane::XY, + 3 => Plane::XY }; assert!(find(g, iset, oset, planes).is_none()); } @@ -476,11 +471,11 @@ mod tests { #[test_log::test] fn test_find_case7() { let TestCase { g, iset, oset } = test_utils::CASE7.clone(); - let planes = map! { - 0: Plane::YZ, - 1: Plane::XZ, - 2: Plane::XY, - 3: Plane::YZ + let planes = hashmap! { + 0 => Plane::YZ, + 1 => Plane::XZ, + 2 => Plane::XY, + 3 => Plane::YZ }; assert!(find(g, iset, oset, planes).is_none()); } @@ -488,10 +483,10 @@ mod tests { #[test_log::test] fn test_find_case8() { let TestCase { g, iset, oset } = test_utils::CASE8.clone(); - let planes = map! { - 0: Plane::YZ, - 1: Plane::XZ, - 2: Plane::XY + let planes = hashmap! { + 0 => Plane::YZ, + 1 => Plane::XZ, + 2 => Plane::XY }; assert!(find(g, iset, oset, planes).is_none()); } diff --git a/src/internal.rs b/src/internal.rs index 1df283f0..46f4aef4 100644 --- a/src/internal.rs +++ b/src/internal.rs @@ -1,7 +1,5 @@ #[cfg(test)] -#[macro_use] pub mod test_utils; pub mod gf2_linalg; pub mod utils; -pub mod validate; diff --git a/src/internal/gf2_linalg.rs b/src/internal/gf2_linalg.rs index 8589e140..7e7724a6 100644 --- a/src/internal/gf2_linalg.rs +++ b/src/internal/gf2_linalg.rs @@ -1,16 +1,12 @@ //! GF(2) linear solver for gflow algorithm. -use core::{ - fmt::{self, Debug, Formatter}, - ops::DerefMut, -}; -use std::collections::BTreeMap; +use core::{fmt::Debug, ops::DerefMut}; use fixedbitset::FixedBitSet; use itertools::Itertools; /// Solver for GF(2) linear equations. -#[derive(PartialEq, Eq)] +#[derive(PartialEq, Eq, Debug)] pub struct GF2Solver { /// Number of rows in the coefficient matrix. rows: usize, @@ -248,37 +244,6 @@ impl> GF2Solver { } } -impl> Debug for GF2Solver { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let mut ret = f.debug_struct("GF2Solver"); - ret.field("rows", &self.rows) - .field("cols", &self.cols) - .field("neqs", &self.neqs) - .field("rank", &self.rank) - .field("perm", &self.perm); - let mut work = BTreeMap::new(); - for (r, row) in self.work.iter().enumerate() { - let mut s = String::with_capacity(self.cols); - for c in 0..self.cols { - s.push(if row[c] { '1' } else { '0' }); - } - work.insert(r, s); - } - ret.field("co", &work); - let mut work = BTreeMap::new(); - for (r, row) in self.work.iter().enumerate() { - let mut s = String::with_capacity(self.neqs); - for ieq in 0..self.neqs { - let c = self.cols + ieq; - s.push(if row[c] { '1' } else { '0' }); - } - work.insert(r, s); - } - ret.field("rhs", &work); - ret.finish() - } -} - #[cfg(test)] mod tests { use rand::{self, Rng}; diff --git a/src/internal/test_utils.rs b/src/internal/test_utils.rs index 82dfcca5..8862e0eb 100644 --- a/src/internal/test_utils.rs +++ b/src/internal/test_utils.rs @@ -2,29 +2,10 @@ use std::sync::LazyLock; -use crate::common::{Graph, Nodes}; - -pub mod exports { - pub use hashbrown::{HashMap, HashSet}; -} - -macro_rules! map { - ($($u:literal: $v:expr),*) => { - // Dirty .expect to handle i32 -> usize conversion - $crate::internal::test_utils::exports::HashMap::from_iter([$(($u, ($v).try_into().expect("dynamic coersion"))),*].into_iter()) - }; - ($($u:literal: $v:expr),*,) => {map! { $($u: $v),* }}; -} - -macro_rules! set { - ($($u:literal),*) => { - $crate::internal::test_utils::exports::HashSet::from_iter([$($u),*].into_iter()) - }; - ($($u:literal),*,) => {set! { $($u),* }}; -} +use crate::common::{Graph, Node, Nodes}; /// Creates a undirected graph from edges. -pub fn graph(edges: &[(usize, usize); N]) -> Graph { +pub fn graph(edges: &[(Node, Node); N]) -> Graph { let n = edges .iter() .map(|&(u, v)| u.max(v) + 1) @@ -174,22 +155,22 @@ mod tests { /// Checks if the graph is valid. /// /// In production code, this check should be done in the Python layer. - fn check_graph(g: &Graph, iset: &Nodes, oset: &Nodes) { + fn check_graph(g: &[Nodes], iset: &Nodes, oset: &Nodes) { let n = g.len(); assert_ne!(n, 0, "empty graph"); for (u, gu) in g.iter().enumerate() { assert!(!gu.contains(&u), "self-loop detected: {u}"); - gu.iter().for_each(|&v| { + for &v in gu { assert!(v < n, "node index out of range: {v}"); assert!(g[v].contains(&u), "g must be undirected: {u} -> {v}"); - }); + } } - iset.iter().for_each(|&u| { + for &u in iset { assert!((0..n).contains(&u), "unknown node in iset: {u}"); - }); - oset.iter().for_each(|&u| { + } + for &u in oset { assert!((0..n).contains(&u), "unknown node in oset: {u}"); - }); + } } #[apply(template_tests)] diff --git a/src/internal/utils.rs b/src/internal/utils.rs index d80faa1b..a7a0dcdc 100644 --- a/src/internal/utils.rs +++ b/src/internal/utils.rs @@ -4,22 +4,25 @@ use core::{ hash::Hash, ops::{Deref, DerefMut}, }; -use std::collections::BTreeSet; +use std::collections::{BTreeSet, HashSet}; use fixedbitset::FixedBitSet; -use crate::common::{Graph, Nodes, OrderedNodes}; +use crate::common::{Node, Nodes, OrderedNodes}; /// Computes the odd neighbors of the nodes in `kset`. -/// -/// # Note -/// -/// - Naive implementation only for post-verification. -pub fn odd_neighbors(g: &Graph, kset: &Nodes) -> Nodes { +pub fn odd_neighbors(g: &[Nodes], kset: &Nodes) -> Nodes { assert!(kset.iter().all(|&ki| ki < g.len()), "kset out of range"); - let mut work = kset.clone(); - work.extend(kset.iter().flat_map(|&ki| g[ki].iter().copied())); - work.retain(|&u| kset.intersection(&g[u]).count() % 2 == 1); + let mut work = Nodes::default(); + for &k in kset { + for &u in &g[k] { + if work.contains(&u) { + work.remove(&u); + } else { + work.insert(u); + } + } + } work } @@ -39,7 +42,7 @@ pub trait InPlaceSetDiff { U: Deref; } -impl InPlaceSetDiff for hashbrown::HashSet +impl InPlaceSetDiff for HashSet where T: Eq + Hash, { @@ -78,11 +81,11 @@ pub fn indexmap>(set: &OrderedNodes) -> T { /// Inserts `u` on construction and reverts on drop. pub struct ScopedInclude<'a> { target: &'a mut OrderedNodes, - u: Option, + u: Option, } impl<'a> ScopedInclude<'a> { - pub fn new(target: &'a mut OrderedNodes, u: usize) -> Self { + pub fn new(target: &'a mut OrderedNodes, u: Node) -> Self { let u = if target.insert(u) { Some(u) } else { None }; Self { target, u } } @@ -116,11 +119,11 @@ impl Drop for ScopedInclude<'_> { /// Removes `u` on construction and reverts on drop. pub struct ScopedExclude<'a> { target: &'a mut OrderedNodes, - u: Option, + u: Option, } impl<'a> ScopedExclude<'a> { - pub fn new(target: &'a mut OrderedNodes, u: usize) -> Self { + pub fn new(target: &'a mut OrderedNodes, u: Node) -> Self { let u = if target.remove(&u) { Some(u) } else { None }; Self { target, u } } @@ -193,9 +196,9 @@ mod tests { #[test] fn test_difference_with_hashset() { - let mut set = hashbrown::HashSet::from([1, 2, 3]); + let mut set = HashSet::from([1, 2, 3]); set.difference_with(&[2, 3, 4]); - assert_eq!(set, hashbrown::HashSet::from([1])); + assert_eq!(set, HashSet::from([1])); } #[test] diff --git a/src/internal/validate.rs b/src/internal/validate.rs deleted file mode 100644 index 4c1e7a8f..00000000 --- a/src/internal/validate.rs +++ /dev/null @@ -1,158 +0,0 @@ -//! Rust-side input validations. -//! -//! # Note -//! -//! - Internal module for testing. - -use crate::common::{ - FlowValidationError::{ - self, ExcessiveNonZeroLayer, ExcessiveZeroLayer, InvalidFlowCodomain, InvalidFlowDomain, - }, - Layer, Nodes, -}; - -/// Checks if the layer-zero nodes are correctly chosen. -/// -/// # Arguments -/// -/// - `layer`: The layer. -/// - `oset`: The set of output nodes. -/// - `iff`: If `true`, `layer[u] == 0` "iff" `u` is in `oset`. Otherwise "if". -pub fn check_initial(layer: &Layer, oset: &Nodes, iff: bool) -> Result<(), FlowValidationError> { - for (u, &lu) in layer.iter().enumerate() { - match (oset.contains(&u), lu == 0) { - (true, false) => { - Err(ExcessiveNonZeroLayer { node: u, layer: lu })?; - } - (false, true) if iff => { - Err(ExcessiveZeroLayer { node: u })?; - } - _ => {} - } - } - Ok(()) -} - -/// Checks if the domain of `f` is in `vset - oset` and the codomain is in `vset - iset`. -/// -/// # Arguments -/// -/// - `f_flatiter`: Flow, gflow, or pflow as `impl Iterator`. -/// - `vset`: All nodes. -/// - `iset`: Input nodes. -/// - `oset`: Output nodes. -/// -/// # Note -/// -/// It is allowed for `f[i]` to contain `i`, even if `i` is in `iset`. -pub fn check_domain<'a, 'b>( - f_flatiter: impl Iterator, - vset: &Nodes, - iset: &Nodes, - oset: &Nodes, -) -> Result<(), FlowValidationError> { - let icset = vset - iset; - let ocset = vset - oset; - let mut dom = Nodes::new(); - for (&i, &fi) in f_flatiter { - dom.insert(i); - if i != fi && !icset.contains(&fi) { - Err(InvalidFlowCodomain { node: i })?; - } - } - if let Some(&i) = dom.symmetric_difference(&ocset).next() { - Err(InvalidFlowDomain { node: i })?; - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use core::iter; - - use super::*; - use crate::common::Nodes; - - #[test] - fn test_check_initial() { - let layer = vec![0, 0, 0, 1, 1, 1]; - let oset = Nodes::from([0, 1]); - check_initial(&layer, &oset, false).unwrap(); - } - - #[test] - fn test_check_initial_ng() { - let layer = vec![0, 0, 0, 1, 1, 1]; - let oset = Nodes::from([0, 1, 2, 3]); - assert!(check_initial(&layer, &oset, false).is_err()); - } - - #[test] - fn test_check_initial_iff() { - let layer = vec![0, 0, 0, 1, 1, 1]; - let oset = Nodes::from([0, 1, 2]); - check_initial(&layer, &oset, true).unwrap(); - } - - #[test] - fn test_check_initial_iff_ng() { - let layer = vec![0, 0, 0, 1, 1, 1]; - let oset = Nodes::from([0, 1]); - assert!(check_initial(&layer, &oset, true).is_err()); - } - - #[test] - fn test_check_domain_flow() { - let f = hashbrown::HashMap::::from([(0, 1), (1, 2)]); - let vset = Nodes::from([0, 1, 2]); - let iset = Nodes::from([0]); - let oset = Nodes::from([2]); - check_domain(f.iter(), &vset, &iset, &oset).unwrap(); - } - - #[test] - fn test_check_domain_gflow() { - let f = hashbrown::HashMap::::from([ - // OK: 0 in f(0) - (0, Nodes::from([0, 1])), - (1, Nodes::from([2])), - ]); - let vset = Nodes::from([0, 1, 2]); - let iset = Nodes::from([0]); - let oset = Nodes::from([2]); - let f_flatiter = f - .iter() - .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); - check_domain(f_flatiter, &vset, &iset, &oset).unwrap(); - } - - #[test] - fn test_check_domain_ng_iset() { - let f = hashbrown::HashMap::::from([ - (0, Nodes::from([0, 1])), - (2, Nodes::from([2])), - ]); - let vset = Nodes::from([0, 1, 2]); - let iset = Nodes::from([0]); - let oset = Nodes::from([2]); - let f_flatiter = f - .iter() - .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); - assert!(check_domain(f_flatiter, &vset, &iset, &oset).is_err()); - } - - #[test] - fn test_check_domain_ng_oset() { - let f = hashbrown::HashMap::::from([ - (0, Nodes::from([1])), - (1, Nodes::from([0])), - ]); - let vset = Nodes::from([0, 1, 2]); - let iset = Nodes::from([0]); - let oset = Nodes::from([2]); - let f_flatiter = f - .iter() - .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); - assert!(check_domain(f_flatiter, &vset, &iset, &oset).is_err()); - } -} diff --git a/src/lib.rs b/src/lib.rs index 44f43580..5cc678b4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,12 +12,10 @@ rust_2024_compatibility )] -#[macro_use] -mod internal; - -pub mod common; +mod common; pub mod flow; pub mod gflow; +mod internal; pub mod pflow; use common::FlowValidationError; diff --git a/src/pflow.rs b/src/pflow.rs index 70409bd4..b3e36d8d 100644 --- a/src/pflow.rs +++ b/src/pflow.rs @@ -1,23 +1,22 @@ //! Maximally-delayed Pauli flow algorithm. use core::iter; +use std::collections::{HashMap, HashSet}; use fixedbitset::FixedBitSet; -use hashbrown; use pyo3::prelude::*; use crate::{ common::{ - FATAL_MSG, + self, FATAL_MSG, FlowValidationError::{ self, InconsistentFlowOrder, InconsistentFlowPPlane, InvalidMeasurementSpec, }, - Graph, Layer, Nodes, OrderedNodes, + Graph, Layer, Layers, Node, Nodes, OrderedNodes, }, internal::{ gf2_linalg::GF2Solver, utils::{self, InPlaceSetDiff, ScopedExclude, ScopedInclude}, - validate, }, }; @@ -33,50 +32,19 @@ pub enum PPlane { Z, } -type PPlanes = hashbrown::HashMap; -type PFlow = hashbrown::HashMap; +type PPlanes = HashMap; +type PFlow = HashMap; -/// Checks the definition of Pauli flow. -fn check_definition( - f: &PFlow, - layer: &Layer, - g: &Graph, - pplanes: &PPlanes, -) -> Result<(), FlowValidationError> { - for &i in itertools::chain(f.keys(), pplanes.keys()) { +/// Checks the geometric constraints of pflow. +fn check_def_geom(f: &PFlow, g: &[Nodes], pplanes: &PPlanes) -> Result<(), FlowValidationError> { + for &i in Iterator::chain(f.keys(), pplanes.keys()) { if f.contains_key(&i) != pplanes.contains_key(&i) { Err(InvalidMeasurementSpec { node: i })?; } } for (&i, fi) in f { let pi = pplanes[&i]; - for &fij in fi { - match (i != fij, layer[i] <= layer[fij]) { - (true, true) if !matches!(pplanes.get(&fij), Some(&PPlane::X | &PPlane::Y)) => { - Err(InconsistentFlowOrder { nodes: (i, fij) })?; - } - (false, false) => unreachable!("layer[i] == layer[i]"), - _ => {} - } - } let odd_fi = utils::odd_neighbors(g, fi); - for &j in &odd_fi { - match (i != j, layer[i] <= layer[j]) { - (true, true) if !matches!(pplanes.get(&j), Some(&PPlane::Y | &PPlane::Z)) => { - Err(InconsistentFlowOrder { nodes: (i, j) })?; - } - (false, false) => unreachable!("layer[i] == layer[i]"), - _ => {} - } - } - for &j in fi.symmetric_difference(&odd_fi) { - if pplanes.get(&j) == Some(&PPlane::Y) && i != j && layer[i] <= layer[j] { - Err(InconsistentFlowPPlane { - node: i, - pplane: PPlane::Y, - })?; - } - } let in_info = (fi.contains(&i), odd_fi.contains(&i)); match pi { PPlane::XY if in_info != (false, true) => { @@ -121,6 +89,45 @@ fn check_definition( Ok(()) } +/// Checks the layer constraints of pflow. +fn check_def_layer( + f: &PFlow, + layers: &[Layer], + g: &[Nodes], + pplanes: &PPlanes, +) -> Result<(), FlowValidationError> { + for (&i, fi) in f { + for &fij in fi { + match (i != fij, layers[i] <= layers[fij]) { + (true, true) if !matches!(pplanes.get(&fij), Some(&PPlane::X | &PPlane::Y)) => { + Err(InconsistentFlowOrder { nodes: (i, fij) })?; + } + (false, false) => unreachable!("layers[i] == layers[i]"), + _ => {} + } + } + let odd_fi = utils::odd_neighbors(g, fi); + for &j in &odd_fi { + match (i != j, layers[i] <= layers[j]) { + (true, true) if !matches!(pplanes.get(&j), Some(&PPlane::Y | &PPlane::Z)) => { + Err(InconsistentFlowOrder { nodes: (i, j) })?; + } + (false, false) => unreachable!("layers[i] == layers[i]"), + _ => {} + } + } + for &j in fi.symmetric_difference(&odd_fi) { + if pplanes.get(&j) == Some(&PPlane::Y) && i != j && layers[i] <= layers[j] { + Err(InconsistentFlowPPlane { + node: i, + pplane: PPlane::Y, + })?; + } + } + } + Ok(()) +} + /// Sellects nodes from `src` with `pred`. fn matching_nodes(src: &PPlanes, mut pred: impl FnMut(&PPlane) -> bool) -> Nodes { src.iter() @@ -131,11 +138,11 @@ fn matching_nodes(src: &PPlanes, mut pred: impl FnMut(&PPlane) -> bool) -> Nodes /// Initializes the upper block of working storage. fn init_work_upper_co( work: &mut [FixedBitSet], - g: &Graph, + g: &[Nodes], rowset: &OrderedNodes, colset: &OrderedNodes, ) { - let colset2i = utils::indexmap::>(colset); + let colset2i = utils::indexmap::>(colset); for (r, &v) in rowset.iter().enumerate() { let gv = &g[v]; for &w in gv { @@ -149,11 +156,11 @@ fn init_work_upper_co( /// Initializes the lower block of working storage. fn init_work_lower_co( work: &mut [FixedBitSet], - g: &Graph, + g: &[Nodes], rowset: &OrderedNodes, colset: &OrderedNodes, ) { - let colset2i = utils::indexmap::>(colset); + let colset2i = utils::indexmap::>(colset); for (r, &v) in rowset.iter().enumerate() { // need to introduce self-loops if let Some(&c) = colset2i.get(&v) { @@ -176,8 +183,8 @@ const BRANCH_XZ: BranchKind = 2; /// Initializes the right-hand side of working storage for the upper block. fn init_work_upper_rhs( work: &mut [FixedBitSet], - u: usize, - g: &Graph, + u: Node, + g: &[Nodes], rowset: &OrderedNodes, colset: &OrderedNodes, ) { @@ -185,7 +192,7 @@ fn init_work_upper_rhs( assert!(K == BRANCH_XY || K == BRANCH_YZ || K == BRANCH_XZ); }; debug_assert!(rowset.contains(&u)); - let rowset2i = utils::indexmap::>(rowset); + let rowset2i = utils::indexmap::>(rowset); let c = colset.len(); let gu = &g[u]; if K != BRANCH_YZ { @@ -206,15 +213,15 @@ fn init_work_upper_rhs( /// Initializes the right-hand side of working storage for the lower block. fn init_work_lower_rhs( work: &mut [FixedBitSet], - u: usize, - g: &Graph, + u: Node, + g: &[Nodes], rowset: &OrderedNodes, colset: &OrderedNodes, ) { const { assert!(K == BRANCH_XY || K == BRANCH_YZ || K == BRANCH_XZ); }; - let rowset2i = utils::indexmap::>(rowset); + let rowset2i = utils::indexmap::>(rowset); let c = colset.len(); let gu = &g[u]; if K == BRANCH_XY { @@ -230,8 +237,8 @@ fn init_work_lower_rhs( /// Initializes working storage for the given branch kind. fn init_work( work: &mut [FixedBitSet], - u: usize, - g: &Graph, + u: Node, + g: &[Nodes], rowset_upper: &OrderedNodes, rowset_lower: &OrderedNodes, colset: &OrderedNodes, @@ -247,7 +254,7 @@ fn init_work( } /// Decodes the solution returned by `GF2Solver`. -fn decode_solution(u: usize, x: &FixedBitSet, colset: &OrderedNodes) -> Nodes { +fn decode_solution(u: Node, x: &FixedBitSet, colset: &OrderedNodes) -> Nodes { const { assert!(K == BRANCH_XY || K == BRANCH_YZ || K == BRANCH_XZ); }; @@ -265,11 +272,11 @@ fn decode_solution(u: usize, x: &FixedBitSet, colset: &Orde #[derive(Debug)] struct PFlowContext<'a> { work: &'a mut Vec, - g: &'a Graph, - u: usize, - rowset_upper: &'a OrderedNodes, - rowset_lower: &'a OrderedNodes, - colset: &'a OrderedNodes, + g: &'a [Nodes], + u: Node, + rowset_upper: &'a ScopedInclude<'a>, + rowset_lower: &'a ScopedExclude<'a>, + colset: &'a ScopedExclude<'a>, x: &'a mut FixedBitSet, f: &'a mut PFlow, } @@ -325,7 +332,7 @@ fn find_impl(ctx: &mut PFlowContext<'_>) -> bool { #[tracing::instrument] #[expect(clippy::needless_pass_by_value)] #[inline] -pub fn find(g: Graph, iset: Nodes, oset: Nodes, pplanes: PPlanes) -> Option<(PFlow, Layer)> { +pub fn find(g: Graph, iset: Nodes, oset: Nodes, pplanes: PPlanes) -> Option<(PFlow, Layers)> { let yset = matching_nodes(&pplanes, |pp| matches!(pp, PPlane::Y)); let xyset = matching_nodes(&pplanes, |pp| matches!(pp, PPlane::X | PPlane::Y)); let yzset = matching_nodes(&pplanes, |pp| matches!(pp, PPlane::Y | PPlane::Z)); @@ -337,7 +344,7 @@ pub fn find(g: Graph, iset: Nodes, oset: Nodes, pplanes: PPlanes) -> Option<(PFl let mut rowset_lower = yset.iter().copied().collect::(); let mut colset = xyset.difference(&iset).copied().collect::(); let mut f = PFlow::with_capacity(ocset.len()); - let mut layer = vec![0_usize; n]; + let mut layers = vec![0_usize; n]; let mut work = vec![FixedBitSet::new(); rowset_upper.len() + rowset_lower.len()]; for l in 0_usize.. { tracing::debug!("=====layer {l}====="); @@ -388,7 +395,7 @@ pub fn find(g: Graph, iset: Nodes, oset: Nodes, pplanes: PPlanes) -> Option<(PFl if done { tracing::debug!("f({}) = {:?}", u, &f[&u]); tracing::debug!("layer({u}) = {l}"); - layer[u] = l; + layers[u] = l; cset.insert(u); } else { tracing::debug!("solution not found: {u} (all branches)"); @@ -409,23 +416,49 @@ pub fn find(g: Graph, iset: Nodes, oset: Nodes, pplanes: PPlanes) -> Option<(PFl if ocset.is_empty() { tracing::debug!("pflow found"); tracing::debug!("pflow: {f:?}"); - tracing::debug!("layer: {layer:?}"); + tracing::debug!("layers: {layers:?}"); // TODO: Remove this block once stabilized { let f_flatiter = f .iter() .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); - validate::check_domain(f_flatiter, &vset, &iset, &oset).expect(FATAL_MSG); - validate::check_initial(&layer, &oset, false).expect(FATAL_MSG); - check_definition(&f, &layer, &g, &pplanes).expect(FATAL_MSG); + common::check_domain(f_flatiter, &vset, &iset, &oset).expect(FATAL_MSG); + common::check_initial(&layers, &oset, false).expect(FATAL_MSG); + check_def_geom(&f, &g, &pplanes).expect(FATAL_MSG); + check_def_layer(&f, &layers, &g, &pplanes).expect(FATAL_MSG); } - Some((f, layer)) + Some((f, layers)) } else { tracing::debug!("pflow not found"); None } } +/// Compute special edges that can bypass partial order constraints in Pauli flow. +fn special_edges(g: &[Nodes], pflow: &PFlow, pplanes: &PPlanes) -> HashSet<(Node, Node)> { + let mut ret = HashSet::new(); + for (&i, fi) in pflow { + for &j in fi.iter().filter(|&&j| i != j) { + if let Some(PPlane::X) = pplanes.get(&i) { + ret.insert((i, j)); + } + } + let mut odd_fi = utils::odd_neighbors(g, fi); + odd_fi.remove(&i); + for &j in &odd_fi { + if let Some(PPlane::Z) = pplanes.get(&i) { + ret.insert((i, j)); + } + } + for &j in fi.intersection(&odd_fi) { + if let Some(PPlane::Y) = pplanes.get(&i) { + ret.insert((i, j)); + } + } + } + ret +} + /// Validates Pauli flow. /// /// # Errors @@ -436,26 +469,29 @@ pub fn find(g: Graph, iset: Nodes, oset: Nodes, pplanes: PPlanes) -> Option<(PFl #[expect(clippy::needless_pass_by_value)] #[inline] pub fn verify( - pflow: (PFlow, Layer), + pflow: (PFlow, Layers), g: Graph, iset: Nodes, oset: Nodes, pplanes: PPlanes, ) -> PyResult<()> { - let (f, layer) = pflow; + let (f, layers) = pflow; let n = g.len(); let vset = (0..n).collect::(); + let special = special_edges(&g, &f, &pplanes); let f_flatiter = f .iter() - .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())); - validate::check_domain(f_flatiter, &vset, &iset, &oset)?; - validate::check_initial(&layer, &oset, false)?; - check_definition(&f, &layer, &g, &pplanes)?; + .flat_map(|(i, fi)| Iterator::zip(iter::repeat(i), fi.iter())) + .filter(|&(&i, &j)| !special.contains(&(i, j))); + common::check_domain(f_flatiter, &vset, &iset, &oset)?; + check_def_geom(&f, &g, &pplanes)?; + check_def_layer(&f, &layers, &g, &pplanes)?; Ok(()) } #[cfg(test)] mod tests { + use maplit::{hashmap, hashset}; use test_log; use super::*; @@ -465,44 +501,43 @@ mod tests { fn test_check_definition_ng() { // Missing Plane specification assert_eq!( - check_definition( - &map! { 0: set!{1} }, - &vec![1, 0], + check_def_geom( + &hashmap! { 0 => hashset!{1} }, &test_utils::graph(&[(0, 1)]), - &map! {}, + &hashmap! {}, ), Err(InvalidMeasurementSpec { node: 0 }) ); // Violate 0 -> f(0) = 1 assert_eq!( - check_definition( - &map! { 0: set!{1} }, - &vec![0, 0], + check_def_layer( + &hashmap! { 0 => hashset!{1} }, + &[0, 0], &test_utils::graph(&[(0, 1)]), - &map! { 0: PPlane::XY }, + &hashmap! { 0 => PPlane::XY }, ), Err(InconsistentFlowOrder { nodes: (0, 1) }) ); // Violate 1 in nb(f(0)) = nb(2) => 0 == 1 or 0 -> 1 assert_eq!( - check_definition( - &map! { 0: set!{2}, 1: set!{2} }, - &vec![1, 1, 0], + check_def_layer( + &hashmap! { 0 => hashset!{2}, 1 => hashset!{2} }, + &[1, 1, 0], &test_utils::graph(&[(0, 1), (1, 2)]), - &map! { - 0: PPlane::XY, - 1: PPlane::XY + &hashmap! { + 0 => PPlane::XY, + 1 => PPlane::XY }, ), Err(InconsistentFlowOrder { nodes: (0, 1) }) ); // Violate Y: 0 != 1 and not 0 -> 1 and 1 in f(0) ^ Odd(f(0)) assert_eq!( - check_definition( - &map! { 0: set!{1}, 1: set!{2} }, - &vec![1, 1, 0], + check_def_layer( + &hashmap! { 0 => hashset!{1}, 1 => hashset!{2} }, + &[1, 1, 0], &test_utils::graph(&[(0, 1), (1, 2)]), - &map! { 0: PPlane::XY, 1: PPlane::Y }, + &hashmap! { 0 => PPlane::XY, 1 => PPlane::Y }, ), Err(InconsistentFlowPPlane { node: 0, @@ -511,11 +546,10 @@ mod tests { ); // Violate XY: 0 in f(0) assert_eq!( - check_definition( - &map! { 0: set!{0} }, - &vec![1, 0], + check_def_geom( + &hashmap! { 0 => hashset!{0} }, &test_utils::graph(&[(0, 1)]), - &map! { 0: PPlane::XY }, + &hashmap! { 0 => PPlane::XY }, ), Err(InconsistentFlowPPlane { node: 0, @@ -524,11 +558,10 @@ mod tests { ); // Violate YZ: 0 in Odd(f(0)) assert_eq!( - check_definition( - &map! { 0: set!{1} }, - &vec![1, 0], + check_def_geom( + &hashmap! { 0 => hashset!{1} }, &test_utils::graph(&[(0, 1)]), - &map! { 0: PPlane::YZ }, + &hashmap! { 0 => PPlane::YZ }, ), Err(InconsistentFlowPPlane { node: 0, @@ -537,11 +570,10 @@ mod tests { ); // Violate XZ: 0 not in Odd(f(0)) and in f(0) assert_eq!( - check_definition( - &map! { 0: set!{0} }, - &vec![1, 0], + check_def_geom( + &hashmap! { 0 => hashset!{0} }, &test_utils::graph(&[(0, 1)]), - &map! { 0: PPlane::XZ }, + &hashmap! { 0 => PPlane::XZ }, ), Err(InconsistentFlowPPlane { node: 0, @@ -550,11 +582,10 @@ mod tests { ); // Violate XZ: 0 in Odd(f(0)) and not in f(0) assert_eq!( - check_definition( - &map! { 0: set!{1} }, - &vec![1, 0], + check_def_geom( + &hashmap! { 0 => hashset!{1} }, &test_utils::graph(&[(0, 1)]), - &map! { 0: PPlane::XZ }, + &hashmap! { 0 => PPlane::XZ }, ), Err(InconsistentFlowPPlane { node: 0, @@ -563,11 +594,10 @@ mod tests { ); // Violate X: 0 not in Odd(f(0)) assert_eq!( - check_definition( - &map! { 0: set!{0} }, - &vec![1, 0], + check_def_geom( + &hashmap! { 0 => hashset!{0} }, &test_utils::graph(&[(0, 1)]), - &map! { 0: PPlane::X }, + &hashmap! { 0 => PPlane::X }, ), Err(InconsistentFlowPPlane { node: 0, @@ -576,11 +606,10 @@ mod tests { ); // Violate Z: 0 not in f(0) assert_eq!( - check_definition( - &map! { 0: set!{1} }, - &vec![1, 0], + check_def_geom( + &hashmap! { 0 => hashset!{1} }, &test_utils::graph(&[(0, 1)]), - &map! { 0: PPlane::Z }, + &hashmap! { 0 => PPlane::Z }, ), Err(InconsistentFlowPPlane { node: 0, @@ -589,11 +618,10 @@ mod tests { ); // Violate Y: 0 in f(0) and 0 in Odd(f(0)) assert_eq!( - check_definition( - &map! { 0: set!{0, 1} }, - &vec![1, 0], + check_def_geom( + &hashmap! { 0 => hashset!{0, 1} }, &test_utils::graph(&[(0, 1)]), - &map! { 0: PPlane::Y }, + &hashmap! { 0 => PPlane::Y }, ), Err(InconsistentFlowPPlane { node: 0, @@ -605,98 +633,98 @@ mod tests { #[test_log::test] fn test_find_case0() { let TestCase { g, iset, oset } = test_utils::CASE0.clone(); - let pplanes = map! {}; + let pplanes = hashmap! {}; let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); - assert_eq!(layer, vec![0, 0]); - verify((f, layer), g, iset, oset, pplanes).unwrap(); + assert_eq!(layers, vec![0, 0]); + verify((f, layers), g, iset, oset, pplanes).unwrap(); } #[test_log::test] fn test_find_case1() { let TestCase { g, iset, oset } = test_utils::CASE1.clone(); - let pplanes = map! { - 0: PPlane::XY, - 1: PPlane::XY, - 2: PPlane::XY, - 3: PPlane::XY + let pplanes = hashmap! { + 0 => PPlane::XY, + 1 => PPlane::XY, + 2 => PPlane::XY, + 3 => PPlane::XY }; let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([1])); assert_eq!(f[&1], Nodes::from([2])); assert_eq!(f[&2], Nodes::from([3])); assert_eq!(f[&3], Nodes::from([4])); - assert_eq!(layer, vec![4, 3, 2, 1, 0]); - verify((f, layer), g, iset, oset, pplanes).unwrap(); + assert_eq!(layers, vec![4, 3, 2, 1, 0]); + verify((f, layers), g, iset, oset, pplanes).unwrap(); } #[test_log::test] fn test_find_case2() { let TestCase { g, iset, oset } = test_utils::CASE2.clone(); - let pplanes = map! { - 0: PPlane::XY, - 1: PPlane::XY, - 2: PPlane::XY, - 3: PPlane::XY + let pplanes = hashmap! { + 0 => PPlane::XY, + 1 => PPlane::XY, + 2 => PPlane::XY, + 3 => PPlane::XY }; let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([2])); assert_eq!(f[&1], Nodes::from([3])); assert_eq!(f[&2], Nodes::from([4])); assert_eq!(f[&3], Nodes::from([5])); - assert_eq!(layer, vec![2, 2, 1, 1, 0, 0]); - verify((f, layer), g, iset, oset, pplanes).unwrap(); + assert_eq!(layers, vec![2, 2, 1, 1, 0, 0]); + verify((f, layers), g, iset, oset, pplanes).unwrap(); } #[test_log::test] fn test_find_case3() { let TestCase { g, iset, oset } = test_utils::CASE3.clone(); - let pplanes = map! { - 0: PPlane::XY, - 1: PPlane::XY, - 2: PPlane::XY + let pplanes = hashmap! { + 0 => PPlane::XY, + 1 => PPlane::XY, + 2 => PPlane::XY }; let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([4, 5])); assert_eq!(f[&1], Nodes::from([3, 4, 5])); assert_eq!(f[&2], Nodes::from([3, 5])); - assert_eq!(layer, vec![1, 1, 1, 0, 0, 0]); - verify((f, layer), g, iset, oset, pplanes).unwrap(); + assert_eq!(layers, vec![1, 1, 1, 0, 0, 0]); + verify((f, layers), g, iset, oset, pplanes).unwrap(); } #[test_log::test] fn test_find_case4() { let TestCase { g, iset, oset } = test_utils::CASE4.clone(); - let pplanes = map! { - 0: PPlane::XY, - 1: PPlane::XY, - 2: PPlane::XZ, - 3: PPlane::YZ + let pplanes = hashmap! { + 0 => PPlane::XY, + 1 => PPlane::XY, + 2 => PPlane::XZ, + 3 => PPlane::YZ }; let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([2])); assert_eq!(f[&1], Nodes::from([5])); assert_eq!(f[&2], Nodes::from([2, 4])); assert_eq!(f[&3], Nodes::from([3])); - assert_eq!(layer, vec![2, 2, 1, 1, 0, 0]); - verify((f, layer), g, iset, oset, pplanes).unwrap(); + assert_eq!(layers, vec![2, 2, 1, 1, 0, 0]); + verify((f, layers), g, iset, oset, pplanes).unwrap(); } #[test_log::test] fn test_find_case5() { let TestCase { g, iset, oset } = test_utils::CASE5.clone(); - let pplanes = map! { - 0: PPlane::XY, - 1: PPlane::XY + let pplanes = hashmap! { + 0 => PPlane::XY, + 1 => PPlane::XY }; assert!(find(g, iset, oset, pplanes).is_none()); } @@ -704,34 +732,34 @@ mod tests { #[test_log::test] fn test_find_case6() { let TestCase { g, iset, oset } = test_utils::CASE6.clone(); - let pplanes = map! { - 0: PPlane::XY, - 1: PPlane::X, - 2: PPlane::XY, - 3: PPlane::X + let pplanes = hashmap! { + 0 => PPlane::XY, + 1 => PPlane::X, + 2 => PPlane::XY, + 3 => PPlane::X }; let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); assert_eq!(f[&0], Nodes::from([1])); assert_eq!(f[&1], Nodes::from([4])); assert_eq!(f[&2], Nodes::from([3])); assert_eq!(f[&3], Nodes::from([2, 4])); - assert_eq!(layer, vec![1, 1, 0, 1, 0]); - verify((f, layer), g, iset, oset, pplanes).unwrap(); + assert_eq!(layers, vec![1, 1, 0, 1, 0]); + verify((f, layers), g, iset, oset, pplanes).unwrap(); } #[test_log::test] fn test_find_case7() { let TestCase { g, iset, oset } = test_utils::CASE7.clone(); - let pplanes = map! { - 0: PPlane::Z, - 1: PPlane::Z, - 2: PPlane::Y, - 3: PPlane::Y + let pplanes = hashmap! { + 0 => PPlane::Z, + 1 => PPlane::Z, + 2 => PPlane::Y, + 3 => PPlane::Y }; let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); // Graphix // assert_eq!(f[&0], Nodes::from([0, 1])); @@ -739,27 +767,27 @@ mod tests { assert_eq!(f[&1], Nodes::from([1])); assert_eq!(f[&2], Nodes::from([2])); assert_eq!(f[&3], Nodes::from([4])); - assert_eq!(layer, vec![1, 0, 0, 1, 0]); - verify((f, layer), g, iset, oset, pplanes).unwrap(); + assert_eq!(layers, vec![1, 0, 0, 1, 0]); + verify((f, layers), g, iset, oset, pplanes).unwrap(); } #[test_log::test] fn test_find_case8() { let TestCase { g, iset, oset } = test_utils::CASE8.clone(); - let pplanes = map! { - 0: PPlane::Z, - 1: PPlane::XZ, - 2: PPlane::Y + let pplanes = hashmap! { + 0 => PPlane::Z, + 1 => PPlane::XZ, + 2 => PPlane::Y }; let flen = g.len() - oset.len(); - let (f, layer) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); + let (f, layers) = find(g.clone(), iset.clone(), oset.clone(), pplanes.clone()).unwrap(); assert_eq!(f.len(), flen); // Graphix // assert_eq!(f[&0], Nodes::from([0, 3, 4])); assert_eq!(f[&0], Nodes::from([0, 2, 4])); assert_eq!(f[&1], Nodes::from([1, 2])); assert_eq!(f[&2], Nodes::from([4])); - assert_eq!(layer, vec![1, 1, 1, 0, 0]); - verify((f, layer), g, iset, oset, pplanes).unwrap(); + assert_eq!(layers, vec![1, 1, 1, 0, 0]); + verify((f, layers), g, iset, oset, pplanes).unwrap(); } } diff --git a/tests/assets.py b/tests/assets.py index 416b33d1..66808297 100644 --- a/tests/assets.py +++ b/tests/assets.py @@ -3,9 +3,15 @@ from __future__ import annotations import dataclasses +from typing import TYPE_CHECKING import networkx as nx -from swiflow.common import FlowResult, GFlowResult, Plane, PPlane +from swiflow.common import Plane, PPlane + +if TYPE_CHECKING: + from swiflow.flow import FlowResult + from swiflow.gflow import GFlowResult + from swiflow.pflow import PFlowResult @dataclasses.dataclass(frozen=True) @@ -13,11 +19,11 @@ class FlowTestCase: g: nx.Graph[int] iset: set[int] oset: set[int] - plane: dict[int, Plane] | None - pplane: dict[int, PPlane] | None + planes: dict[int, Plane] | None + pplanes: dict[int, PPlane] | None flow: FlowResult[int] | None gflow: GFlowResult[int] | None - pflow: GFlowResult[int] | None + pflow: PFlowResult[int] | None # MEMO: DO NOT modify while testing @@ -30,9 +36,9 @@ class FlowTestCase: {1, 2}, None, None, - FlowResult({}, {1: 0, 2: 0}), - GFlowResult({}, {1: 0, 2: 0}), - GFlowResult({}, {1: 0, 2: 0}), + ({}, {1: 0, 2: 0}), + ({}, {1: 0, 2: 0}), + ({}, {1: 0, 2: 0}), ) # 1 - 2 - 3 - 4 - 5 @@ -42,9 +48,9 @@ class FlowTestCase: {5}, None, None, - FlowResult({1: 2, 2: 3, 3: 4, 4: 5}, {1: 4, 2: 3, 3: 2, 4: 1, 5: 0}), - GFlowResult({1: {2}, 2: {3}, 3: {4}, 4: {5}}, {1: 4, 2: 3, 3: 2, 4: 1, 5: 0}), - GFlowResult({1: {2}, 2: {3}, 3: {4}, 4: {5}}, {1: 4, 2: 3, 3: 2, 4: 1, 5: 0}), + ({1: 2, 2: 3, 3: 4, 4: 5}, {1: 4, 2: 3, 3: 2, 4: 1, 5: 0}), + ({1: {2}, 2: {3}, 3: {4}, 4: {5}}, {1: 4, 2: 3, 3: 2, 4: 1, 5: 0}), + ({1: {2}, 2: {3}, 3: {4}, 4: {5}}, {1: 4, 2: 3, 3: 2, 4: 1, 5: 0}), ) @@ -52,14 +58,14 @@ class FlowTestCase: # | # 2 - 4 - 6 CASE2 = FlowTestCase( - nx.Graph([(1, 3), (2, 4), (3, 5), (4, 6)]), + nx.Graph([(1, 3), (2, 4), (3, 5), (4, 6), (3, 4)]), {1, 2}, {5, 6}, None, None, - FlowResult({3: 5, 4: 6, 1: 3, 2: 4}, {1: 2, 2: 2, 3: 1, 4: 1, 5: 0, 6: 0}), - GFlowResult({3: {5}, 4: {6}, 1: {3}, 2: {4}}, {1: 2, 2: 2, 3: 1, 4: 1, 5: 0, 6: 0}), - GFlowResult({3: {5}, 4: {6}, 1: {3}, 2: {4}}, {1: 2, 2: 2, 3: 1, 4: 1, 5: 0, 6: 0}), + ({3: 5, 4: 6, 1: 3, 2: 4}, {1: 2, 2: 2, 3: 1, 4: 1, 5: 0, 6: 0}), + ({3: {5}, 4: {6}, 1: {3}, 2: {4}}, {1: 2, 2: 2, 3: 1, 4: 1, 5: 0, 6: 0}), + ({3: {5}, 4: {6}, 1: {3}, 2: {4}}, {1: 2, 2: 2, 3: 1, 4: 1, 5: 0, 6: 0}), ) # ______ @@ -80,8 +86,8 @@ class FlowTestCase: None, None, None, - GFlowResult({1: {5, 6}, 2: {4, 5, 6}, 3: {4, 6}}, {1: 1, 2: 1, 3: 1, 4: 0, 5: 0, 6: 0}), - GFlowResult({1: {5, 6}, 2: {4, 5, 6}, 3: {4, 6}}, {1: 1, 2: 1, 3: 1, 4: 0, 5: 0, 6: 0}), + ({1: {5, 6}, 2: {4, 5, 6}, 3: {4, 6}}, {1: 1, 2: 1, 3: 1, 4: 0, 5: 0, 6: 0}), + ({1: {5, 6}, 2: {4, 5, 6}, 3: {4, 6}}, {1: 1, 2: 1, 3: 1, 4: 0, 5: 0, 6: 0}), ) # 0 - 1 @@ -96,8 +102,8 @@ class FlowTestCase: {0: Plane.XY, 1: Plane.XY, 2: Plane.XZ, 3: Plane.YZ}, {0: PPlane.XY, 1: PPlane.XY, 2: PPlane.XZ, 3: PPlane.YZ}, None, - GFlowResult({0: {2}, 1: {5}, 2: {2, 4}, 3: {3}}, {0: 2, 1: 2, 2: 1, 3: 1, 4: 0, 5: 0}), - GFlowResult({0: {2}, 1: {5}, 2: {2, 4}, 3: {3}}, {0: 2, 1: 2, 2: 1, 3: 1, 4: 0, 5: 0}), + ({0: {2}, 1: {5}, 2: {2, 4}, 3: {3}}, {0: 2, 1: 2, 2: 1, 3: 1, 4: 0, 5: 0}), + ({0: {2}, 1: {5}, 2: {2, 4}, 3: {3}}, {0: 2, 1: 2, 2: 1, 3: 1, 4: 0, 5: 0}), ) @@ -130,7 +136,7 @@ class FlowTestCase: {0: PPlane.XY, 1: PPlane.X, 2: PPlane.XY, 3: PPlane.X}, None, None, - GFlowResult({0: {1}, 1: {4}, 2: {3}, 3: {2, 4}}, {0: 1, 1: 1, 2: 0, 3: 1, 4: 0}), + ({0: {1}, 1: {4}, 2: {3}, 3: {2, 4}}, {0: 1, 1: 1, 2: 0, 3: 1, 4: 0}), ) # 1 2 3 @@ -144,7 +150,7 @@ class FlowTestCase: {0: PPlane.Z, 1: PPlane.Z, 2: PPlane.Y, 3: PPlane.Y}, None, None, - GFlowResult({0: {0}, 1: {1}, 2: {2}, 3: {4}}, {0: 1, 1: 0, 2: 0, 3: 1, 4: 0}), + ({0: {0}, 1: {1}, 2: {2}, 3: {4}}, {0: 1, 1: 0, 2: 0, 3: 1, 4: 0}), ) # 0 - 1 -- 3 @@ -160,7 +166,47 @@ class FlowTestCase: {0: PPlane.Z, 1: PPlane.XZ, 2: PPlane.Y}, None, None, - GFlowResult({0: {0, 2, 4}, 1: {1, 2}, 2: {4}}, {0: 1, 1: 1, 2: 1, 3: 0, 4: 0}), + ({0: {0, 2, 4}, 1: {1, 2}, 2: {4}}, {0: 1, 1: 1, 2: 1, 3: 0, 4: 0}), +) + +# 1 - 2 +# +# 3 - 4 +CASE9 = FlowTestCase( + nx.Graph([(1, 2), (3, 4)]), + {1}, + {2}, + None, + None, + # None exists as 3 - 4 is isolated + None, + None, + None, ) -CASES: tuple[FlowTestCase, ...] = (CASE0, CASE1, CASE2, CASE3, CASE4, CASE5, CASE6, CASE7, CASE8) +# 1 - 2 - 3 +CASE10 = FlowTestCase( + nx.Graph([(1, 2), (2, 3)]), + {1}, + set(), + None, + None, + # None exists as oset is empty + None, + None, + None, +) + +CASES: tuple[FlowTestCase, ...] = ( + CASE0, + CASE1, + CASE2, + CASE3, + CASE4, + CASE5, + CASE6, + CASE7, + CASE8, + CASE9, + CASE10, +) diff --git a/tests/test_common.py b/tests/test_common.py index 96475a38..2b7023e7 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -2,12 +2,14 @@ import networkx as nx import pytest -from swiflow import _common +from swiflow import _common, common from swiflow._common import IndexMap from swiflow._impl import FlowValidationMessage from swiflow.common import Plane, PPlane from typing_extensions import Never +from tests.assets import CASE3 + def test_check_graph_ng_g() -> None: with pytest.raises(TypeError): @@ -98,7 +100,7 @@ def test_decode_err(self, fx_indexmap: IndexMap[str], emsg: ValueError) -> None: def test_encode_layer_missing(self, fx_indexmap: IndexMap[str]) -> None: with pytest.raises(ValueError, match=r"Layers must be specified for all nodes\."): - fx_indexmap.encode_layer({"a": 0, "b": 1}) + fx_indexmap.encode_layers({"a": 0, "b": 1}) def test_ecatch(self, fx_indexmap: IndexMap[str]) -> None: def dummy_ok(x: int) -> int: @@ -110,3 +112,35 @@ def dummy_ng(_: int) -> Never: assert fx_indexmap.ecatch(dummy_ok, 1) == 1 with pytest.raises(ValueError, match=r"Zero-layer node a outside output nodes\."): fx_indexmap.ecatch(dummy_ng, 1) + + +def test_odd_neighbors() -> None: + g = CASE3.g + for u in g.nodes: + assert _common.odd_neighbors(g, {u}) == set(g.neighbors(u)) + assert _common.odd_neighbors(g, {1, 4}) == {1, 2, 4, 6} + assert _common.odd_neighbors(g, {2, 5}) == {2, 3, 4, 5, 6} + assert _common.odd_neighbors(g, {3, 6}) == {1, 2, 3, 5, 6} + assert _common.odd_neighbors(g, {1, 2, 3}) == {6} + assert _common.odd_neighbors(g, {4, 5, 6}) == {2} + assert _common.odd_neighbors(g, {1, 2, 3, 4, 5, 6}) == {2, 6} + + +class TestInferLayer: + def test_line(self) -> None: + g: nx.Graph[int] = nx.Graph([(0, 1), (1, 2), (2, 3)]) + flow = {0: {1}, 1: {2}, 2: {3}} + layers = common.infer_layers(g, flow) + assert layers == {0: 3, 1: 2, 2: 1, 3: 0} + + def test_dag(self) -> None: + g: nx.Graph[int] = nx.Graph([(0, 2), (0, 3), (1, 2), (1, 3)]) + flow = {0: {2, 3}, 1: {2, 3}} + layers = common.infer_layers(g, flow) + assert layers == {0: 1, 1: 1, 2: 0, 3: 0} + + def test_cycle(self) -> None: + g: nx.Graph[int] = nx.Graph([(0, 1), (1, 2), (2, 0)]) + flow = {0: {1}, 1: {2}, 2: {0}} + with pytest.raises(ValueError, match=r".*constraints.*"): + common.infer_layers(g, flow) diff --git a/tests/test_flow.py b/tests/test_flow.py index d536c93b..3c6cbb74 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -1,14 +1,24 @@ from __future__ import annotations import pytest -from swiflow import flow +from swiflow import common, flow from tests.assets import CASES, FlowTestCase @pytest.mark.parametrize("c", CASES) -def test_flow_graphix(c: FlowTestCase) -> None: +def test_flow(c: FlowTestCase) -> None: result = flow.find(c.g, c.iset, c.oset) assert result == c.flow if result is not None: flow.verify(result, c.g, c.iset, c.oset) + + +@pytest.mark.parametrize("c", CASES) +def test_infer_verify(c: FlowTestCase) -> None: + if c.flow is None: + pytest.skip() + f, _ = c.flow + flow.verify(f, c.g, c.iset, c.oset) + layers = common.infer_layers(c.g, f) + flow.verify((f, layers), c.g, c.iset, c.oset) diff --git a/tests/test_gflow.py b/tests/test_gflow.py index 56c9194b..5da414e4 100644 --- a/tests/test_gflow.py +++ b/tests/test_gflow.py @@ -2,18 +2,18 @@ import networkx as nx import pytest -from swiflow import gflow +from swiflow import common, gflow from swiflow.common import Plane from tests.assets import CASES, FlowTestCase @pytest.mark.parametrize("c", CASES) -def test_gflow_graphix(c: FlowTestCase) -> None: - result = gflow.find(c.g, c.iset, c.oset, c.plane) +def test_gflow(c: FlowTestCase) -> None: + result = gflow.find(c.g, c.iset, c.oset, planes=c.planes) assert result == c.gflow if result is not None: - gflow.verify(result, c.g, c.iset, c.oset, c.plane) + gflow.verify(result, c.g, c.iset, c.oset, planes=c.planes) def test_gflow_redundant() -> None: @@ -22,4 +22,14 @@ def test_gflow_redundant() -> None: oset = {1} planes = {0: Plane.XY, 1: Plane.XY} with pytest.raises(ValueError, match=r".*Excessive measurement planes specified.*"): - gflow.find(g, iset, oset, planes) + gflow.find(g, iset, oset, planes=planes) + + +@pytest.mark.parametrize("c", CASES) +def test_infer_verify(c: FlowTestCase) -> None: + if c.gflow is None: + pytest.skip() + f, _ = c.gflow + gflow.verify(f, c.g, c.iset, c.oset, planes=c.planes) + layers = common.infer_layers(c.g, f) + gflow.verify((f, layers), c.g, c.iset, c.oset, planes=c.planes) diff --git a/tests/test_pflow.py b/tests/test_pflow.py index 92ab352f..53a9e2ef 100644 --- a/tests/test_pflow.py +++ b/tests/test_pflow.py @@ -2,7 +2,7 @@ import networkx as nx import pytest -from swiflow import pflow +from swiflow import common, pflow from swiflow.common import PPlane from tests.assets import CASES, FlowTestCase @@ -10,11 +10,11 @@ @pytest.mark.filterwarnings("ignore:No Pauli measurement found") @pytest.mark.parametrize("c", CASES) -def test_pflow_graphix(c: FlowTestCase) -> None: - result = pflow.find(c.g, c.iset, c.oset, c.pplane) +def test_pflow(c: FlowTestCase) -> None: + result = pflow.find(c.g, c.iset, c.oset, pplanes=c.pplanes) assert result == c.pflow if result is not None: - pflow.verify(result, c.g, c.iset, c.oset, c.pplane) + pflow.verify(result, c.g, c.iset, c.oset, pplanes=c.pplanes) def test_pflow_nopauli() -> None: @@ -23,7 +23,7 @@ def test_pflow_nopauli() -> None: oset = {1} planes = {0: PPlane.XY} with pytest.warns(UserWarning, match=r".*No Pauli measurement found\. Use gflow\.find instead\..*"): - pflow.find(g, iset, oset, planes) + pflow.find(g, iset, oset, pplanes=planes) def test_pflow_redundant() -> None: @@ -32,4 +32,14 @@ def test_pflow_redundant() -> None: oset = {1} planes = {0: PPlane.X, 1: PPlane.Y} with pytest.raises(ValueError, match=r".*Excessive measurement planes specified.*"): - pflow.find(g, iset, oset, planes) + pflow.find(g, iset, oset, pplanes=planes) + + +@pytest.mark.parametrize("c", CASES) +def test_infer_verify(c: FlowTestCase) -> None: + if c.pflow is None: + pytest.skip() + f, _ = c.pflow + pflow.verify(f, c.g, c.iset, c.oset, pplanes=c.pplanes) + layers = common.infer_layers(c.g, f, pplanes=c.pplanes) + pflow.verify((f, layers), c.g, c.iset, c.oset, pplanes=c.pplanes)