ai-content-maker/.venv/Lib/site-packages/networkx/algorithms/d_separation.py

143 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Algorithm for testing d-separation in DAGs.
*d-separation* is a test for conditional independence in probability
distributions that can be factorized using DAGs. It is a purely
graphical test that uses the underlying graph and makes no reference
to the actual distribution parameters. See [1]_ for a formal
definition.
The implementation is based on the conceptually simple linear time
algorithm presented in [2]_. Refer to [3]_, [4]_ for a couple of
alternative algorithms.
Examples
--------
>>>
>>> # HMM graph with five states and observation nodes
... g = nx.DiGraph()
>>> g.add_edges_from(
... [
... ("S1", "S2"),
... ("S2", "S3"),
... ("S3", "S4"),
... ("S4", "S5"),
... ("S1", "O1"),
... ("S2", "O2"),
... ("S3", "O3"),
... ("S4", "O4"),
... ("S5", "O5"),
... ]
... )
>>>
>>> # states/obs before 'S3' are d-separated from states/obs after 'S3'
... nx.d_separated(g, {"S1", "S2", "O1", "O2"}, {"S4", "S5", "O4", "O5"}, {"S3"})
True
References
----------
.. [1] Pearl, J. (2009). Causality. Cambridge: Cambridge University Press.
.. [2] Darwiche, A. (2009). Modeling and reasoning with Bayesian networks.
Cambridge: Cambridge University Press.
.. [3] Shachter, R. D. (1998).
Bayes-ball: rational pastime (for determining irrelevance and requisite
information in belief networks and influence diagrams).
In , Proceedings of the Fourteenth Conference on Uncertainty in Artificial
Intelligence (pp. 480487).
San Francisco, CA, USA: Morgan Kaufmann Publishers Inc.
.. [4] Koller, D., & Friedman, N. (2009).
Probabilistic graphical models: principles and techniques. The MIT Press.
"""
from collections import deque
import networkx as nx
from networkx.utils import UnionFind, not_implemented_for
__all__ = ["d_separated"]
@not_implemented_for("undirected")
def d_separated(G, x, y, z):
"""
Return whether node sets ``x`` and ``y`` are d-separated by ``z``.
Parameters
----------
G : graph
A NetworkX DAG.
x : set
First set of nodes in ``G``.
y : set
Second set of nodes in ``G``.
z : set
Set of conditioning nodes in ``G``. Can be empty set.
Returns
-------
b : bool
A boolean that is true if ``x`` is d-separated from ``y`` given ``z`` in ``G``.
Raises
------
NetworkXError
The *d-separation* test is commonly used with directed
graphical models which are acyclic. Accordingly, the algorithm
raises a :exc:`NetworkXError` if the input graph is not a DAG.
NodeNotFound
If any of the input nodes are not found in the graph,
a :exc:`NodeNotFound` exception is raised.
"""
if not nx.is_directed_acyclic_graph(G):
raise nx.NetworkXError("graph should be directed acyclic")
union_xyz = x.union(y).union(z)
if any(n not in G.nodes for n in union_xyz):
raise nx.NodeNotFound("one or more specified nodes not found in the graph")
G_copy = G.copy()
# transform the graph by removing leaves that are not in x | y | z
# until no more leaves can be removed.
leaves = deque([n for n in G_copy.nodes if G_copy.out_degree[n] == 0])
while len(leaves) > 0:
leaf = leaves.popleft()
if leaf not in union_xyz:
for p in G_copy.predecessors(leaf):
if G_copy.out_degree[p] == 1:
leaves.append(p)
G_copy.remove_node(leaf)
# transform the graph by removing outgoing edges from the
# conditioning set.
edges_to_remove = list(G_copy.out_edges(z))
G_copy.remove_edges_from(edges_to_remove)
# use disjoint-set data structure to check if any node in `x`
# occurs in the same weakly connected component as a node in `y`.
disjoint_set = UnionFind(G_copy.nodes())
for component in nx.weakly_connected_components(G_copy):
disjoint_set.union(*component)
disjoint_set.union(*x)
disjoint_set.union(*y)
if x and y and disjoint_set[next(iter(x))] == disjoint_set[next(iter(y))]:
return False
else:
return True