ai-content-maker/.venv/Lib/site-packages/pynndescent/optimal_transport.py

1195 lines
33 KiB
Python

#############################################################################
# This code draws from the Python Optimal Transport version of the
# network simplex algorithm, which in turn was adapted from the LEMON
# library. The copyrights/comment blocks for those are preserved below.
# The Python/Numba implementation was adapted by Leland McInnes (2020).
#
# * This file has been adapted by Nicolas Bonneel (2013),
# * from network_simplex.h from LEMON, a generic C++ optimization library,
# * to implement a lightweight network simplex for mass transport, more
# * memory efficient that the original file. A previous version of this file
# * is used as part of the Displacement Interpolation project,
# * Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/
# *
# *
# **** Original file Copyright Notice :
# *
# * Copyright (C) 2003-2010
# * Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport
# * (Egervary Research Group on Combinatorial Optimization, EGRES).
# *
# * Permission to use, modify and distribute this software is granted
# * provided that this copyright notice appears in all copies. For
# * precise terms see the accompanying LICENSE file.
# *
# * This software is provided "AS IS" with no warranty of any kind,
# * express or implied, and with no claim as to its suitability for any
# * purpose.
import numpy as np
import numba
from collections import namedtuple
from enum import Enum, IntEnum
_mock_identity = np.eye(2, dtype=np.float32)
_mock_ones = np.ones(2, dtype=np.float32)
_dummy_cost = np.zeros((2, 2), dtype=np.float64)
# Accuracy tolerance and net supply tolerance
EPSILON = 2.2204460492503131e-15
NET_SUPPLY_ERROR_TOLERANCE = 1e-8
## Defaults to double for everythig in POT
INFINITY = np.finfo(np.float64).max
MAX = np.finfo(np.float64).max
dummy_cost = np.zeros((2, 2), dtype=np.float64)
# Invalid Arc num
INVALID = -1
# Problem Status
class ProblemStatus(Enum):
OPTIMAL = 0
MAX_ITER_REACHED = 1
UNBOUNDED = 2
INFEASIBLE = 3
# Arc States
class ArcState(IntEnum):
STATE_UPPER = -1
STATE_TREE = 0
STATE_LOWER = 1
SpanningTree = namedtuple(
"SpanningTree",
[
"parent", # int array
"pred", # int array
"thread", # int array
"rev_thread", # int array
"succ_num", # int array
"last_succ", # int array
"forward", # bool array
"state", # state array
"root", # int
],
)
DiGraph = namedtuple(
"DiGraph",
[
"n_nodes", # int
"n_arcs", # int
"n", # int
"m", # int
"use_arc_mixing", # bool
"num_total_big_subsequence_numbers", # int
"subsequence_length", # int
"num_big_subsequences", # int
"mixing_coeff",
],
)
NodeArcData = namedtuple(
"NodeArcData",
[
"cost", # double array
"supply", # double array
"flow", # double array
"pi", # double array
"source", # unsigned int array
"target", # unsigned int array
],
)
LeavingArcData = namedtuple(
"LeavingArcData", ["u_in", "u_out", "v_in", "delta", "change"]
)
# Just reproduce a simpler version of numpy isclose (not numba supported yet)
@numba.njit()
def isclose(a, b, rtol=1.0e-5, atol=EPSILON):
diff = np.abs(a - b)
return diff <= (atol + rtol * np.abs(b))
# locals: c, min, e, cnt, a
# modifies _in_arc, _next_arc,
@numba.njit(locals={"a": numba.uint32, "e": numba.uint32})
def find_entering_arc(
pivot_block_size,
pivot_next_arc,
search_arc_num,
state_vector,
node_arc_data,
in_arc,
):
min = 0
cnt = pivot_block_size
# Pull from tuple for quick reference
cost = node_arc_data.cost
pi = node_arc_data.pi
source = node_arc_data.source
target = node_arc_data.target
for e in range(pivot_next_arc, search_arc_num):
c = state_vector[e] * (cost[e] + pi[source[e]] - pi[target[e]])
if c < min:
min = c
in_arc = e
cnt -= 1
if cnt == 0:
if np.fabs(pi[source[in_arc]]) > np.fabs(pi[target[in_arc]]):
a = np.fabs(pi[source[in_arc]])
else:
a = np.fabs(pi[target[in_arc]])
if a <= np.fabs(cost[in_arc]):
a = np.fabs(cost[in_arc])
if min < -(EPSILON * a):
pivot_next_arc = e
return in_arc, pivot_next_arc
else:
cnt = pivot_block_size
for e in range(pivot_next_arc):
c = state_vector[e] * (cost[e] + pi[source[e]] - pi[target[e]])
if c < min:
min = c
in_arc = e
cnt -= 1
if cnt == 0:
if np.fabs(pi[source[in_arc]]) > np.fabs(pi[target[in_arc]]):
a = np.fabs(pi[source[in_arc]])
else:
a = np.fabs(pi[target[in_arc]])
if a <= np.fabs(cost[in_arc]):
a = np.fabs(cost[in_arc])
if min < -(EPSILON * a):
pivot_next_arc = e
return in_arc, pivot_next_arc
else:
cnt = pivot_block_size
# assert(pivot_block.next_arc[0] == 0 or e == pivot_block.next_arc[0] - 1)
if np.fabs(pi[source[in_arc]]) > np.fabs(pi[target[in_arc]]):
a = np.fabs(pi[source[in_arc]])
else:
a = np.fabs(pi[target[in_arc]])
if a <= np.fabs(cost[in_arc]):
a = np.fabs(cost[in_arc])
if min >= -(EPSILON * a):
return -1, 0
return in_arc, pivot_next_arc
# Find the join node
# Operates with graph (_source, _target) and MST (_succ_num, _parent, in_arc) data
# locals: u, v
# modifies: join
@numba.njit(locals={"u": numba.types.uint16, "v": numba.types.uint16})
def find_join_node(source, target, succ_num, parent, in_arc):
u = source[in_arc]
v = target[in_arc]
while u != v:
if succ_num[u] < succ_num[v]:
u = parent[u]
else:
v = parent[v]
join = u
return join
# Find the leaving arc of the cycle and returns true if the
# leaving arc is not the same as the entering arc
# locals: first, second, result, d, e
# modifies: u_in, v_in, u_out, delta
@numba.njit(
locals={
"u": numba.uint16,
"u_in": numba.uint16,
"u_out": numba.uint16,
"v_in": numba.uint16,
"first": numba.uint16,
"second": numba.uint16,
"result": numba.uint8,
"in_arc": numba.uint32,
}
)
def find_leaving_arc(join, in_arc, node_arc_data, spanning_tree):
source = node_arc_data.source
target = node_arc_data.target
flow = node_arc_data.flow
state = spanning_tree.state
forward = spanning_tree.forward
pred = spanning_tree.pred
parent = spanning_tree.parent
u_out = -1 # May not be set, but we need to return something?
# Initialize first and second nodes according to the direction
# of the cycle
if state[in_arc] == ArcState.STATE_LOWER:
first = source[in_arc]
second = target[in_arc]
else:
first = target[in_arc]
second = source[in_arc]
delta = INFINITY
result = 0
# Search the cycle along the path form the first node to the root
u = first
while u != join:
e = pred[u]
if forward[u]:
d = flow[e]
else:
d = INFINITY
if d < delta:
delta = d
u_out = u
result = 1
u = parent[u]
# Search the cycle along the path form the second node to the root
u = second
while u != join:
e = pred[u]
if forward[u]:
d = INFINITY
else:
d = flow[e]
if d <= delta:
delta = d
u_out = u
result = 2
u = parent[u]
if result == 1:
u_in = first
v_in = second
else:
u_in = second
v_in = first
return LeavingArcData(u_in, u_out, v_in, delta, result != 0)
# Change _flow and _state vectors
# locals: val, u
# modifies: _state, _flow
@numba.njit(locals={"u": numba.uint16, "in_arc": numba.uint32, "val": numba.float64})
def update_flow(join, leaving_arc_data, node_arc_data, spanning_tree, in_arc):
source = node_arc_data.source
target = node_arc_data.target
flow = node_arc_data.flow
state = spanning_tree.state
pred = spanning_tree.pred
parent = spanning_tree.parent
forward = spanning_tree.forward
# Augment along the cycle
if leaving_arc_data.delta > 0:
val = state[in_arc] * leaving_arc_data.delta
flow[in_arc] += val
u = source[in_arc]
while u != join:
if forward[u]:
flow[pred[u]] -= val
else:
flow[pred[u]] += val
u = parent[u]
u = target[in_arc]
while u != join:
if forward[u]:
flow[pred[u]] += val
else:
flow[pred[u]] -= val
u = parent[u]
# Update the state of the entering and leaving arcs
if leaving_arc_data.change:
state[in_arc] = ArcState.STATE_TREE
if flow[pred[leaving_arc_data.u_out]] == 0:
state[pred[leaving_arc_data.u_out]] = ArcState.STATE_LOWER
else:
state[pred[leaving_arc_data.u_out]] = ArcState.STATE_UPPER
else:
state[in_arc] = -state[in_arc]
# Update the tree structure
# locals: u, w, old_rev_thread, old_succ_num, old_last_succ, tmp_sc, tmp_ls
# more locals: up_limit_in, up_limit_out, _dirty_revs
# modifies: v_out, _thread, _rev_thread, _parent, _last_succ,
# modifies: _pred, _forward, _succ_num
@numba.njit(
locals={
"u": numba.int32,
"w": numba.int32,
"u_in": numba.uint16,
"u_out": numba.uint16,
"v_in": numba.uint16,
"right": numba.uint16,
"stem": numba.uint16,
"new_stem": numba.uint16,
"par_stem": numba.uint16,
"in_arc": numba.uint32,
}
)
def update_spanning_tree(spanning_tree, leaving_arc_data, join, in_arc, source):
parent = spanning_tree.parent
thread = spanning_tree.thread
rev_thread = spanning_tree.rev_thread
succ_num = spanning_tree.succ_num
last_succ = spanning_tree.last_succ
forward = spanning_tree.forward
pred = spanning_tree.pred
u_out = leaving_arc_data.u_out
u_in = leaving_arc_data.u_in
v_in = leaving_arc_data.v_in
old_rev_thread = rev_thread[u_out]
old_succ_num = succ_num[u_out]
old_last_succ = last_succ[u_out]
v_out = parent[u_out]
u = last_succ[u_in] # the last successor of u_in
right = thread[u] # the node after it
# Handle the case when old_rev_thread equals to v_in
# (it also means that join and v_out coincide)
if old_rev_thread == v_in:
last = thread[last_succ[u_out]]
else:
last = thread[v_in]
# Update _thread and _parent along the stem nodes (i.e. the nodes
# between u_in and u_out, whose parent have to be changed)
thread[v_in] = stem = u_in
dirty_revs = []
dirty_revs.append(v_in)
par_stem = v_in
while stem != u_out:
# Insert the next stem node into the thread list
new_stem = parent[stem]
thread[u] = new_stem
dirty_revs.append(u)
# Remove the subtree of stem from the thread list
w = rev_thread[stem]
thread[w] = right
rev_thread[right] = w
# Change the parent node and shift stem nodes
parent[stem] = par_stem
par_stem = stem
stem = new_stem
# Update u and right
if last_succ[stem] == last_succ[par_stem]:
u = rev_thread[par_stem]
else:
u = last_succ[stem]
right = thread[u]
parent[u_out] = par_stem
thread[u] = last
rev_thread[last] = u
last_succ[u_out] = u
# Remove the subtree of u_out from the thread list except for
# the case when old_rev_thread equals to v_in
# (it also means that join and v_out coincide)
if old_rev_thread != v_in:
thread[old_rev_thread] = right
rev_thread[right] = old_rev_thread
# Update _rev_thread using the new _thread values
for i in range(len(dirty_revs)):
u = dirty_revs[i]
rev_thread[thread[u]] = u
# Update _pred, _forward, _last_succ and _succ_num for the
# stem nodes from u_out to u_in
tmp_sc = 0
tmp_ls = last_succ[u_out]
u = u_out
while u != u_in:
w = parent[u]
pred[u] = pred[w]
forward[u] = not forward[w]
tmp_sc += succ_num[u] - succ_num[w]
succ_num[u] = tmp_sc
last_succ[w] = tmp_ls
u = w
pred[u_in] = in_arc
forward[u_in] = u_in == source[in_arc]
succ_num[u_in] = old_succ_num
# Set limits for updating _last_succ form v_in and v_out
# towards the root
up_limit_in = -1
up_limit_out = -1
if last_succ[join] == v_in:
up_limit_out = join
else:
up_limit_in = join
# Update _last_succ from v_in towards the root
u = v_in
while u != up_limit_in and last_succ[u] == v_in:
last_succ[u] = last_succ[u_out]
u = parent[u]
# Update _last_succ from v_out towards the root
if join != old_rev_thread and v_in != old_rev_thread:
u = v_out
while u != up_limit_out and last_succ[u] == old_last_succ:
last_succ[u] = old_rev_thread
u = parent[u]
else:
u = v_out
while u != up_limit_out and last_succ[u] == old_last_succ:
last_succ[u] = last_succ[u_out]
u = parent[u]
# Update _succ_num from v_in to join
u = v_in
while u != join:
succ_num[u] += old_succ_num
u = parent[u]
# Update _succ_num from v_out to join
u = v_out
while u != join:
succ_num[u] -= old_succ_num
u = parent[u]
# Update potentials
# locals: sigma, end
# modifies: _pi
@numba.njit(
fastmath=True,
inline="always",
locals={"u": numba.uint16, "u_in": numba.uint16, "v_in": numba.uint16},
)
def update_potential(leaving_arc_data, pi, cost, spanning_tree):
thread = spanning_tree.thread
pred = spanning_tree.pred
forward = spanning_tree.forward
last_succ = spanning_tree.last_succ
u_in = leaving_arc_data.u_in
v_in = leaving_arc_data.v_in
if forward[u_in]:
sigma = pi[v_in] - pi[u_in] - cost[pred[u_in]]
else:
sigma = pi[v_in] - pi[u_in] + cost[pred[u_in]]
# Update potentials in the subtree, which has been moved
end = thread[last_succ[u_in]]
u = u_in
while u != end:
pi[u] += sigma
u = thread[u]
# If we have mixed arcs (for better random access)
# we need a more complicated function to get the ID of a given arc
@numba.njit()
def arc_id(arc, graph):
k = graph.n_arcs - arc - 1
if graph.use_arc_mixing:
smallv = (k > graph.num_total_big_subsequence_numbers) & 1
k -= graph.num_total_big_subsequence_numbers * smallv
subsequence_length2 = graph.subsequence_length - smallv
subsequence_num = (
k // subsequence_length2
) + graph.num_big_subsequences * smallv
subsequence_offset = (k % subsequence_length2) * graph.mixing_coeff
return subsequence_offset + subsequence_num
else:
return k
# Heuristic initial pivots
# locals: curr, total, supply_nodes, demand_nodes, u
# modifies:
@numba.njit(locals={"i": numba.uint16})
def construct_initial_pivots(graph, node_arc_data, spanning_tree):
cost = node_arc_data.cost
pi = node_arc_data.pi
source = node_arc_data.source
target = node_arc_data.target
supply = node_arc_data.supply
n1 = graph.n
n2 = graph.m
n_nodes = graph.n_nodes
n_arcs = graph.n_arcs
state = spanning_tree.state
total = 0
supply_nodes = []
demand_nodes = []
for u in range(n_nodes):
curr = supply[n_nodes - u - 1] # _node_id(u)
if curr > 0:
total += curr
supply_nodes.append(u)
elif curr < 0:
demand_nodes.append(u)
arc_vector = []
if len(supply_nodes) == 1 and len(demand_nodes) == 1:
# Perform a reverse graph search from the sink to the source
reached = np.zeros(n_nodes, dtype=np.bool_)
s = supply_nodes[0]
t = demand_nodes[0]
stack = []
reached[t] = True
stack.append(t)
while len(stack) > 0:
u = stack[-1]
v = stack[-1]
stack.pop(-1)
if v == s:
break
first_arc = n_arcs + v - n_nodes if v >= n1 else -1
for a in range(first_arc, -1, -n2):
u = a // n2
if reached[u]:
continue
j = arc_id(a, graph)
if INFINITY >= total:
arc_vector.append(j)
reached[u] = True
stack.append(u)
else:
# Find the min. cost incomming arc for each demand node
for i in range(len(demand_nodes)):
v = demand_nodes[i]
c = MAX
min_cost = MAX
min_arc = INVALID
first_arc = n_arcs + v - n_nodes if v >= n1 else -1
for a in range(first_arc, -1, -n2):
c = cost[arc_id(a, graph)]
if c < min_cost:
min_cost = c
min_arc = a
if min_arc != INVALID:
arc_vector.append(arc_id(min_arc, graph))
# Perform heuristic initial pivots
in_arc = -1
for i in range(len(arc_vector)):
in_arc = arc_vector[i]
# Bad arcs
if (
state[in_arc] * (cost[in_arc] + pi[source[in_arc]] - pi[target[in_arc]])
>= 0
):
continue
join = find_join_node(
source, target, spanning_tree.succ_num, spanning_tree.parent, in_arc
)
leaving_arc_data = find_leaving_arc(join, in_arc, node_arc_data, spanning_tree)
if leaving_arc_data.delta >= MAX:
return False, in_arc
update_flow(join, leaving_arc_data, node_arc_data, spanning_tree, in_arc)
if leaving_arc_data.change:
update_spanning_tree(spanning_tree, leaving_arc_data, join, in_arc, source)
update_potential(leaving_arc_data, pi, cost, spanning_tree)
return True, in_arc
@numba.njit()
def allocate_graph_structures(n, m, use_arc_mixing=True):
# Size bipartite graph
n_nodes = n + m
n_arcs = n * m
# Resize vectors
all_node_num = n_nodes + 1
max_arc_num = n_arcs + 2 * n_nodes
root = n_nodes
source = np.zeros(max_arc_num, dtype=np.uint16)
target = np.zeros(max_arc_num, dtype=np.uint16)
cost = np.ones(max_arc_num, dtype=np.float64)
supply = np.zeros(all_node_num, dtype=np.float64)
flow = np.zeros(max_arc_num, dtype=np.float64)
pi = np.zeros(all_node_num, dtype=np.float64)
parent = np.zeros(all_node_num, dtype=np.int32)
pred = np.zeros(all_node_num, dtype=np.int32)
forward = np.zeros(all_node_num, dtype=np.bool_)
thread = np.zeros(all_node_num, dtype=np.int32)
rev_thread = np.zeros(all_node_num, dtype=np.int32)
succ_num = np.zeros(all_node_num, dtype=np.int32)
last_succ = np.zeros(all_node_num, dtype=np.int32)
state = np.zeros(max_arc_num, dtype=np.int8)
if use_arc_mixing:
# Store the arcs in a mixed order
k = max(np.int32(np.sqrt(n_arcs)), 10)
mixing_coeff = k
subsequence_length = (n_arcs // mixing_coeff) + 1
num_big_subsequences = n_arcs % mixing_coeff
num_total_big_subsequence_numbers = subsequence_length * num_big_subsequences
i = 0
j = 0
for a in range(n_arcs - 1, -1, -1):
source[i] = n_nodes - (a // m) - 1
target[i] = n_nodes - ((a % m) + n) - 1
i += k
if i >= n_arcs:
j += 1
i = j
else:
# dummy values
subsequence_length = 0
mixing_coeff = 0
num_big_subsequences = 0
num_total_big_subsequence_numbers = 0
# Store the arcs in the original order
i = 0
for a in range(n_arcs - 1, -1, -1):
source[i] = n_nodes - (a // m) - 1
target[i] = n_nodes - ((a % m) + n) - 1
i += 1
node_arc_data = NodeArcData(cost, supply, flow, pi, source, target)
spanning_tree = SpanningTree(
parent, pred, thread, rev_thread, succ_num, last_succ, forward, state, root
)
graph = DiGraph(
n_nodes,
n_arcs,
n,
m,
use_arc_mixing,
num_total_big_subsequence_numbers,
subsequence_length,
num_big_subsequences,
mixing_coeff,
)
return node_arc_data, spanning_tree, graph
@numba.njit(locals={"u": numba.uint16, "e": numba.uint32})
def initialize_graph_structures(graph, node_arc_data, spanning_tree):
n_nodes = graph.n_nodes
n_arcs = graph.n_arcs
# unpack arrays
cost = node_arc_data.cost
supply = node_arc_data.supply
flow = node_arc_data.flow
pi = node_arc_data.pi
source = node_arc_data.source
target = node_arc_data.target
parent = spanning_tree.parent
pred = spanning_tree.pred
thread = spanning_tree.thread
rev_thread = spanning_tree.rev_thread
succ_num = spanning_tree.succ_num
last_succ = spanning_tree.last_succ
forward = spanning_tree.forward
state = spanning_tree.state
if n_nodes == 0:
return False
# Check the sum of supply values
net_supply = 0
for i in range(n_nodes):
net_supply += supply[i]
if np.fabs(net_supply) > NET_SUPPLY_ERROR_TOLERANCE:
return False
# Fix using doubles
# Initialize artifical cost
artificial_cost = 0.0
for i in range(n_arcs):
if cost[i] > artificial_cost:
artificial_cost = cost[i]
# reset flow and state vectors
if flow[i] != 0:
flow[i] = 0
state[i] = ArcState.STATE_LOWER
artificial_cost = (artificial_cost + 1) * n_nodes
# Set data for the artificial root node
root = n_nodes
parent[root] = -1
pred[root] = -1
thread[root] = 0
rev_thread[0] = root
succ_num[root] = n_nodes + 1
last_succ[root] = root - 1
supply[root] = -net_supply
pi[root] = 0
# Add artificial arcs and initialize the spanning tree data structure
# EQ supply constraints
e = n_arcs
for u in range(n_nodes):
parent[u] = root
pred[u] = e
thread[u] = u + 1
rev_thread[u + 1] = u
succ_num[u] = 1
last_succ[u] = u
state[e] = ArcState.STATE_TREE
if supply[u] >= 0:
forward[u] = True
pi[u] = 0
source[e] = u
target[e] = root
flow[e] = supply[u]
cost[e] = 0
else:
forward[u] = False
pi[u] = artificial_cost
source[e] = root
target[e] = u
flow[e] = -supply[u]
cost[e] = artificial_cost
e += 1
return True
@numba.njit()
def initialize_supply(left_node_supply, right_node_supply, graph, supply):
for n in range(graph.n_nodes):
if n < graph.n:
supply[graph.n_nodes - n - 1] = left_node_supply[n]
else:
supply[graph.n_nodes - n - 1] = right_node_supply[n - graph.n]
@numba.njit(inline="always")
def set_cost(arc, cost_val, cost, graph):
cost[arc_id(arc, graph)] = cost_val
@numba.njit(locals={"i": numba.uint16, "j": numba.uint16})
def initialize_cost(cost_matrix, graph, cost):
for i in range(cost_matrix.shape[0]):
for j in range(cost_matrix.shape[1]):
set_cost(i * cost_matrix.shape[1] + j, cost_matrix[i, j], cost, graph)
@numba.njit(fastmath=True, locals={"i": numba.uint32})
def total_cost(flow, cost):
c = 0.0
for i in range(flow.shape[0]):
c += flow[i] * cost[i]
return c
@numba.njit(nogil=True)
def network_simplex_core(node_arc_data, spanning_tree, graph, max_iter):
# pivot_block = PivotBlock(
# max(np.int32(np.sqrt(graph.n_arcs)), 10),
# np.zeros(1, dtype=np.int32),
# graph.n_arcs,
# )
pivot_block_size = max(np.int32(np.sqrt(graph.n_arcs)), 10)
search_arc_num = graph.n_arcs
solution_status = ProblemStatus.OPTIMAL
# Perform heuristic initial pivots
bounded, in_arc = construct_initial_pivots(graph, node_arc_data, spanning_tree)
if not bounded:
return ProblemStatus.UNBOUNDED
iter_number = 0
# pivot.setDantzig(true);
# Execute the Network Simplex algorithm
in_arc, pivot_next_arc = find_entering_arc(
pivot_block_size, 0, search_arc_num, spanning_tree.state, node_arc_data, in_arc
)
while in_arc >= 0:
iter_number += 1
if max_iter > 0 and iter_number >= max_iter:
solution_status = ProblemStatus.MAX_ITER_REACHED
break
join = find_join_node(
node_arc_data.source,
node_arc_data.target,
spanning_tree.succ_num,
spanning_tree.parent,
in_arc,
)
leaving_arc_data = find_leaving_arc(join, in_arc, node_arc_data, spanning_tree)
if leaving_arc_data.delta >= MAX:
return ProblemStatus.UNBOUNDED
update_flow(join, leaving_arc_data, node_arc_data, spanning_tree, in_arc)
if leaving_arc_data.change:
update_spanning_tree(
spanning_tree, leaving_arc_data, join, in_arc, node_arc_data.source
)
update_potential(
leaving_arc_data, node_arc_data.pi, node_arc_data.cost, spanning_tree
)
in_arc, pivot_next_arc = find_entering_arc(
pivot_block_size,
pivot_next_arc,
search_arc_num,
spanning_tree.state,
node_arc_data,
in_arc,
)
flow = node_arc_data.flow
pi = node_arc_data.pi
# Check feasibility
if solution_status == ProblemStatus.OPTIMAL:
for e in range(graph.n_arcs, graph.n_arcs + graph.n_nodes):
if flow[e] != 0:
if np.abs(flow[e]) > EPSILON:
return ProblemStatus.INFEASIBLE
else:
flow[e] = 0
# Shift potentials to meet the requirements of the GEQ/LEQ type
# optimality conditions
max_pot = -INFINITY
for i in range(graph.n_nodes):
if pi[i] > max_pot:
max_pot = pi[i]
if max_pot > 0:
for i in range(graph.n_nodes):
pi[i] -= max_pot
return solution_status
#######################################################
# SINKHORN distances in various variations
#######################################################
@numba.njit(
fastmath=True,
parallel=True,
locals={"diff": numba.float32, "result": numba.float32},
cache=False,
)
def right_marginal_error(u, K, v, y):
uK = u @ K
result = 0.0
for i in numba.prange(uK.shape[0]):
diff = y[i] - uK[i] * v[i]
result += diff * diff
return np.sqrt(result)
@numba.njit(
fastmath=True,
parallel=True,
locals={"diff": numba.float32, "result": numba.float32},
cache=False,
)
def right_marginal_error_batch(u, K, v, y):
uK = K.T @ u
result = 0.0
for i in numba.prange(uK.shape[0]):
for j in range(uK.shape[1]):
diff = y[j, i] - uK[i, j] * v[i, j]
result += diff * diff
return np.sqrt(result)
@numba.njit(fastmath=True, parallel=True, cache=False)
def transport_plan(K, u, v):
i_dim = K.shape[0]
j_dim = K.shape[1]
result = np.empty_like(K)
for i in numba.prange(i_dim):
for j in range(j_dim):
result[i, j] = u[i] * K[i, j] * v[j]
return result
@numba.njit(fastmath=True, parallel=True, locals={"result": numba.float32}, cache=False)
def relative_change_in_plan(old_u, old_v, new_u, new_v):
i_dim = old_u.shape[0]
j_dim = old_v.shape[0]
result = 0.0
for i in numba.prange(i_dim):
for j in range(j_dim):
old_uv = old_u[i] * old_v[j]
result += np.float32(np.abs(old_uv - new_u[i] * new_v[j]) / old_uv)
return result / (i_dim * j_dim)
@numba.njit(fastmath=True, parallel=True, cache=False)
def precompute_K_prime(K, x):
i_dim = K.shape[0]
j_dim = K.shape[1]
result = np.empty_like(K)
for i in numba.prange(i_dim):
if x[i] > 0.0:
x_i_inverse = 1.0 / x[i]
else:
x_i_inverse = INFINITY
for j in range(j_dim):
result[i, j] = x_i_inverse * K[i, j]
return result
@numba.njit(fastmath=True, parallel=True, cache=False)
def K_from_cost(cost, regularization):
i_dim = cost.shape[0]
j_dim = cost.shape[1]
result = np.empty_like(cost)
for i in numba.prange(i_dim):
for j in range(j_dim):
scaled_cost = cost[i, j] / regularization
result[i, j] = np.exp(-scaled_cost)
return result
@numba.njit(fastmath=True, cache=True)
def sinkhorn_iterations(
x, y, u, v, K, max_iter=1000, error_tolerance=1e-9, change_tolerance=1e-9
):
K_prime = precompute_K_prime(K, x)
prev_u = u
prev_v = v
for iteration in range(max_iter):
next_v = y / (K.T @ u)
if np.any(~np.isfinite(next_v)):
break
next_u = 1.0 / (K_prime @ next_v)
if np.any(~np.isfinite(next_u)):
break
u = next_u
v = next_v
if iteration % 20 == 0:
# Check if values in plan have changed significantly since last 20 iterations
relative_change = relative_change_in_plan(prev_u, prev_v, next_u, next_v)
if relative_change <= change_tolerance:
break
prev_u = u
prev_v = v
if iteration % 10 == 0:
# Check if right marginal error is less than tolerance every 10 iterations
err = right_marginal_error(u, K, v, y)
if err <= error_tolerance:
break
return u, v
@numba.njit(fastmath=True, cache=True)
def sinkhorn_iterations_batch(x, y, u, v, K, max_iter=1000, error_tolerance=1e-9):
K_prime = precompute_K_prime(K, x)
for iteration in range(max_iter):
next_v = y.T / (K.T @ u)
if np.any(~np.isfinite(next_v)):
break
next_u = 1.0 / (K_prime @ next_v)
if np.any(~np.isfinite(next_u)):
break
u = next_u
v = next_v
if iteration % 10 == 0:
# Check if right marginal error is less than tolerance every 10 iterations
err = right_marginal_error_batch(u, K, v, y)
if err <= error_tolerance:
break
return u, v
@numba.njit(fastmath=True, cache=True)
def sinkhorn_transport_plan(
x,
y,
cost=_dummy_cost,
regularization=1.0,
max_iter=1000,
error_tolerance=1e-9,
change_tolerance=1e-9,
):
dim_x = x.shape[0]
dim_y = y.shape[0]
u = np.full(dim_x, 1.0 / dim_x, dtype=cost.dtype)
v = np.full(dim_y, 1.0 / dim_y, dtype=cost.dtype)
K = K_from_cost(cost, regularization)
u, v = sinkhorn_iterations(
x,
y,
u,
v,
K,
max_iter=max_iter,
error_tolerance=error_tolerance,
change_tolerance=change_tolerance,
)
return transport_plan(K, u, v)
@numba.njit(fastmath=True, cache=True)
def sinkhorn_distance(x, y, cost=_dummy_cost, regularization=1.0):
transport_plan = sinkhorn_transport_plan(
x, y, cost=cost, regularization=regularization
)
dim_i = transport_plan.shape[0]
dim_j = transport_plan.shape[1]
result = 0.0
for i in range(dim_i):
for j in range(dim_j):
result += transport_plan[i, j] * cost[i, j]
return result
@numba.njit(fastmath=True, parallel=True, cache=False)
def sinkhorn_distance_batch(x, y, cost=_dummy_cost, regularization=1.0):
dim_x = x.shape[0]
dim_y = y.shape[0]
batch_size = y.shape[1]
u = np.full((dim_x, batch_size), 1.0 / dim_x, dtype=cost.dtype)
v = np.full((dim_y, batch_size), 1.0 / dim_y, dtype=cost.dtype)
K = K_from_cost(cost, regularization)
u, v = sinkhorn_iterations_batch(
x,
y,
u,
v,
K,
)
i_dim = K.shape[0]
j_dim = K.shape[1]
result = np.zeros(batch_size)
for i in range(i_dim):
for j in range(j_dim):
K_times_cost = K[i, j] * cost[i, j]
for batch in range(batch_size):
result[batch] += u[i, batch] * K_times_cost * v[j, batch]
return result
def make_fixed_cost_sinkhorn_distance(cost, regularization=1.0):
K = K_from_cost(cost, regularization)
dim_x = K.shape[0]
dim_y = K.shape[1]
@numba.njit(fastmath=True)
def closure(x, y):
u = np.full(dim_x, 1.0 / dim_x, dtype=cost.dtype)
v = np.full(dim_y, 1.0 / dim_y, dtype=cost.dtype)
K = K_from_cost(cost, regularization)
u, v = sinkhorn_iterations(
x,
y,
u,
v,
K,
)
current_plan = transport_plan(K, u, v)
result = 0.0
for i in range(dim_x):
for j in range(dim_y):
result += current_plan[i, j] * cost[i, j]
return result
return closure