Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pywhy_graphs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._version import __version__ # noqa: F401
from .classes import (
ADMG,
CG,
CPDAG,
PAG,
AugmentedGraph,
Expand Down
1 change: 1 addition & 0 deletions pywhy_graphs/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .cyclic import * # noqa: F403
from .generic import * # noqa: F403
from .pag import * # noqa: F403
from .cg import * # noqa: F403
77 changes: 77 additions & 0 deletions pywhy_graphs/algorithms/cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from collections import deque

import networkx as nx
import numpy as np

from pywhy_graphs import CG

__all__ = ["is_valid_cg"]


def is_valid_cg(graph: CG):
"""
Checks if a supplied chain graph is valid.

This implements the original defintion of a (Lauritzen Wermuth Frydenberg) chain graph as
presented in [1]_.

Define a cycle as a series of nodes X_1 -o X_2 ... X_n -o X_1 where the edges may be directed or
undirected. Note that directed edges in a cycle must all be aligned in the same direction. A
chain graph may only contain cycles consisting of only undirected edges. Equivalently, a chain
graph does not contain any cycles with one or more directed edges.

Parameters
__________
graph : CG
The graph.

Returns
_______
is_valid : bool
Whether supplied `graph` is a valid chain graph.

References
----------
.. [1] Frydenberg, Morten. “The Chain Graph Markov Property.” Scandinavian Journal of
Statistics, vol. 17, no. 4, 1990, pp. 333–53. JSTOR, http://www.jstor.org/stable/4616181.
Accessed 15 Apr. 2023.


"""

# Check if directed edges are acyclic
undirected_edge_name = graph.undirected_edge_name
directed_edge_name = graph.directed_edge_name
visited = set()
all_nodes = graph.nodes()
G_undirected = graph.get_graphs(edge_type=undirected_edge_name)
G_directed = graph.get_graphs(edge_type=directed_edge_name)
# TODO: keep track of paths as first class in queue
for v in all_nodes:
print("v:", v)
seen = {v}
queue = deque([z for _, z in G_directed.out_edges(nbunch=v)])
if v in visited:

continue
while queue:
print(queue)
x = queue.popleft()
print("pop", x)
print("seen", seen)
if x in seen:
print("appeared in seen", x)
return False

seen.add(x)

for _, node in G_directed.out_edges(nbunch=x):
print("add out edge", node)
queue.append(node)
for nbr in G_undirected.neighbors(x):
print("add nbr edge", nbr)
queue.append(nbr)

visited.add(v)

return True
108 changes: 108 additions & 0 deletions pywhy_graphs/algorithms/tests/test_cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from pywhy_graphs import CG
from pywhy_graphs.algorithms import is_valid_cg
import pytest


@pytest.fixture
def cg_simple_partially_directed_cycle():
graph = CG()
graph.add_nodes_from(["A", "B", "C", "D"])
graph.add_edge("A", "B", graph.directed_edge_name)
graph.add_edge("D", "C", graph.directed_edge_name)
graph.add_edge("B", "D", graph.undirected_edge_name)
graph.add_edge("A", "C", graph.undirected_edge_name)

return graph


@pytest.fixture
def cg_multiple_blocks_partially_directed_cycle():

graph = CG()
graph.add_nodes_from(["A", "B", "C", "D", "E", "F", "G"])
graph.add_edge("A", "B", graph.directed_edge_name)
graph.add_edge("D", "C", graph.directed_edge_name)
graph.add_edge("B", "D", graph.undirected_edge_name)
graph.add_edge("A", "C", graph.undirected_edge_name)
graph.add_edge("E", "F", graph.undirected_edge_name)
graph.add_edge("F", "G", graph.undirected_edge_name)
graph.add_edge("G", "E", graph.undirected_edge_name)

return graph


@pytest.fixture
def square_graph():
graph = CG()
graph.add_nodes_from(["A", "B", "C", "D"])
graph.add_edge("A", "B", graph.undirected_edge_name)
graph.add_edge("B", "C", graph.undirected_edge_name)
graph.add_edge("C", "D", graph.undirected_edge_name)
graph.add_edge("C", "A", graph.undirected_edge_name)

return graph


@pytest.fixture
def fig_g1_frydenberg():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add a docstring and a link to the reference for all the fixtures that come from some paper, so its back traceable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes that's a good idea, i have the reference in the function but not the test

graph = CG()
graph.add_nodes_from(["a", "b", "g", "m", "d"])
graph.add_edge("a", "b", graph.undirected_edge_name)
graph.add_edge("b", "g", graph.directed_edge_name)
graph.add_edge("g", "d", graph.undirected_edge_name)
graph.add_edge("d", "m", graph.undirected_edge_name)
graph.add_edge("a", "m", graph.directed_edge_name)

return graph


@pytest.fixture
def fig_g2_frydenberg():
graph = CG()
graph.add_nodes_from(["b", "g", "d", "m", "a"])
graph.add_edge("a", "m", graph.directed_edge_name)
graph.add_edge("m", "g", graph.undirected_edge_name)
graph.add_edge("m", "d", graph.directed_edge_name)
graph.add_edge("g", "d", graph.directed_edge_name)
graph.add_edge("b", "g", graph.directed_edge_name)

return graph


@pytest.fixture
def fig_g3_frydenberg():
graph = CG()
graph.add_nodes_from(["a", "b", "g"])
graph.add_edge("b", "a", graph.undirected_edge_name)
graph.add_edge("a", "g", graph.undirected_edge_name)
graph.add_edge("b", "g", graph.directed_edge_name)

return graph


@pytest.mark.parametrize(
"G",
[
"cg_simple_partially_directed_cycle",
"cg_multiple_blocks_partially_directed_cycle",
"fig_g3_frydenberg",
],
)
def test_graphs_are_not_valid_cg(G, request):
graph = request.getfixturevalue(G)

assert not is_valid_cg(graph)


@pytest.mark.parametrize(
"G",
[
"square_graph",
"fig_g1_frydenberg",
"fig_g2_frydenberg",
],
)
def test_graphs_are_valid_cg(G, request):
graph = request.getfixturevalue(G)

assert is_valid_cg(graph)
2 changes: 1 addition & 1 deletion pywhy_graphs/classes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from . import timeseries
from .admg import ADMG
from .cpdag import CPDAG
from .diungraph import CG, CPDAG
from .intervention import IPAG, AugmentedGraph, PsiPAG
from .pag import PAG
from .timeseries import (
Expand Down
Loading