ai-content-maker/.venv/Lib/site-packages/networkx/algorithms/d_separation.py

143 lines
4.1 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
"""
Algorithm for testing d-separation in DAGs.
*d-separation* is a test for conditional independence in probability
distributions that can be factorized using DAGs. It is a purely
graphical test that uses the underlying graph and makes no reference
to the actual distribution parameters. See [1]_ for a formal
definition.
The implementation is based on the conceptually simple linear time
algorithm presented in [2]_. Refer to [3]_, [4]_ for a couple of
alternative algorithms.
Examples
--------
>>>
>>> # HMM graph with five states and observation nodes
... g = nx.DiGraph()
>>> g.add_edges_from(
... [
... ("S1", "S2"),
... ("S2", "S3"),
... ("S3", "S4"),
... ("S4", "S5"),
... ("S1", "O1"),
... ("S2", "O2"),
... ("S3", "O3"),
... ("S4", "O4"),
... ("S5", "O5"),
... ]
... )
>>>
>>> # states/obs before 'S3' are d-separated from states/obs after 'S3'
... nx.d_separated(g, {"S1", "S2", "O1", "O2"}, {"S4", "S5", "O4", "O5"}, {"S3"})
True
References
----------
.. [1] Pearl, J. (2009). Causality. Cambridge: Cambridge University Press.
.. [2] Darwiche, A. (2009). Modeling and reasoning with Bayesian networks.
Cambridge: Cambridge University Press.
.. [3] Shachter, R. D. (1998).
Bayes-ball: rational pastime (for determining irrelevance and requisite
information in belief networks and influence diagrams).
In , Proceedings of the Fourteenth Conference on Uncertainty in Artificial
Intelligence (pp. 480487).
San Francisco, CA, USA: Morgan Kaufmann Publishers Inc.
.. [4] Koller, D., & Friedman, N. (2009).
Probabilistic graphical models: principles and techniques. The MIT Press.
"""
from collections import deque
import networkx as nx
from networkx.utils import UnionFind, not_implemented_for
__all__ = ["d_separated"]
@not_implemented_for("undirected")
def d_separated(G, x, y, z):
"""
Return whether node sets ``x`` and ``y`` are d-separated by ``z``.
Parameters
----------
G : graph
A NetworkX DAG.
x : set
First set of nodes in ``G``.
y : set
Second set of nodes in ``G``.
z : set
Set of conditioning nodes in ``G``. Can be empty set.
Returns
-------
b : bool
A boolean that is true if ``x`` is d-separated from ``y`` given ``z`` in ``G``.
Raises
------
NetworkXError
The *d-separation* test is commonly used with directed
graphical models which are acyclic. Accordingly, the algorithm
raises a :exc:`NetworkXError` if the input graph is not a DAG.
NodeNotFound
If any of the input nodes are not found in the graph,
a :exc:`NodeNotFound` exception is raised.
"""
if not nx.is_directed_acyclic_graph(G):
raise nx.NetworkXError("graph should be directed acyclic")
union_xyz = x.union(y).union(z)
if any(n not in G.nodes for n in union_xyz):
raise nx.NodeNotFound("one or more specified nodes not found in the graph")
G_copy = G.copy()
# transform the graph by removing leaves that are not in x | y | z
# until no more leaves can be removed.
leaves = deque([n for n in G_copy.nodes if G_copy.out_degree[n] == 0])
while len(leaves) > 0:
leaf = leaves.popleft()
if leaf not in union_xyz:
for p in G_copy.predecessors(leaf):
if G_copy.out_degree[p] == 1:
leaves.append(p)
G_copy.remove_node(leaf)
# transform the graph by removing outgoing edges from the
# conditioning set.
edges_to_remove = list(G_copy.out_edges(z))
G_copy.remove_edges_from(edges_to_remove)
# use disjoint-set data structure to check if any node in `x`
# occurs in the same weakly connected component as a node in `y`.
disjoint_set = UnionFind(G_copy.nodes())
for component in nx.weakly_connected_components(G_copy):
disjoint_set.union(*component)
disjoint_set.union(*x)
disjoint_set.union(*y)
if x and y and disjoint_set[next(iter(x))] == disjoint_set[next(iter(y))]:
return False
else:
return True