ai-content-maker/.venv/Lib/site-packages/umap/parametric_umap.py

1022 lines
34 KiB
Python

import numpy as np
from umap import UMAP
from warnings import warn, catch_warnings, filterwarnings
from numba import TypingError
import os
from umap.spectral import spectral_layout
from sklearn.utils import check_random_state
import codecs, pickle
from sklearn.neighbors import KDTree
try:
# Used for tf.data.
import tensorflow as tf
except ImportError:
warn(
"""The umap.parametric_umap package requires Tensorflow > 2.0 to be installed.
You can install Tensorflow at https://www.tensorflow.org/install
or you can install the CPU version of Tensorflow using
pip install umap-learn[parametric_umap]
"""
)
raise ImportError("umap.parametric_umap requires Tensorflow >= 2.0") from None
try:
import keras
from keras import ops
except ImportError:
warn(
"""The umap.parametric_umap package requires Keras >= 3 to be installed."""
)
raise ImportError("umap.parametric_umap requires Keras") from None
class ParametricUMAP(UMAP):
def __init__(
self,
batch_size=None,
dims=None,
encoder=None,
decoder=None,
parametric_reconstruction=False,
parametric_reconstruction_loss_fcn=None,
parametric_reconstruction_loss_weight=1.0,
autoencoder_loss=False,
reconstruction_validation=None,
global_correlation_loss_weight=0,
keras_fit_kwargs={},
**kwargs
):
"""
Parametric UMAP subclassing UMAP-learn, based on keras/tensorflow.
There is also a non-parametric implementation contained within to compare
with the base non-parametric implementation.
Parameters
----------
batch_size : int, optional
size of batch used for batch training, by default None
dims : tuple, optional
dimensionality of data, if not flat (e.g. (32x32x3 images for ConvNet), by default None
encoder : keras.Sequential, optional
The encoder Keras network
decoder : keras.Sequential, optional
the decoder Keras network
parametric_reconstruction : bool, optional
Whether the decoder is parametric or non-parametric, by default False
parametric_reconstruction_loss_fcn : bool, optional
What loss function to use for parametric reconstruction,
by default keras.losses.BinaryCrossentropy
parametric_reconstruction_loss_weight : float, optional
How to weight the parametric reconstruction loss relative to umap loss, by default 1.0
autoencoder_loss : bool, optional
[description], by default False
reconstruction_validation : array, optional
validation X data for reconstruction loss, by default None
global_correlation_loss_weight : float, optional
Whether to additionally train on correlation of global pairwise relationships (>0), by default 0
keras_fit_kwargs : dict, optional
additional arguments for model.fit (like callbacks), by default {}
"""
super().__init__(**kwargs)
# add to network
self.dims = dims # if this is an image, we should reshape for network
self.encoder = encoder # neural network used for embedding
self.decoder = decoder # neural network used for decoding
self.parametric_reconstruction = parametric_reconstruction
self.parametric_reconstruction_loss_weight = (
parametric_reconstruction_loss_weight
)
self.parametric_reconstruction_loss_fcn = parametric_reconstruction_loss_fcn
self.autoencoder_loss = autoencoder_loss
self.batch_size = batch_size
self.loss_report_frequency = 10
self.global_correlation_loss_weight = global_correlation_loss_weight
self.reconstruction_validation = (
reconstruction_validation # holdout data for reconstruction acc
)
self.keras_fit_kwargs = keras_fit_kwargs # arguments for model.fit
self.parametric_model = None
# How many epochs to train for
# (different than n_epochs which is specific to each sample)
self.n_training_epochs = 1
# Set optimizer.
# Adam is better for parametric_embedding. Use gradient clipping by value.
self.optimizer = keras.optimizers.Adam(1e-3, clipvalue=4.0)
if self.encoder is not None:
if encoder.outputs[0].shape[-1] != self.n_components:
raise ValueError(
(
"Dimensionality of embedder network output ({}) does"
"not match n_components ({})".format(
encoder.outputs[0].shape[-1], self.n_components
)
)
)
def fit(self, X, y=None, precomputed_distances=None):
if self.metric == "precomputed":
if precomputed_distances is None:
raise ValueError(
"Precomputed distances must be supplied if metric \
is precomputed."
)
# prepare X for training the network
self._X = X
# geneate the graph on precomputed distances
return super().fit(precomputed_distances, y)
else:
return super().fit(X, y)
def fit_transform(self, X, y=None, precomputed_distances=None):
if self.metric == "precomputed":
if precomputed_distances is None:
raise ValueError(
"Precomputed distances must be supplied if metric \
is precomputed."
)
# prepare X for training the network
self._X = X
# generate the graph on precomputed distances
return super().fit_transform(precomputed_distances, y)
else:
return super().fit_transform(X, y)
def transform(self, X):
"""Transform X into the existing embedded space and return that
transformed output.
Parameters
----------
X : array, shape (n_samples, n_features)
New data to be transformed.
Returns
-------
X_new : array, shape (n_samples, n_components)
Embedding of the new data in low-dimensional space.
"""
return self.encoder.predict(
np.asanyarray(X), batch_size=self.batch_size, verbose=self.verbose
)
def inverse_transform(self, X):
""" Transform X in the existing embedded space back into the input
data space and return that transformed output.
Parameters
----------
X : array, shape (n_samples, n_components)
New points to be inverse transformed.
Returns
-------
X_new : array, shape (n_samples, n_features)
Generated data points new data in data space.
"""
if self.parametric_reconstruction:
return self.decoder.predict(
np.asanyarray(X), batch_size=self.batch_size, verbose=self.verbose
)
else:
return super().inverse_transform(X)
def _define_model(self):
"""Define the model in keras"""
prlw = self.parametric_reconstruction_loss_weight
self.parametric_model = UMAPModel(
self._a,
self._b,
negative_sample_rate=self.negative_sample_rate,
encoder=self.encoder,
decoder=self.decoder,
parametric_reconstruction_loss_fn=self.parametric_reconstruction_loss_fcn,
parametric_reconstruction=self.parametric_reconstruction,
parametric_reconstruction_loss_weight=prlw,
global_correlation_loss_weight=self.global_correlation_loss_weight,
autoencoder_loss=self.autoencoder_loss,
)
def _fit_embed_data(self, X, n_epochs, init, random_state):
if self.metric == "precomputed":
X = self._X
# get dimensionality of dataset
if self.dims is None:
self.dims = [np.shape(X)[-1]]
else:
# reshape data for network
if len(self.dims) > 1:
X = np.reshape(X, [len(X)] + list(self.dims))
if self.parametric_reconstruction and (np.max(X) > 1.0 or np.min(X) < 0.0):
warn(
"Data should be scaled to the range 0-1 for cross-entropy reconstruction loss."
)
# get dataset of edges
(
edge_dataset,
self.batch_size,
n_edges,
head,
tail,
self.edge_weight,
) = construct_edge_dataset(
X,
self.graph_,
self.n_epochs,
self.batch_size,
self.parametric_reconstruction,
self.global_correlation_loss_weight,
)
self.head = ops.array(ops.expand_dims(head.astype(np.int64), 0))
self.tail = ops.array(ops.expand_dims(tail.astype(np.int64), 0))
init_embedding = None
# create encoder and decoder model
n_data = len(X)
self.encoder, self.decoder = prepare_networks(
self.encoder,
self.decoder,
self.n_components,
self.dims,
n_data,
self.parametric_reconstruction,
init_embedding,
)
# create the model
self._define_model()
# report every loss_report_frequency subdivision of an epochs
steps_per_epoch = int(
n_edges / self.batch_size / self.loss_report_frequency
)
# Validation dataset for reconstruction
if (
self.parametric_reconstruction
and self.reconstruction_validation is not None
):
# reshape data for network
if len(self.dims) > 1:
self.reconstruction_validation = np.reshape(
self.reconstruction_validation,
[len(self.reconstruction_validation)] + list(self.dims),
)
validation_data = (
(
self.reconstruction_validation,
ops.zeros_like(self.reconstruction_validation),
),
{"reconstruction": self.reconstruction_validation},
)
else:
validation_data = None
# create embedding
history = self.parametric_model.fit(
edge_dataset,
epochs=self.loss_report_frequency * self.n_training_epochs,
steps_per_epoch=steps_per_epoch,
validation_data=validation_data,
**self.keras_fit_kwargs
)
# save loss history dictionary
self._history = history.history
# get the final embedding
embedding = self.encoder.predict(X, verbose=self.verbose)
return embedding, {}
def __getstate__(self):
# this function supports pickling, making sure that objects can be pickled
return dict(
(k, v)
for (k, v) in self.__dict__.items()
if should_pickle(k, v) and k not in ("optimizer", "encoder", "decoder", "parametric_model")
)
def save(self, save_location, verbose=True):
# save encoder
if self.encoder is not None:
encoder_output = os.path.join(save_location, "encoder.keras")
self.encoder.save(encoder_output)
if verbose:
print("Keras encoder model saved to {}".format(encoder_output))
# save decoder
if self.decoder is not None:
decoder_output = os.path.join(save_location, "decoder.keras")
self.decoder.save(decoder_output)
if verbose:
print("Keras decoder model saved to {}".format(decoder_output))
# save parametric_model
if self.parametric_model is not None:
parametric_model_output = os.path.join(save_location, "parametric_model.keras")
self.parametric_model.save(parametric_model_output)
if verbose:
print("Keras full model saved to {}".format(parametric_model_output))
# # save model.pkl (ignoring unpickleable warnings)
with catch_warnings():
filterwarnings("ignore")
model_output = os.path.join(save_location, "model.pkl")
with open(model_output, "wb") as output:
pickle.dump(self, output, pickle.HIGHEST_PROTOCOL)
if verbose:
print("Pickle of ParametricUMAP model saved to {}".format(model_output))
def get_graph_elements(graph_, n_epochs):
"""
gets elements of graphs, weights, and number of epochs per edge
Parameters
----------
graph_ : scipy.sparse.csr.csr_matrix
umap graph of probabilities
n_epochs : int
maximum number of epochs per edge
Returns
-------
graph scipy.sparse.csr.csr_matrix
umap graph
epochs_per_sample np.array
number of epochs to train each sample for
head np.array
edge head
tail np.array
edge tail
weight np.array
edge weight
n_vertices int
number of vertices in graph
"""
### should we remove redundancies () here??
# graph_ = remove_redundant_edges(graph_)
graph = graph_.tocoo()
# eliminate duplicate entries by summing them together
graph.sum_duplicates()
# number of vertices in dataset
n_vertices = graph.shape[1]
# get the number of epochs based on the size of the dataset
if n_epochs is None:
# For smaller datasets we can use more epochs
if graph.shape[0] <= 10000:
n_epochs = 500
else:
n_epochs = 200
# remove elements with very low probability
graph.data[graph.data < (graph.data.max() / float(n_epochs))] = 0.0
graph.eliminate_zeros()
# get epochs per sample based upon edge probability
epochs_per_sample = n_epochs * graph.data
head = graph.row
tail = graph.col
weight = graph.data
return graph, epochs_per_sample, head, tail, weight, n_vertices
def init_embedding_from_graph(
_raw_data, graph, n_components, random_state, metric, _metric_kwds, init="spectral"
):
"""Initialize embedding using graph. This is for direct embeddings.
Parameters
----------
init : str, optional
Type of initialization to use. Either random, or spectral, by default "spectral"
Returns
-------
embedding : np.array
the initialized embedding
"""
if random_state is None:
random_state = check_random_state(None)
if isinstance(init, str) and init == "random":
embedding = random_state.uniform(
low=-10.0, high=10.0, size=(graph.shape[0], n_components)
).astype(np.float32)
elif isinstance(init, str) and init == "spectral":
# We add a little noise to avoid local minima for optimization to come
initialisation = spectral_layout(
_raw_data,
graph,
n_components,
random_state,
metric=metric,
metric_kwds=_metric_kwds,
)
expansion = 10.0 / np.abs(initialisation).max()
embedding = (initialisation * expansion).astype(
np.float32
) + random_state.normal(
scale=0.0001, size=[graph.shape[0], n_components]
).astype(
np.float32
)
else:
init_data = np.array(init)
if len(init_data.shape) == 2:
if np.unique(init_data, axis=0).shape[0] < init_data.shape[0]:
tree = KDTree(init_data)
dist, ind = tree.query(init_data, k=2)
nndist = np.mean(dist[:, 1])
embedding = init_data + random_state.normal(
scale=0.001 * nndist, size=init_data.shape
).astype(np.float32)
else:
embedding = init_data
return embedding
def convert_distance_to_log_probability(distances, a=1.0, b=1.0):
"""
convert distance representation into log probability,
as a function of a, b params
Parameters
----------
distances : array
euclidean distance between two points in embedding
a : float, optional
parameter based on min_dist, by default 1.0
b : float, optional
parameter based on min_dist, by default 1.0
Returns
-------
float
log probability in embedding space
"""
return -ops.log1p(a * distances ** (2 * b))
def compute_cross_entropy(
probabilities_graph, log_probabilities_distance, EPS=1e-4, repulsion_strength=1.0
):
"""
Compute cross entropy between low and high probability
Parameters
----------
probabilities_graph : array
high dimensional probabilities
log_probabilities_distance : array
low dimensional log probabilities
EPS : float, optional
offset to ensure log is taken of a positive number, by default 1e-4
repulsion_strength : float, optional
strength of repulsion between negative samples, by default 1.0
Returns
-------
attraction_term: float
attraction term for cross entropy loss
repellant_term: float
repellent term for cross entropy loss
cross_entropy: float
cross entropy umap loss
"""
# cross entropy
attraction_term = -probabilities_graph * ops.log_sigmoid(
log_probabilities_distance
)
# use numerically stable repellent term
# Shi et al. 2022 (https://arxiv.org/abs/2111.08851)
# log(1 - sigmoid(logits)) = log(sigmoid(logits)) - logits
repellant_term = (
-(1.0 - probabilities_graph)
* (ops.log_sigmoid(log_probabilities_distance) - log_probabilities_distance)
* repulsion_strength
)
# balance the expected losses between attraction and repel
CE = attraction_term + repellant_term
return attraction_term, repellant_term, CE
def prepare_networks(
encoder,
decoder,
n_components,
dims,
n_data,
parametric_reconstruction,
init_embedding=None,
):
"""
Generates a set of keras networks for the encoder and decoder if one has not already
been predefined.
Parameters
----------
encoder : keras.Sequential
The encoder Keras network
decoder : keras.Sequential
the decoder Keras network
n_components : int
the dimensionality of the latent space
dims : tuple of shape (dim1, dim2, dim3...)
dimensionality of data
n_data : number of elements in dataset
# of elements in training dataset
parametric_reconstruction : bool
Whether the decoder is parametric or non-parametric
init_embedding : array (optional, default None)
The initial embedding, for nonparametric embeddings
Returns
-------
encoder: keras.Sequential
encoder keras network
decoder: keras.Sequential
decoder keras network
"""
if encoder is None:
encoder = keras.Sequential(
[
keras.layers.Input(shape=dims),
keras.layers.Flatten(),
keras.layers.Dense(units=100, activation="relu"),
keras.layers.Dense(units=100, activation="relu"),
keras.layers.Dense(units=100, activation="relu"),
keras.layers.Dense(units=n_components, name="z"),
]
)
if decoder is None:
if parametric_reconstruction:
decoder = keras.Sequential(
[
keras.layers.Input(shape=(n_components,)),
keras.layers.Dense(units=100, activation="relu"),
keras.layers.Dense(units=100, activation="relu"),
keras.layers.Dense(units=100, activation="relu"),
keras.layers.Dense(
units=np.product(dims), name="recon", activation=None
),
keras.layers.Reshape(dims),
]
)
return encoder, decoder
def construct_edge_dataset(
X,
graph_,
n_epochs,
batch_size,
parametric_reconstruction,
global_correlation_loss_weight,
):
"""
Construct a tf.data.Dataset of edges, sampled by edge weight.
Parameters
----------
X : array, shape (n_samples, n_features)
New data to be transformed.
graph_ : scipy.sparse.csr.csr_matrix
Generated UMAP graph
n_epochs : int
# of epochs to train each edge
batch_size : int
batch size
parametric_reconstruction : bool
Whether the decoder is parametric or non-parametric
"""
def gather_index(index):
return X[index]
# if X is > 512Mb in size, we need to use a different, slower method for
# batching data.
gather_indices_in_python = True if X.nbytes * 1e-9 > 0.5 else False
def gather_X(edge_to, edge_from):
# gather data from indexes (edges) in either numpy of tf, depending on array size
if gather_indices_in_python:
edge_to_batch = tf.py_function(gather_index, [edge_to], [tf.float32])[0]
edge_from_batch = tf.py_function(gather_index, [edge_from], [tf.float32])[0]
else:
edge_to_batch = tf.gather(X, edge_to)
edge_from_batch = tf.gather(X, edge_from)
return edge_to_batch, edge_from_batch
def get_outputs(edge_to_batch, edge_from_batch):
outputs = {"umap": ops.repeat(0, batch_size)}
if global_correlation_loss_weight > 0:
outputs["global_correlation"] = edge_to_batch
if parametric_reconstruction:
# add reconstruction to iterator output
# edge_out = ops.concatenate([edge_to_batch, edge_from_batch], axis=0)
outputs["reconstruction"] = edge_to_batch
return (edge_to_batch, edge_from_batch), outputs
# get data from graph
_, epochs_per_sample, head, tail, weight, n_vertices = get_graph_elements(
graph_, n_epochs
)
# number of elements per batch for embedding
if batch_size is None:
# batch size can be larger if its just over embeddings
batch_size = int(np.min([n_vertices, 1000]))
edges_to_exp, edges_from_exp = (
np.repeat(head, epochs_per_sample.astype("int")),
np.repeat(tail, epochs_per_sample.astype("int")),
)
# shuffle edges
shuffle_mask = np.random.permutation(range(len(edges_to_exp)))
edges_to_exp = edges_to_exp[shuffle_mask].astype(np.int64)
edges_from_exp = edges_from_exp[shuffle_mask].astype(np.int64)
# create edge iterator
edge_dataset = tf.data.Dataset.from_tensor_slices(
(edges_to_exp, edges_from_exp)
)
edge_dataset = edge_dataset.repeat()
edge_dataset = edge_dataset.shuffle(10000)
edge_dataset = edge_dataset.batch(batch_size, drop_remainder=True)
edge_dataset = edge_dataset.map(
gather_X, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
edge_dataset = edge_dataset.map(
get_outputs, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
edge_dataset = edge_dataset.prefetch(10)
return edge_dataset, batch_size, len(edges_to_exp), head, tail, weight
def should_pickle(key, val):
"""
Checks if a dictionary item can be pickled
Parameters
----------
key : try
key for dictionary element
val : None
element of dictionary
Returns
-------
picklable: bool
whether the dictionary item can be pickled
"""
try:
## make sure object can be pickled and then re-read
# pickle object
pickled = codecs.encode(pickle.dumps(val), "base64").decode()
# unpickle object
_ = pickle.loads(codecs.decode(pickled.encode(), "base64"))
except (
pickle.PicklingError,
tf.errors.InvalidArgumentError,
TypeError,
tf.errors.InternalError,
tf.errors.NotFoundError,
OverflowError,
TypingError,
AttributeError,
) as e:
warn("Did not pickle {}: {}".format(key, e))
return False
except ValueError as e:
warn(f"Failed at pickling {key}:{val} due to {e}")
return False
return True
def load_ParametricUMAP(save_location, verbose=True):
"""
Load a parametric UMAP model consisting of a umap-learn UMAP object
and corresponding keras models.
Parameters
----------
save_location : str
the folder that the model was saved in
verbose : bool, optional
Whether to print the loading steps, by default True
Returns
-------
parametric_umap.ParametricUMAP
Parametric UMAP objects
"""
## Loads a ParametricUMAP model and its related keras models
model_output = os.path.join(save_location, "model.pkl")
model = pickle.load((open(model_output, "rb")))
if verbose:
print("Pickle of ParametricUMAP model loaded from {}".format(model_output))
# load encoder
encoder_output = os.path.join(save_location, "encoder.keras")
if os.path.exists(encoder_output):
model.encoder = keras.models.load_model(encoder_output)
if verbose:
print("Keras encoder model loaded from {}".format(encoder_output))
# save decoder
decoder_output = os.path.join(save_location, "decoder.keras")
if os.path.exists(decoder_output):
model.decoder = keras.models.load_model(decoder_output)
print("Keras decoder model loaded from {}".format(decoder_output))
# save parametric_model
parametric_model_output = os.path.join(save_location, "parametric_model")
if os.path.exists(parametric_model_output):
model.parametric_model = keras.models.load_model(
parametric_model_output
)
print("Keras full model loaded from {}".format(parametric_model_output))
return model
def covariance(x,
y=None,
keepdims=False):
"""Adapted from TF Probability."""
x = ops.convert_to_tensor(x)
# Covariance *only* uses the centered versions of x (and y).
x = x - ops.mean(x, axis=0, keepdims=True)
if y is None:
y = x
event_axis = ops.mean(
x * ops.conj(y), axis=0, keepdims=keepdims)
else:
y = ops.convert_to_tensor(y, dtype=x.dtype)
y = y - ops.mean(y, axis=0, keepdims=True)
event_axis = [len(x.shape) - 1]
sample_axis = [0]
event_axis = ops.cast(event_axis, dtype="int32")
sample_axis = ops.cast(sample_axis, dtype="int32")
x_permed = ops.transpose(x)
y_permed = ops.transpose(y)
n_events = ops.shape(x_permed)[0]
n_samples = ops.shape(x_permed)[1]
# Flatten sample_axis into one long dim.
x_permed_flat = ops.reshape(
x_permed, (n_events, n_samples))
y_permed_flat = ops.reshape(
y_permed, (n_events, n_samples))
# Do the same for event_axis.
x_permed_flat = ops.reshape(
x_permed, (n_events, n_samples))
y_permed_flat = ops.reshape(
y_permed, (n_events, n_samples))
# After matmul, cov.shape = batch_shape + [n_events, n_events]
cov = ops.matmul(
x_permed_flat, ops.transpose(y_permed_flat)) / ops.cast(
n_samples, x.dtype)
cov = ops.reshape(
cov,
(n_events**2, 1),
)
# Permuting by the argsort inverts the permutation, making
# cov.shape have ones in the position where there were samples, and
# [n_events * n_events] in the event position.
cov = ops.transpose(cov)
# Now expand event_shape**2 into event_shape + event_shape.
# We here use (for the first time) the fact that we require event_axis to be
# contiguous.
cov = ops.reshape(
cov,
ops.shape(cov)[:1] + (n_events, n_events),
)
if not keepdims:
cov = ops.squeeze(cov, axis=0)
return cov
def correlation(x,
y=None,
keepdims=False):
x = x / ops.std(x, axis=0, keepdims=True)
if y is not None:
y = y / ops.std(y, axis=0, keepdims=True)
return covariance(
x=x,
y=y,
keepdims=keepdims)
class StopGradient(keras.layers.Layer):
def call(self, x):
return ops.stop_gradient(x)
class UMAPModel(keras.Model):
def __init__(self,
umap_loss_a,
umap_loss_b,
negative_sample_rate,
encoder,
decoder,
optimizer=None,
parametric_reconstruction_loss_fn=None,
parametric_reconstruction=False,
parametric_reconstruction_loss_weight=1.,
global_correlation_loss_weight=0.,
autoencoder_loss=False,
name="umap_model"):
super().__init__(name=name)
self.encoder = encoder
self.decoder = decoder
self.parametric_reconstruction = parametric_reconstruction
self.global_correlation_loss_weight = global_correlation_loss_weight
self.parametric_reconstruction_loss_weight = (
parametric_reconstruction_loss_weight
)
self.negative_sample_rate = negative_sample_rate
self.umap_loss_a = umap_loss_a
self.umap_loss_b = umap_loss_b
self.autoencoder_loss = autoencoder_loss
optimizer = optimizer or keras.optimizers.Adam(1e-3, clipvalue=4.0)
self.compile(optimizer=optimizer)
self.flatten = keras.layers.Flatten()
self.seed_generator = keras.random.SeedGenerator()
if parametric_reconstruction_loss_fn is None:
self.parametric_reconstruction_loss_fn = keras.losses.BinaryCrossentropy(
from_logits=True
)
else:
self.parametric_reconstruction_loss_fn = (
parametric_reconstruction_loss_fn
)
def call(self, inputs):
to_x, from_x = inputs
embedding_to = self.encoder(to_x)
embedding_from = self.encoder(from_x)
y_pred = {
"embedding_to": embedding_to,
"embedding_from": embedding_from,
}
if self.parametric_reconstruction:
# parametric reconstruction
if self.autoencoder_loss:
embedding_to_recon = self.decoder(embedding_to)
else:
# stop gradient of reconstruction loss before it reaches the encoder
embedding_to_recon = self.decoder(ops.stop_gradient(embedding_to))
y_pred["reconstruction"] = embedding_to_recon
return y_pred
def compute_loss(
self, x=None, y=None, y_pred=None, sample_weight=None, **kwargs
):
losses = []
# Regularization losses.
for loss in self.losses:
losses.append(ops.cast(loss, dtype=keras.backend.floatx()))
# umap loss
losses.append(self._umap_loss(y_pred))
# global correlation loss
if self.global_correlation_loss_weight > 0:
losses.append(self._global_correlation_loss(y, y_pred))
# parametric reconstruction loss
if self.parametric_reconstruction:
losses.append(self._parametric_reconstruction_loss(y, y_pred))
return ops.sum(losses)
def _umap_loss(self, y_pred, repulsion_strength=1.0):
# split out to/from
embedding_to = y_pred["embedding_to"]
embedding_from = y_pred["embedding_from"]
# get negative samples
embedding_neg_to = ops.repeat(embedding_to, self.negative_sample_rate, axis=0)
repeat_neg = ops.repeat(embedding_from, self.negative_sample_rate, axis=0)
repeat_neg_batch_dim = ops.shape(repeat_neg)[0]
shuffled_indices = keras.random.shuffle(
ops.arange(repeat_neg_batch_dim), seed=self.seed_generator)
if keras.config.backend() == "tensorflow":
embedding_neg_from = tf.gather(
repeat_neg, shuffled_indices
)
else:
embedding_neg_from = repeat_neg[shuffled_indices]
# distances between samples (and negative samples)
distance_embedding = ops.concatenate(
[
ops.norm(embedding_to - embedding_from, axis=1),
ops.norm(embedding_neg_to - embedding_neg_from, axis=1),
],
axis=0,
)
# convert distances to probabilities
log_probabilities_distance = convert_distance_to_log_probability(
distance_embedding, self.umap_loss_a, self.umap_loss_b
)
# set true probabilities based on negative sampling
batch_size = ops.shape(embedding_to)[0]
probabilities_graph = ops.concatenate(
[
ops.ones((batch_size,)),
ops.zeros((batch_size * self.negative_sample_rate,)),
],
axis=0
)
# compute cross entropy
(attraction_loss, repellant_loss, ce_loss) = compute_cross_entropy(
probabilities_graph,
log_probabilities_distance,
repulsion_strength=repulsion_strength,
)
return ops.mean(ce_loss)
def _global_correlation_loss(self, y, y_pred):
# flatten data
x = self.flatten(y["global_correlation"])
z_x = self.flatten(y_pred["embedding_to"])
# z score data
def z_score(x):
return (x - ops.mean(x)) / ops.std(x)
x = z_score(x)
z_x = z_score(z_x)
# clip distances to 10 standard deviations for stability
x = ops.clip(x, -10, 10)
z_x = ops.clip(z_x, -10, 10)
dx = ops.norm(x[1:] - x[:-1], axis=1)
dz = ops.norm(z_x[1:] - z_x[:-1], axis=1)
# jitter dz to prevent mode collapse
dz = dz + keras.random.uniform(dz.shape, seed=self.seed_generator) * 1e-10
# compute correlation
corr_d = ops.squeeze(
correlation(
x=ops.expand_dims(dx, -1), y=ops.expand_dims(dz, -1)
)
)
return -corr_d * self.global_correlation_loss_weight
def _parametric_reconstruction_loss(self, y, y_pred):
loss = self.parametric_reconstruction_loss_fn(
y["reconstruction"], y_pred["reconstruction"])
return loss * self.parametric_reconstruction_loss_weight