1195 lines
33 KiB
Python
1195 lines
33 KiB
Python
|
#############################################################################
|
||
|
# This code draws from the Python Optimal Transport version of the
|
||
|
# network simplex algorithm, which in turn was adapted from the LEMON
|
||
|
# library. The copyrights/comment blocks for those are preserved below.
|
||
|
# The Python/Numba implementation was adapted by Leland McInnes (2020).
|
||
|
#
|
||
|
# * This file has been adapted by Nicolas Bonneel (2013),
|
||
|
# * from network_simplex.h from LEMON, a generic C++ optimization library,
|
||
|
# * to implement a lightweight network simplex for mass transport, more
|
||
|
# * memory efficient that the original file. A previous version of this file
|
||
|
# * is used as part of the Displacement Interpolation project,
|
||
|
# * Web: http://www.cs.ubc.ca/labs/imager/tr/2011/DisplacementInterpolation/
|
||
|
# *
|
||
|
# *
|
||
|
# **** Original file Copyright Notice :
|
||
|
# *
|
||
|
# * Copyright (C) 2003-2010
|
||
|
# * Egervary Jeno Kombinatorikus Optimalizalasi Kutatocsoport
|
||
|
# * (Egervary Research Group on Combinatorial Optimization, EGRES).
|
||
|
# *
|
||
|
# * Permission to use, modify and distribute this software is granted
|
||
|
# * provided that this copyright notice appears in all copies. For
|
||
|
# * precise terms see the accompanying LICENSE file.
|
||
|
# *
|
||
|
# * This software is provided "AS IS" with no warranty of any kind,
|
||
|
# * express or implied, and with no claim as to its suitability for any
|
||
|
# * purpose.
|
||
|
|
||
|
import numpy as np
|
||
|
import numba
|
||
|
from collections import namedtuple
|
||
|
from enum import Enum, IntEnum
|
||
|
|
||
|
_mock_identity = np.eye(2, dtype=np.float32)
|
||
|
_mock_ones = np.ones(2, dtype=np.float32)
|
||
|
_dummy_cost = np.zeros((2, 2), dtype=np.float64)
|
||
|
|
||
|
# Accuracy tolerance and net supply tolerance
|
||
|
EPSILON = 2.2204460492503131e-15
|
||
|
NET_SUPPLY_ERROR_TOLERANCE = 1e-8
|
||
|
|
||
|
## Defaults to double for everythig in POT
|
||
|
INFINITY = np.finfo(np.float64).max
|
||
|
MAX = np.finfo(np.float64).max
|
||
|
|
||
|
dummy_cost = np.zeros((2, 2), dtype=np.float64)
|
||
|
|
||
|
# Invalid Arc num
|
||
|
INVALID = -1
|
||
|
|
||
|
# Problem Status
|
||
|
class ProblemStatus(Enum):
|
||
|
OPTIMAL = 0
|
||
|
MAX_ITER_REACHED = 1
|
||
|
UNBOUNDED = 2
|
||
|
INFEASIBLE = 3
|
||
|
|
||
|
|
||
|
# Arc States
|
||
|
class ArcState(IntEnum):
|
||
|
STATE_UPPER = -1
|
||
|
STATE_TREE = 0
|
||
|
STATE_LOWER = 1
|
||
|
|
||
|
|
||
|
SpanningTree = namedtuple(
|
||
|
"SpanningTree",
|
||
|
[
|
||
|
"parent", # int array
|
||
|
"pred", # int array
|
||
|
"thread", # int array
|
||
|
"rev_thread", # int array
|
||
|
"succ_num", # int array
|
||
|
"last_succ", # int array
|
||
|
"forward", # bool array
|
||
|
"state", # state array
|
||
|
"root", # int
|
||
|
],
|
||
|
)
|
||
|
DiGraph = namedtuple(
|
||
|
"DiGraph",
|
||
|
[
|
||
|
"n_nodes", # int
|
||
|
"n_arcs", # int
|
||
|
"n", # int
|
||
|
"m", # int
|
||
|
"use_arc_mixing", # bool
|
||
|
"num_total_big_subsequence_numbers", # int
|
||
|
"subsequence_length", # int
|
||
|
"num_big_subsequences", # int
|
||
|
"mixing_coeff",
|
||
|
],
|
||
|
)
|
||
|
NodeArcData = namedtuple(
|
||
|
"NodeArcData",
|
||
|
[
|
||
|
"cost", # double array
|
||
|
"supply", # double array
|
||
|
"flow", # double array
|
||
|
"pi", # double array
|
||
|
"source", # unsigned int array
|
||
|
"target", # unsigned int array
|
||
|
],
|
||
|
)
|
||
|
LeavingArcData = namedtuple(
|
||
|
"LeavingArcData", ["u_in", "u_out", "v_in", "delta", "change"]
|
||
|
)
|
||
|
|
||
|
# Just reproduce a simpler version of numpy isclose (not numba supported yet)
|
||
|
@numba.njit()
|
||
|
def isclose(a, b, rtol=1.0e-5, atol=EPSILON):
|
||
|
diff = np.abs(a - b)
|
||
|
return diff <= (atol + rtol * np.abs(b))
|
||
|
|
||
|
|
||
|
# locals: c, min, e, cnt, a
|
||
|
# modifies _in_arc, _next_arc,
|
||
|
@numba.njit(locals={"a": numba.uint32, "e": numba.uint32})
|
||
|
def find_entering_arc(
|
||
|
pivot_block_size,
|
||
|
pivot_next_arc,
|
||
|
search_arc_num,
|
||
|
state_vector,
|
||
|
node_arc_data,
|
||
|
in_arc,
|
||
|
):
|
||
|
min = 0
|
||
|
cnt = pivot_block_size
|
||
|
|
||
|
# Pull from tuple for quick reference
|
||
|
cost = node_arc_data.cost
|
||
|
pi = node_arc_data.pi
|
||
|
source = node_arc_data.source
|
||
|
target = node_arc_data.target
|
||
|
|
||
|
for e in range(pivot_next_arc, search_arc_num):
|
||
|
c = state_vector[e] * (cost[e] + pi[source[e]] - pi[target[e]])
|
||
|
if c < min:
|
||
|
min = c
|
||
|
in_arc = e
|
||
|
|
||
|
cnt -= 1
|
||
|
if cnt == 0:
|
||
|
if np.fabs(pi[source[in_arc]]) > np.fabs(pi[target[in_arc]]):
|
||
|
a = np.fabs(pi[source[in_arc]])
|
||
|
else:
|
||
|
a = np.fabs(pi[target[in_arc]])
|
||
|
|
||
|
if a <= np.fabs(cost[in_arc]):
|
||
|
a = np.fabs(cost[in_arc])
|
||
|
|
||
|
if min < -(EPSILON * a):
|
||
|
pivot_next_arc = e
|
||
|
return in_arc, pivot_next_arc
|
||
|
else:
|
||
|
cnt = pivot_block_size
|
||
|
|
||
|
for e in range(pivot_next_arc):
|
||
|
c = state_vector[e] * (cost[e] + pi[source[e]] - pi[target[e]])
|
||
|
if c < min:
|
||
|
min = c
|
||
|
in_arc = e
|
||
|
|
||
|
cnt -= 1
|
||
|
if cnt == 0:
|
||
|
if np.fabs(pi[source[in_arc]]) > np.fabs(pi[target[in_arc]]):
|
||
|
a = np.fabs(pi[source[in_arc]])
|
||
|
else:
|
||
|
a = np.fabs(pi[target[in_arc]])
|
||
|
|
||
|
if a <= np.fabs(cost[in_arc]):
|
||
|
a = np.fabs(cost[in_arc])
|
||
|
|
||
|
if min < -(EPSILON * a):
|
||
|
pivot_next_arc = e
|
||
|
return in_arc, pivot_next_arc
|
||
|
else:
|
||
|
cnt = pivot_block_size
|
||
|
|
||
|
# assert(pivot_block.next_arc[0] == 0 or e == pivot_block.next_arc[0] - 1)
|
||
|
|
||
|
if np.fabs(pi[source[in_arc]]) > np.fabs(pi[target[in_arc]]):
|
||
|
a = np.fabs(pi[source[in_arc]])
|
||
|
else:
|
||
|
a = np.fabs(pi[target[in_arc]])
|
||
|
|
||
|
if a <= np.fabs(cost[in_arc]):
|
||
|
a = np.fabs(cost[in_arc])
|
||
|
|
||
|
if min >= -(EPSILON * a):
|
||
|
return -1, 0
|
||
|
|
||
|
return in_arc, pivot_next_arc
|
||
|
|
||
|
|
||
|
# Find the join node
|
||
|
# Operates with graph (_source, _target) and MST (_succ_num, _parent, in_arc) data
|
||
|
# locals: u, v
|
||
|
# modifies: join
|
||
|
@numba.njit(locals={"u": numba.types.uint16, "v": numba.types.uint16})
|
||
|
def find_join_node(source, target, succ_num, parent, in_arc):
|
||
|
u = source[in_arc]
|
||
|
v = target[in_arc]
|
||
|
while u != v:
|
||
|
if succ_num[u] < succ_num[v]:
|
||
|
u = parent[u]
|
||
|
else:
|
||
|
v = parent[v]
|
||
|
|
||
|
join = u
|
||
|
|
||
|
return join
|
||
|
|
||
|
|
||
|
# Find the leaving arc of the cycle and returns true if the
|
||
|
# leaving arc is not the same as the entering arc
|
||
|
# locals: first, second, result, d, e
|
||
|
# modifies: u_in, v_in, u_out, delta
|
||
|
@numba.njit(
|
||
|
locals={
|
||
|
"u": numba.uint16,
|
||
|
"u_in": numba.uint16,
|
||
|
"u_out": numba.uint16,
|
||
|
"v_in": numba.uint16,
|
||
|
"first": numba.uint16,
|
||
|
"second": numba.uint16,
|
||
|
"result": numba.uint8,
|
||
|
"in_arc": numba.uint32,
|
||
|
}
|
||
|
)
|
||
|
def find_leaving_arc(join, in_arc, node_arc_data, spanning_tree):
|
||
|
source = node_arc_data.source
|
||
|
target = node_arc_data.target
|
||
|
flow = node_arc_data.flow
|
||
|
|
||
|
state = spanning_tree.state
|
||
|
forward = spanning_tree.forward
|
||
|
pred = spanning_tree.pred
|
||
|
parent = spanning_tree.parent
|
||
|
|
||
|
u_out = -1 # May not be set, but we need to return something?
|
||
|
|
||
|
# Initialize first and second nodes according to the direction
|
||
|
# of the cycle
|
||
|
if state[in_arc] == ArcState.STATE_LOWER:
|
||
|
first = source[in_arc]
|
||
|
second = target[in_arc]
|
||
|
else:
|
||
|
first = target[in_arc]
|
||
|
second = source[in_arc]
|
||
|
|
||
|
delta = INFINITY
|
||
|
result = 0
|
||
|
|
||
|
# Search the cycle along the path form the first node to the root
|
||
|
u = first
|
||
|
while u != join:
|
||
|
e = pred[u]
|
||
|
if forward[u]:
|
||
|
d = flow[e]
|
||
|
else:
|
||
|
d = INFINITY
|
||
|
|
||
|
if d < delta:
|
||
|
delta = d
|
||
|
u_out = u
|
||
|
result = 1
|
||
|
|
||
|
u = parent[u]
|
||
|
|
||
|
# Search the cycle along the path form the second node to the root
|
||
|
u = second
|
||
|
while u != join:
|
||
|
e = pred[u]
|
||
|
if forward[u]:
|
||
|
d = INFINITY
|
||
|
else:
|
||
|
d = flow[e]
|
||
|
|
||
|
if d <= delta:
|
||
|
delta = d
|
||
|
u_out = u
|
||
|
result = 2
|
||
|
|
||
|
u = parent[u]
|
||
|
|
||
|
if result == 1:
|
||
|
u_in = first
|
||
|
v_in = second
|
||
|
else:
|
||
|
u_in = second
|
||
|
v_in = first
|
||
|
|
||
|
return LeavingArcData(u_in, u_out, v_in, delta, result != 0)
|
||
|
|
||
|
|
||
|
# Change _flow and _state vectors
|
||
|
# locals: val, u
|
||
|
# modifies: _state, _flow
|
||
|
@numba.njit(locals={"u": numba.uint16, "in_arc": numba.uint32, "val": numba.float64})
|
||
|
def update_flow(join, leaving_arc_data, node_arc_data, spanning_tree, in_arc):
|
||
|
source = node_arc_data.source
|
||
|
target = node_arc_data.target
|
||
|
flow = node_arc_data.flow
|
||
|
|
||
|
state = spanning_tree.state
|
||
|
pred = spanning_tree.pred
|
||
|
parent = spanning_tree.parent
|
||
|
forward = spanning_tree.forward
|
||
|
|
||
|
# Augment along the cycle
|
||
|
if leaving_arc_data.delta > 0:
|
||
|
val = state[in_arc] * leaving_arc_data.delta
|
||
|
flow[in_arc] += val
|
||
|
u = source[in_arc]
|
||
|
while u != join:
|
||
|
if forward[u]:
|
||
|
flow[pred[u]] -= val
|
||
|
else:
|
||
|
flow[pred[u]] += val
|
||
|
|
||
|
u = parent[u]
|
||
|
|
||
|
u = target[in_arc]
|
||
|
while u != join:
|
||
|
if forward[u]:
|
||
|
flow[pred[u]] += val
|
||
|
else:
|
||
|
flow[pred[u]] -= val
|
||
|
|
||
|
u = parent[u]
|
||
|
|
||
|
# Update the state of the entering and leaving arcs
|
||
|
if leaving_arc_data.change:
|
||
|
state[in_arc] = ArcState.STATE_TREE
|
||
|
if flow[pred[leaving_arc_data.u_out]] == 0:
|
||
|
state[pred[leaving_arc_data.u_out]] = ArcState.STATE_LOWER
|
||
|
else:
|
||
|
state[pred[leaving_arc_data.u_out]] = ArcState.STATE_UPPER
|
||
|
else:
|
||
|
state[in_arc] = -state[in_arc]
|
||
|
|
||
|
|
||
|
# Update the tree structure
|
||
|
# locals: u, w, old_rev_thread, old_succ_num, old_last_succ, tmp_sc, tmp_ls
|
||
|
# more locals: up_limit_in, up_limit_out, _dirty_revs
|
||
|
# modifies: v_out, _thread, _rev_thread, _parent, _last_succ,
|
||
|
# modifies: _pred, _forward, _succ_num
|
||
|
@numba.njit(
|
||
|
locals={
|
||
|
"u": numba.int32,
|
||
|
"w": numba.int32,
|
||
|
"u_in": numba.uint16,
|
||
|
"u_out": numba.uint16,
|
||
|
"v_in": numba.uint16,
|
||
|
"right": numba.uint16,
|
||
|
"stem": numba.uint16,
|
||
|
"new_stem": numba.uint16,
|
||
|
"par_stem": numba.uint16,
|
||
|
"in_arc": numba.uint32,
|
||
|
}
|
||
|
)
|
||
|
def update_spanning_tree(spanning_tree, leaving_arc_data, join, in_arc, source):
|
||
|
|
||
|
parent = spanning_tree.parent
|
||
|
thread = spanning_tree.thread
|
||
|
rev_thread = spanning_tree.rev_thread
|
||
|
succ_num = spanning_tree.succ_num
|
||
|
last_succ = spanning_tree.last_succ
|
||
|
forward = spanning_tree.forward
|
||
|
pred = spanning_tree.pred
|
||
|
|
||
|
u_out = leaving_arc_data.u_out
|
||
|
u_in = leaving_arc_data.u_in
|
||
|
v_in = leaving_arc_data.v_in
|
||
|
|
||
|
old_rev_thread = rev_thread[u_out]
|
||
|
old_succ_num = succ_num[u_out]
|
||
|
old_last_succ = last_succ[u_out]
|
||
|
v_out = parent[u_out]
|
||
|
|
||
|
u = last_succ[u_in] # the last successor of u_in
|
||
|
right = thread[u] # the node after it
|
||
|
|
||
|
# Handle the case when old_rev_thread equals to v_in
|
||
|
# (it also means that join and v_out coincide)
|
||
|
if old_rev_thread == v_in:
|
||
|
last = thread[last_succ[u_out]]
|
||
|
else:
|
||
|
last = thread[v_in]
|
||
|
|
||
|
# Update _thread and _parent along the stem nodes (i.e. the nodes
|
||
|
# between u_in and u_out, whose parent have to be changed)
|
||
|
thread[v_in] = stem = u_in
|
||
|
dirty_revs = []
|
||
|
dirty_revs.append(v_in)
|
||
|
par_stem = v_in
|
||
|
while stem != u_out:
|
||
|
# Insert the next stem node into the thread list
|
||
|
new_stem = parent[stem]
|
||
|
thread[u] = new_stem
|
||
|
dirty_revs.append(u)
|
||
|
|
||
|
# Remove the subtree of stem from the thread list
|
||
|
w = rev_thread[stem]
|
||
|
thread[w] = right
|
||
|
rev_thread[right] = w
|
||
|
|
||
|
# Change the parent node and shift stem nodes
|
||
|
parent[stem] = par_stem
|
||
|
par_stem = stem
|
||
|
stem = new_stem
|
||
|
|
||
|
# Update u and right
|
||
|
if last_succ[stem] == last_succ[par_stem]:
|
||
|
u = rev_thread[par_stem]
|
||
|
else:
|
||
|
u = last_succ[stem]
|
||
|
|
||
|
right = thread[u]
|
||
|
|
||
|
parent[u_out] = par_stem
|
||
|
thread[u] = last
|
||
|
rev_thread[last] = u
|
||
|
last_succ[u_out] = u
|
||
|
|
||
|
# Remove the subtree of u_out from the thread list except for
|
||
|
# the case when old_rev_thread equals to v_in
|
||
|
# (it also means that join and v_out coincide)
|
||
|
if old_rev_thread != v_in:
|
||
|
thread[old_rev_thread] = right
|
||
|
rev_thread[right] = old_rev_thread
|
||
|
|
||
|
# Update _rev_thread using the new _thread values
|
||
|
for i in range(len(dirty_revs)):
|
||
|
u = dirty_revs[i]
|
||
|
rev_thread[thread[u]] = u
|
||
|
|
||
|
# Update _pred, _forward, _last_succ and _succ_num for the
|
||
|
# stem nodes from u_out to u_in
|
||
|
tmp_sc = 0
|
||
|
tmp_ls = last_succ[u_out]
|
||
|
u = u_out
|
||
|
while u != u_in:
|
||
|
w = parent[u]
|
||
|
pred[u] = pred[w]
|
||
|
forward[u] = not forward[w]
|
||
|
tmp_sc += succ_num[u] - succ_num[w]
|
||
|
succ_num[u] = tmp_sc
|
||
|
last_succ[w] = tmp_ls
|
||
|
u = w
|
||
|
|
||
|
pred[u_in] = in_arc
|
||
|
forward[u_in] = u_in == source[in_arc]
|
||
|
succ_num[u_in] = old_succ_num
|
||
|
|
||
|
# Set limits for updating _last_succ form v_in and v_out
|
||
|
# towards the root
|
||
|
up_limit_in = -1
|
||
|
up_limit_out = -1
|
||
|
if last_succ[join] == v_in:
|
||
|
up_limit_out = join
|
||
|
else:
|
||
|
up_limit_in = join
|
||
|
|
||
|
# Update _last_succ from v_in towards the root
|
||
|
u = v_in
|
||
|
while u != up_limit_in and last_succ[u] == v_in:
|
||
|
last_succ[u] = last_succ[u_out]
|
||
|
u = parent[u]
|
||
|
|
||
|
# Update _last_succ from v_out towards the root
|
||
|
if join != old_rev_thread and v_in != old_rev_thread:
|
||
|
u = v_out
|
||
|
while u != up_limit_out and last_succ[u] == old_last_succ:
|
||
|
last_succ[u] = old_rev_thread
|
||
|
u = parent[u]
|
||
|
|
||
|
else:
|
||
|
u = v_out
|
||
|
while u != up_limit_out and last_succ[u] == old_last_succ:
|
||
|
last_succ[u] = last_succ[u_out]
|
||
|
u = parent[u]
|
||
|
|
||
|
# Update _succ_num from v_in to join
|
||
|
u = v_in
|
||
|
while u != join:
|
||
|
succ_num[u] += old_succ_num
|
||
|
u = parent[u]
|
||
|
|
||
|
# Update _succ_num from v_out to join
|
||
|
u = v_out
|
||
|
while u != join:
|
||
|
succ_num[u] -= old_succ_num
|
||
|
u = parent[u]
|
||
|
|
||
|
|
||
|
# Update potentials
|
||
|
# locals: sigma, end
|
||
|
# modifies: _pi
|
||
|
@numba.njit(
|
||
|
fastmath=True,
|
||
|
inline="always",
|
||
|
locals={"u": numba.uint16, "u_in": numba.uint16, "v_in": numba.uint16},
|
||
|
)
|
||
|
def update_potential(leaving_arc_data, pi, cost, spanning_tree):
|
||
|
|
||
|
thread = spanning_tree.thread
|
||
|
pred = spanning_tree.pred
|
||
|
forward = spanning_tree.forward
|
||
|
last_succ = spanning_tree.last_succ
|
||
|
|
||
|
u_in = leaving_arc_data.u_in
|
||
|
v_in = leaving_arc_data.v_in
|
||
|
|
||
|
if forward[u_in]:
|
||
|
sigma = pi[v_in] - pi[u_in] - cost[pred[u_in]]
|
||
|
else:
|
||
|
sigma = pi[v_in] - pi[u_in] + cost[pred[u_in]]
|
||
|
|
||
|
# Update potentials in the subtree, which has been moved
|
||
|
end = thread[last_succ[u_in]]
|
||
|
u = u_in
|
||
|
while u != end:
|
||
|
pi[u] += sigma
|
||
|
u = thread[u]
|
||
|
|
||
|
|
||
|
# If we have mixed arcs (for better random access)
|
||
|
# we need a more complicated function to get the ID of a given arc
|
||
|
@numba.njit()
|
||
|
def arc_id(arc, graph):
|
||
|
k = graph.n_arcs - arc - 1
|
||
|
if graph.use_arc_mixing:
|
||
|
smallv = (k > graph.num_total_big_subsequence_numbers) & 1
|
||
|
k -= graph.num_total_big_subsequence_numbers * smallv
|
||
|
subsequence_length2 = graph.subsequence_length - smallv
|
||
|
subsequence_num = (
|
||
|
k // subsequence_length2
|
||
|
) + graph.num_big_subsequences * smallv
|
||
|
subsequence_offset = (k % subsequence_length2) * graph.mixing_coeff
|
||
|
|
||
|
return subsequence_offset + subsequence_num
|
||
|
else:
|
||
|
return k
|
||
|
|
||
|
|
||
|
# Heuristic initial pivots
|
||
|
# locals: curr, total, supply_nodes, demand_nodes, u
|
||
|
# modifies:
|
||
|
@numba.njit(locals={"i": numba.uint16})
|
||
|
def construct_initial_pivots(graph, node_arc_data, spanning_tree):
|
||
|
|
||
|
cost = node_arc_data.cost
|
||
|
pi = node_arc_data.pi
|
||
|
source = node_arc_data.source
|
||
|
target = node_arc_data.target
|
||
|
supply = node_arc_data.supply
|
||
|
|
||
|
n1 = graph.n
|
||
|
n2 = graph.m
|
||
|
n_nodes = graph.n_nodes
|
||
|
n_arcs = graph.n_arcs
|
||
|
|
||
|
state = spanning_tree.state
|
||
|
|
||
|
total = 0
|
||
|
supply_nodes = []
|
||
|
demand_nodes = []
|
||
|
|
||
|
for u in range(n_nodes):
|
||
|
curr = supply[n_nodes - u - 1] # _node_id(u)
|
||
|
if curr > 0:
|
||
|
total += curr
|
||
|
supply_nodes.append(u)
|
||
|
elif curr < 0:
|
||
|
demand_nodes.append(u)
|
||
|
|
||
|
arc_vector = []
|
||
|
if len(supply_nodes) == 1 and len(demand_nodes) == 1:
|
||
|
# Perform a reverse graph search from the sink to the source
|
||
|
reached = np.zeros(n_nodes, dtype=np.bool_)
|
||
|
s = supply_nodes[0]
|
||
|
t = demand_nodes[0]
|
||
|
stack = []
|
||
|
reached[t] = True
|
||
|
stack.append(t)
|
||
|
while len(stack) > 0:
|
||
|
u = stack[-1]
|
||
|
v = stack[-1]
|
||
|
stack.pop(-1)
|
||
|
if v == s:
|
||
|
break
|
||
|
|
||
|
first_arc = n_arcs + v - n_nodes if v >= n1 else -1
|
||
|
for a in range(first_arc, -1, -n2):
|
||
|
u = a // n2
|
||
|
if reached[u]:
|
||
|
continue
|
||
|
|
||
|
j = arc_id(a, graph)
|
||
|
if INFINITY >= total:
|
||
|
arc_vector.append(j)
|
||
|
reached[u] = True
|
||
|
stack.append(u)
|
||
|
|
||
|
else:
|
||
|
# Find the min. cost incomming arc for each demand node
|
||
|
for i in range(len(demand_nodes)):
|
||
|
v = demand_nodes[i]
|
||
|
c = MAX
|
||
|
min_cost = MAX
|
||
|
min_arc = INVALID
|
||
|
first_arc = n_arcs + v - n_nodes if v >= n1 else -1
|
||
|
for a in range(first_arc, -1, -n2):
|
||
|
c = cost[arc_id(a, graph)]
|
||
|
if c < min_cost:
|
||
|
min_cost = c
|
||
|
min_arc = a
|
||
|
|
||
|
if min_arc != INVALID:
|
||
|
arc_vector.append(arc_id(min_arc, graph))
|
||
|
|
||
|
# Perform heuristic initial pivots
|
||
|
in_arc = -1
|
||
|
for i in range(len(arc_vector)):
|
||
|
in_arc = arc_vector[i]
|
||
|
# Bad arcs
|
||
|
if (
|
||
|
state[in_arc] * (cost[in_arc] + pi[source[in_arc]] - pi[target[in_arc]])
|
||
|
>= 0
|
||
|
):
|
||
|
continue
|
||
|
|
||
|
join = find_join_node(
|
||
|
source, target, spanning_tree.succ_num, spanning_tree.parent, in_arc
|
||
|
)
|
||
|
leaving_arc_data = find_leaving_arc(join, in_arc, node_arc_data, spanning_tree)
|
||
|
if leaving_arc_data.delta >= MAX:
|
||
|
return False, in_arc
|
||
|
|
||
|
update_flow(join, leaving_arc_data, node_arc_data, spanning_tree, in_arc)
|
||
|
if leaving_arc_data.change:
|
||
|
update_spanning_tree(spanning_tree, leaving_arc_data, join, in_arc, source)
|
||
|
update_potential(leaving_arc_data, pi, cost, spanning_tree)
|
||
|
|
||
|
return True, in_arc
|
||
|
|
||
|
|
||
|
@numba.njit()
|
||
|
def allocate_graph_structures(n, m, use_arc_mixing=True):
|
||
|
|
||
|
# Size bipartite graph
|
||
|
n_nodes = n + m
|
||
|
n_arcs = n * m
|
||
|
|
||
|
# Resize vectors
|
||
|
all_node_num = n_nodes + 1
|
||
|
max_arc_num = n_arcs + 2 * n_nodes
|
||
|
root = n_nodes
|
||
|
|
||
|
source = np.zeros(max_arc_num, dtype=np.uint16)
|
||
|
target = np.zeros(max_arc_num, dtype=np.uint16)
|
||
|
cost = np.ones(max_arc_num, dtype=np.float64)
|
||
|
supply = np.zeros(all_node_num, dtype=np.float64)
|
||
|
flow = np.zeros(max_arc_num, dtype=np.float64)
|
||
|
pi = np.zeros(all_node_num, dtype=np.float64)
|
||
|
|
||
|
parent = np.zeros(all_node_num, dtype=np.int32)
|
||
|
pred = np.zeros(all_node_num, dtype=np.int32)
|
||
|
forward = np.zeros(all_node_num, dtype=np.bool_)
|
||
|
thread = np.zeros(all_node_num, dtype=np.int32)
|
||
|
rev_thread = np.zeros(all_node_num, dtype=np.int32)
|
||
|
succ_num = np.zeros(all_node_num, dtype=np.int32)
|
||
|
last_succ = np.zeros(all_node_num, dtype=np.int32)
|
||
|
state = np.zeros(max_arc_num, dtype=np.int8)
|
||
|
|
||
|
if use_arc_mixing:
|
||
|
# Store the arcs in a mixed order
|
||
|
k = max(np.int32(np.sqrt(n_arcs)), 10)
|
||
|
mixing_coeff = k
|
||
|
subsequence_length = (n_arcs // mixing_coeff) + 1
|
||
|
num_big_subsequences = n_arcs % mixing_coeff
|
||
|
num_total_big_subsequence_numbers = subsequence_length * num_big_subsequences
|
||
|
|
||
|
i = 0
|
||
|
j = 0
|
||
|
for a in range(n_arcs - 1, -1, -1):
|
||
|
source[i] = n_nodes - (a // m) - 1
|
||
|
target[i] = n_nodes - ((a % m) + n) - 1
|
||
|
i += k
|
||
|
if i >= n_arcs:
|
||
|
j += 1
|
||
|
i = j
|
||
|
|
||
|
else:
|
||
|
# dummy values
|
||
|
subsequence_length = 0
|
||
|
mixing_coeff = 0
|
||
|
num_big_subsequences = 0
|
||
|
num_total_big_subsequence_numbers = 0
|
||
|
# Store the arcs in the original order
|
||
|
i = 0
|
||
|
for a in range(n_arcs - 1, -1, -1):
|
||
|
source[i] = n_nodes - (a // m) - 1
|
||
|
target[i] = n_nodes - ((a % m) + n) - 1
|
||
|
i += 1
|
||
|
|
||
|
node_arc_data = NodeArcData(cost, supply, flow, pi, source, target)
|
||
|
spanning_tree = SpanningTree(
|
||
|
parent, pred, thread, rev_thread, succ_num, last_succ, forward, state, root
|
||
|
)
|
||
|
graph = DiGraph(
|
||
|
n_nodes,
|
||
|
n_arcs,
|
||
|
n,
|
||
|
m,
|
||
|
use_arc_mixing,
|
||
|
num_total_big_subsequence_numbers,
|
||
|
subsequence_length,
|
||
|
num_big_subsequences,
|
||
|
mixing_coeff,
|
||
|
)
|
||
|
|
||
|
return node_arc_data, spanning_tree, graph
|
||
|
|
||
|
|
||
|
@numba.njit(locals={"u": numba.uint16, "e": numba.uint32})
|
||
|
def initialize_graph_structures(graph, node_arc_data, spanning_tree):
|
||
|
|
||
|
n_nodes = graph.n_nodes
|
||
|
n_arcs = graph.n_arcs
|
||
|
|
||
|
# unpack arrays
|
||
|
cost = node_arc_data.cost
|
||
|
supply = node_arc_data.supply
|
||
|
flow = node_arc_data.flow
|
||
|
pi = node_arc_data.pi
|
||
|
source = node_arc_data.source
|
||
|
target = node_arc_data.target
|
||
|
|
||
|
parent = spanning_tree.parent
|
||
|
pred = spanning_tree.pred
|
||
|
thread = spanning_tree.thread
|
||
|
rev_thread = spanning_tree.rev_thread
|
||
|
succ_num = spanning_tree.succ_num
|
||
|
last_succ = spanning_tree.last_succ
|
||
|
forward = spanning_tree.forward
|
||
|
state = spanning_tree.state
|
||
|
|
||
|
if n_nodes == 0:
|
||
|
return False
|
||
|
|
||
|
# Check the sum of supply values
|
||
|
net_supply = 0
|
||
|
for i in range(n_nodes):
|
||
|
net_supply += supply[i]
|
||
|
|
||
|
if np.fabs(net_supply) > NET_SUPPLY_ERROR_TOLERANCE:
|
||
|
return False
|
||
|
|
||
|
# Fix using doubles
|
||
|
# Initialize artifical cost
|
||
|
artificial_cost = 0.0
|
||
|
for i in range(n_arcs):
|
||
|
if cost[i] > artificial_cost:
|
||
|
artificial_cost = cost[i]
|
||
|
# reset flow and state vectors
|
||
|
if flow[i] != 0:
|
||
|
flow[i] = 0
|
||
|
state[i] = ArcState.STATE_LOWER
|
||
|
|
||
|
artificial_cost = (artificial_cost + 1) * n_nodes
|
||
|
|
||
|
# Set data for the artificial root node
|
||
|
root = n_nodes
|
||
|
parent[root] = -1
|
||
|
pred[root] = -1
|
||
|
thread[root] = 0
|
||
|
rev_thread[0] = root
|
||
|
succ_num[root] = n_nodes + 1
|
||
|
last_succ[root] = root - 1
|
||
|
supply[root] = -net_supply
|
||
|
pi[root] = 0
|
||
|
|
||
|
# Add artificial arcs and initialize the spanning tree data structure
|
||
|
# EQ supply constraints
|
||
|
e = n_arcs
|
||
|
for u in range(n_nodes):
|
||
|
parent[u] = root
|
||
|
pred[u] = e
|
||
|
thread[u] = u + 1
|
||
|
rev_thread[u + 1] = u
|
||
|
succ_num[u] = 1
|
||
|
last_succ[u] = u
|
||
|
state[e] = ArcState.STATE_TREE
|
||
|
if supply[u] >= 0:
|
||
|
forward[u] = True
|
||
|
pi[u] = 0
|
||
|
source[e] = u
|
||
|
target[e] = root
|
||
|
flow[e] = supply[u]
|
||
|
cost[e] = 0
|
||
|
else:
|
||
|
forward[u] = False
|
||
|
pi[u] = artificial_cost
|
||
|
source[e] = root
|
||
|
target[e] = u
|
||
|
flow[e] = -supply[u]
|
||
|
cost[e] = artificial_cost
|
||
|
e += 1
|
||
|
|
||
|
return True
|
||
|
|
||
|
|
||
|
@numba.njit()
|
||
|
def initialize_supply(left_node_supply, right_node_supply, graph, supply):
|
||
|
for n in range(graph.n_nodes):
|
||
|
if n < graph.n:
|
||
|
supply[graph.n_nodes - n - 1] = left_node_supply[n]
|
||
|
else:
|
||
|
supply[graph.n_nodes - n - 1] = right_node_supply[n - graph.n]
|
||
|
|
||
|
|
||
|
@numba.njit(inline="always")
|
||
|
def set_cost(arc, cost_val, cost, graph):
|
||
|
cost[arc_id(arc, graph)] = cost_val
|
||
|
|
||
|
|
||
|
@numba.njit(locals={"i": numba.uint16, "j": numba.uint16})
|
||
|
def initialize_cost(cost_matrix, graph, cost):
|
||
|
for i in range(cost_matrix.shape[0]):
|
||
|
for j in range(cost_matrix.shape[1]):
|
||
|
set_cost(i * cost_matrix.shape[1] + j, cost_matrix[i, j], cost, graph)
|
||
|
|
||
|
|
||
|
@numba.njit(fastmath=True, locals={"i": numba.uint32})
|
||
|
def total_cost(flow, cost):
|
||
|
c = 0.0
|
||
|
for i in range(flow.shape[0]):
|
||
|
c += flow[i] * cost[i]
|
||
|
return c
|
||
|
|
||
|
|
||
|
@numba.njit(nogil=True)
|
||
|
def network_simplex_core(node_arc_data, spanning_tree, graph, max_iter):
|
||
|
|
||
|
# pivot_block = PivotBlock(
|
||
|
# max(np.int32(np.sqrt(graph.n_arcs)), 10),
|
||
|
# np.zeros(1, dtype=np.int32),
|
||
|
# graph.n_arcs,
|
||
|
# )
|
||
|
pivot_block_size = max(np.int32(np.sqrt(graph.n_arcs)), 10)
|
||
|
search_arc_num = graph.n_arcs
|
||
|
solution_status = ProblemStatus.OPTIMAL
|
||
|
|
||
|
# Perform heuristic initial pivots
|
||
|
bounded, in_arc = construct_initial_pivots(graph, node_arc_data, spanning_tree)
|
||
|
if not bounded:
|
||
|
return ProblemStatus.UNBOUNDED
|
||
|
|
||
|
iter_number = 0
|
||
|
# pivot.setDantzig(true);
|
||
|
# Execute the Network Simplex algorithm
|
||
|
in_arc, pivot_next_arc = find_entering_arc(
|
||
|
pivot_block_size, 0, search_arc_num, spanning_tree.state, node_arc_data, in_arc
|
||
|
)
|
||
|
while in_arc >= 0:
|
||
|
iter_number += 1
|
||
|
if max_iter > 0 and iter_number >= max_iter:
|
||
|
solution_status = ProblemStatus.MAX_ITER_REACHED
|
||
|
break
|
||
|
|
||
|
join = find_join_node(
|
||
|
node_arc_data.source,
|
||
|
node_arc_data.target,
|
||
|
spanning_tree.succ_num,
|
||
|
spanning_tree.parent,
|
||
|
in_arc,
|
||
|
)
|
||
|
leaving_arc_data = find_leaving_arc(join, in_arc, node_arc_data, spanning_tree)
|
||
|
if leaving_arc_data.delta >= MAX:
|
||
|
return ProblemStatus.UNBOUNDED
|
||
|
|
||
|
update_flow(join, leaving_arc_data, node_arc_data, spanning_tree, in_arc)
|
||
|
|
||
|
if leaving_arc_data.change:
|
||
|
update_spanning_tree(
|
||
|
spanning_tree, leaving_arc_data, join, in_arc, node_arc_data.source
|
||
|
)
|
||
|
update_potential(
|
||
|
leaving_arc_data, node_arc_data.pi, node_arc_data.cost, spanning_tree
|
||
|
)
|
||
|
|
||
|
in_arc, pivot_next_arc = find_entering_arc(
|
||
|
pivot_block_size,
|
||
|
pivot_next_arc,
|
||
|
search_arc_num,
|
||
|
spanning_tree.state,
|
||
|
node_arc_data,
|
||
|
in_arc,
|
||
|
)
|
||
|
|
||
|
flow = node_arc_data.flow
|
||
|
pi = node_arc_data.pi
|
||
|
|
||
|
# Check feasibility
|
||
|
if solution_status == ProblemStatus.OPTIMAL:
|
||
|
for e in range(graph.n_arcs, graph.n_arcs + graph.n_nodes):
|
||
|
if flow[e] != 0:
|
||
|
if np.abs(flow[e]) > EPSILON:
|
||
|
return ProblemStatus.INFEASIBLE
|
||
|
else:
|
||
|
flow[e] = 0
|
||
|
|
||
|
# Shift potentials to meet the requirements of the GEQ/LEQ type
|
||
|
# optimality conditions
|
||
|
max_pot = -INFINITY
|
||
|
for i in range(graph.n_nodes):
|
||
|
if pi[i] > max_pot:
|
||
|
max_pot = pi[i]
|
||
|
if max_pot > 0:
|
||
|
for i in range(graph.n_nodes):
|
||
|
pi[i] -= max_pot
|
||
|
|
||
|
return solution_status
|
||
|
|
||
|
|
||
|
#######################################################
|
||
|
# SINKHORN distances in various variations
|
||
|
#######################################################
|
||
|
|
||
|
|
||
|
@numba.njit(
|
||
|
fastmath=True,
|
||
|
parallel=True,
|
||
|
locals={"diff": numba.float32, "result": numba.float32},
|
||
|
cache=False,
|
||
|
)
|
||
|
def right_marginal_error(u, K, v, y):
|
||
|
uK = u @ K
|
||
|
result = 0.0
|
||
|
for i in numba.prange(uK.shape[0]):
|
||
|
diff = y[i] - uK[i] * v[i]
|
||
|
result += diff * diff
|
||
|
return np.sqrt(result)
|
||
|
|
||
|
|
||
|
@numba.njit(
|
||
|
fastmath=True,
|
||
|
parallel=True,
|
||
|
locals={"diff": numba.float32, "result": numba.float32},
|
||
|
cache=False,
|
||
|
)
|
||
|
def right_marginal_error_batch(u, K, v, y):
|
||
|
uK = K.T @ u
|
||
|
result = 0.0
|
||
|
for i in numba.prange(uK.shape[0]):
|
||
|
for j in range(uK.shape[1]):
|
||
|
diff = y[j, i] - uK[i, j] * v[i, j]
|
||
|
result += diff * diff
|
||
|
return np.sqrt(result)
|
||
|
|
||
|
|
||
|
@numba.njit(fastmath=True, parallel=True, cache=False)
|
||
|
def transport_plan(K, u, v):
|
||
|
i_dim = K.shape[0]
|
||
|
j_dim = K.shape[1]
|
||
|
result = np.empty_like(K)
|
||
|
for i in numba.prange(i_dim):
|
||
|
for j in range(j_dim):
|
||
|
result[i, j] = u[i] * K[i, j] * v[j]
|
||
|
|
||
|
return result
|
||
|
|
||
|
|
||
|
@numba.njit(fastmath=True, parallel=True, locals={"result": numba.float32}, cache=False)
|
||
|
def relative_change_in_plan(old_u, old_v, new_u, new_v):
|
||
|
i_dim = old_u.shape[0]
|
||
|
j_dim = old_v.shape[0]
|
||
|
result = 0.0
|
||
|
for i in numba.prange(i_dim):
|
||
|
for j in range(j_dim):
|
||
|
old_uv = old_u[i] * old_v[j]
|
||
|
result += np.float32(np.abs(old_uv - new_u[i] * new_v[j]) / old_uv)
|
||
|
|
||
|
return result / (i_dim * j_dim)
|
||
|
|
||
|
|
||
|
@numba.njit(fastmath=True, parallel=True, cache=False)
|
||
|
def precompute_K_prime(K, x):
|
||
|
i_dim = K.shape[0]
|
||
|
j_dim = K.shape[1]
|
||
|
result = np.empty_like(K)
|
||
|
for i in numba.prange(i_dim):
|
||
|
if x[i] > 0.0:
|
||
|
x_i_inverse = 1.0 / x[i]
|
||
|
else:
|
||
|
x_i_inverse = INFINITY
|
||
|
for j in range(j_dim):
|
||
|
result[i, j] = x_i_inverse * K[i, j]
|
||
|
|
||
|
return result
|
||
|
|
||
|
|
||
|
@numba.njit(fastmath=True, parallel=True, cache=False)
|
||
|
def K_from_cost(cost, regularization):
|
||
|
i_dim = cost.shape[0]
|
||
|
j_dim = cost.shape[1]
|
||
|
result = np.empty_like(cost)
|
||
|
for i in numba.prange(i_dim):
|
||
|
for j in range(j_dim):
|
||
|
scaled_cost = cost[i, j] / regularization
|
||
|
result[i, j] = np.exp(-scaled_cost)
|
||
|
|
||
|
return result
|
||
|
|
||
|
|
||
|
@numba.njit(fastmath=True, cache=True)
|
||
|
def sinkhorn_iterations(
|
||
|
x, y, u, v, K, max_iter=1000, error_tolerance=1e-9, change_tolerance=1e-9
|
||
|
):
|
||
|
K_prime = precompute_K_prime(K, x)
|
||
|
|
||
|
prev_u = u
|
||
|
prev_v = v
|
||
|
|
||
|
for iteration in range(max_iter):
|
||
|
|
||
|
next_v = y / (K.T @ u)
|
||
|
|
||
|
if np.any(~np.isfinite(next_v)):
|
||
|
break
|
||
|
|
||
|
next_u = 1.0 / (K_prime @ next_v)
|
||
|
|
||
|
if np.any(~np.isfinite(next_u)):
|
||
|
break
|
||
|
|
||
|
u = next_u
|
||
|
v = next_v
|
||
|
|
||
|
if iteration % 20 == 0:
|
||
|
# Check if values in plan have changed significantly since last 20 iterations
|
||
|
relative_change = relative_change_in_plan(prev_u, prev_v, next_u, next_v)
|
||
|
if relative_change <= change_tolerance:
|
||
|
break
|
||
|
|
||
|
prev_u = u
|
||
|
prev_v = v
|
||
|
|
||
|
if iteration % 10 == 0:
|
||
|
# Check if right marginal error is less than tolerance every 10 iterations
|
||
|
err = right_marginal_error(u, K, v, y)
|
||
|
if err <= error_tolerance:
|
||
|
break
|
||
|
|
||
|
return u, v
|
||
|
|
||
|
|
||
|
@numba.njit(fastmath=True, cache=True)
|
||
|
def sinkhorn_iterations_batch(x, y, u, v, K, max_iter=1000, error_tolerance=1e-9):
|
||
|
K_prime = precompute_K_prime(K, x)
|
||
|
|
||
|
for iteration in range(max_iter):
|
||
|
|
||
|
next_v = y.T / (K.T @ u)
|
||
|
|
||
|
if np.any(~np.isfinite(next_v)):
|
||
|
break
|
||
|
|
||
|
next_u = 1.0 / (K_prime @ next_v)
|
||
|
|
||
|
if np.any(~np.isfinite(next_u)):
|
||
|
break
|
||
|
|
||
|
u = next_u
|
||
|
v = next_v
|
||
|
|
||
|
if iteration % 10 == 0:
|
||
|
# Check if right marginal error is less than tolerance every 10 iterations
|
||
|
err = right_marginal_error_batch(u, K, v, y)
|
||
|
if err <= error_tolerance:
|
||
|
break
|
||
|
|
||
|
return u, v
|
||
|
|
||
|
|
||
|
@numba.njit(fastmath=True, cache=True)
|
||
|
def sinkhorn_transport_plan(
|
||
|
x,
|
||
|
y,
|
||
|
cost=_dummy_cost,
|
||
|
regularization=1.0,
|
||
|
max_iter=1000,
|
||
|
error_tolerance=1e-9,
|
||
|
change_tolerance=1e-9,
|
||
|
):
|
||
|
dim_x = x.shape[0]
|
||
|
dim_y = y.shape[0]
|
||
|
u = np.full(dim_x, 1.0 / dim_x, dtype=cost.dtype)
|
||
|
v = np.full(dim_y, 1.0 / dim_y, dtype=cost.dtype)
|
||
|
|
||
|
K = K_from_cost(cost, regularization)
|
||
|
u, v = sinkhorn_iterations(
|
||
|
x,
|
||
|
y,
|
||
|
u,
|
||
|
v,
|
||
|
K,
|
||
|
max_iter=max_iter,
|
||
|
error_tolerance=error_tolerance,
|
||
|
change_tolerance=change_tolerance,
|
||
|
)
|
||
|
|
||
|
return transport_plan(K, u, v)
|
||
|
|
||
|
|
||
|
@numba.njit(fastmath=True, cache=True)
|
||
|
def sinkhorn_distance(x, y, cost=_dummy_cost, regularization=1.0):
|
||
|
transport_plan = sinkhorn_transport_plan(
|
||
|
x, y, cost=cost, regularization=regularization
|
||
|
)
|
||
|
dim_i = transport_plan.shape[0]
|
||
|
dim_j = transport_plan.shape[1]
|
||
|
result = 0.0
|
||
|
for i in range(dim_i):
|
||
|
for j in range(dim_j):
|
||
|
result += transport_plan[i, j] * cost[i, j]
|
||
|
|
||
|
return result
|
||
|
|
||
|
|
||
|
@numba.njit(fastmath=True, parallel=True, cache=False)
|
||
|
def sinkhorn_distance_batch(x, y, cost=_dummy_cost, regularization=1.0):
|
||
|
dim_x = x.shape[0]
|
||
|
dim_y = y.shape[0]
|
||
|
|
||
|
batch_size = y.shape[1]
|
||
|
|
||
|
u = np.full((dim_x, batch_size), 1.0 / dim_x, dtype=cost.dtype)
|
||
|
v = np.full((dim_y, batch_size), 1.0 / dim_y, dtype=cost.dtype)
|
||
|
|
||
|
K = K_from_cost(cost, regularization)
|
||
|
u, v = sinkhorn_iterations_batch(
|
||
|
x,
|
||
|
y,
|
||
|
u,
|
||
|
v,
|
||
|
K,
|
||
|
)
|
||
|
|
||
|
i_dim = K.shape[0]
|
||
|
j_dim = K.shape[1]
|
||
|
result = np.zeros(batch_size)
|
||
|
for i in range(i_dim):
|
||
|
for j in range(j_dim):
|
||
|
K_times_cost = K[i, j] * cost[i, j]
|
||
|
for batch in range(batch_size):
|
||
|
result[batch] += u[i, batch] * K_times_cost * v[j, batch]
|
||
|
|
||
|
return result
|
||
|
|
||
|
|
||
|
def make_fixed_cost_sinkhorn_distance(cost, regularization=1.0):
|
||
|
|
||
|
K = K_from_cost(cost, regularization)
|
||
|
dim_x = K.shape[0]
|
||
|
dim_y = K.shape[1]
|
||
|
|
||
|
@numba.njit(fastmath=True)
|
||
|
def closure(x, y):
|
||
|
u = np.full(dim_x, 1.0 / dim_x, dtype=cost.dtype)
|
||
|
v = np.full(dim_y, 1.0 / dim_y, dtype=cost.dtype)
|
||
|
|
||
|
K = K_from_cost(cost, regularization)
|
||
|
u, v = sinkhorn_iterations(
|
||
|
x,
|
||
|
y,
|
||
|
u,
|
||
|
v,
|
||
|
K,
|
||
|
)
|
||
|
|
||
|
current_plan = transport_plan(K, u, v)
|
||
|
|
||
|
result = 0.0
|
||
|
for i in range(dim_x):
|
||
|
for j in range(dim_y):
|
||
|
result += current_plan[i, j] * cost[i, j]
|
||
|
|
||
|
return result
|
||
|
|
||
|
return closure
|