215 lines
8.3 KiB
Python
215 lines
8.3 KiB
Python
|
from typing import Dict
|
||
|
|
||
|
import torch
|
||
|
from torch.distributions import Categorical, constraints
|
||
|
from torch.distributions.distribution import Distribution
|
||
|
|
||
|
__all__ = ["MixtureSameFamily"]
|
||
|
|
||
|
|
||
|
class MixtureSameFamily(Distribution):
|
||
|
r"""
|
||
|
The `MixtureSameFamily` distribution implements a (batch of) mixture
|
||
|
distribution where all component are from different parameterizations of
|
||
|
the same distribution type. It is parameterized by a `Categorical`
|
||
|
"selecting distribution" (over `k` component) and a component
|
||
|
distribution, i.e., a `Distribution` with a rightmost batch shape
|
||
|
(equal to `[k]`) which indexes each (batch of) component.
|
||
|
|
||
|
Examples::
|
||
|
|
||
|
>>> # xdoctest: +SKIP("undefined vars")
|
||
|
>>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally
|
||
|
>>> # weighted normal distributions
|
||
|
>>> mix = D.Categorical(torch.ones(5,))
|
||
|
>>> comp = D.Normal(torch.randn(5,), torch.rand(5,))
|
||
|
>>> gmm = MixtureSameFamily(mix, comp)
|
||
|
|
||
|
>>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally
|
||
|
>>> # weighted bivariate normal distributions
|
||
|
>>> mix = D.Categorical(torch.ones(5,))
|
||
|
>>> comp = D.Independent(D.Normal(
|
||
|
... torch.randn(5,2), torch.rand(5,2)), 1)
|
||
|
>>> gmm = MixtureSameFamily(mix, comp)
|
||
|
|
||
|
>>> # Construct a batch of 3 Gaussian Mixture Models in 2D each
|
||
|
>>> # consisting of 5 random weighted bivariate normal distributions
|
||
|
>>> mix = D.Categorical(torch.rand(3,5))
|
||
|
>>> comp = D.Independent(D.Normal(
|
||
|
... torch.randn(3,5,2), torch.rand(3,5,2)), 1)
|
||
|
>>> gmm = MixtureSameFamily(mix, comp)
|
||
|
|
||
|
Args:
|
||
|
mixture_distribution: `torch.distributions.Categorical`-like
|
||
|
instance. Manages the probability of selecting component.
|
||
|
The number of categories must match the rightmost batch
|
||
|
dimension of the `component_distribution`. Must have either
|
||
|
scalar `batch_shape` or `batch_shape` matching
|
||
|
`component_distribution.batch_shape[:-1]`
|
||
|
component_distribution: `torch.distributions.Distribution`-like
|
||
|
instance. Right-most batch dimension indexes component.
|
||
|
"""
|
||
|
arg_constraints: Dict[str, constraints.Constraint] = {}
|
||
|
has_rsample = False
|
||
|
|
||
|
def __init__(
|
||
|
self, mixture_distribution, component_distribution, validate_args=None
|
||
|
):
|
||
|
self._mixture_distribution = mixture_distribution
|
||
|
self._component_distribution = component_distribution
|
||
|
|
||
|
if not isinstance(self._mixture_distribution, Categorical):
|
||
|
raise ValueError(
|
||
|
" The Mixture distribution needs to be an "
|
||
|
" instance of torch.distributions.Categorical"
|
||
|
)
|
||
|
|
||
|
if not isinstance(self._component_distribution, Distribution):
|
||
|
raise ValueError(
|
||
|
"The Component distribution need to be an "
|
||
|
"instance of torch.distributions.Distribution"
|
||
|
)
|
||
|
|
||
|
# Check that batch size matches
|
||
|
mdbs = self._mixture_distribution.batch_shape
|
||
|
cdbs = self._component_distribution.batch_shape[:-1]
|
||
|
for size1, size2 in zip(reversed(mdbs), reversed(cdbs)):
|
||
|
if size1 != 1 and size2 != 1 and size1 != size2:
|
||
|
raise ValueError(
|
||
|
f"`mixture_distribution.batch_shape` ({mdbs}) is not "
|
||
|
"compatible with `component_distribution."
|
||
|
f"batch_shape`({cdbs})"
|
||
|
)
|
||
|
|
||
|
# Check that the number of mixture component matches
|
||
|
km = self._mixture_distribution.logits.shape[-1]
|
||
|
kc = self._component_distribution.batch_shape[-1]
|
||
|
if km is not None and kc is not None and km != kc:
|
||
|
raise ValueError(
|
||
|
f"`mixture_distribution component` ({km}) does not"
|
||
|
" equal `component_distribution.batch_shape[-1]`"
|
||
|
f" ({kc})"
|
||
|
)
|
||
|
self._num_component = km
|
||
|
|
||
|
event_shape = self._component_distribution.event_shape
|
||
|
self._event_ndims = len(event_shape)
|
||
|
super().__init__(
|
||
|
batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args
|
||
|
)
|
||
|
|
||
|
def expand(self, batch_shape, _instance=None):
|
||
|
batch_shape = torch.Size(batch_shape)
|
||
|
batch_shape_comp = batch_shape + (self._num_component,)
|
||
|
new = self._get_checked_instance(MixtureSameFamily, _instance)
|
||
|
new._component_distribution = self._component_distribution.expand(
|
||
|
batch_shape_comp
|
||
|
)
|
||
|
new._mixture_distribution = self._mixture_distribution.expand(batch_shape)
|
||
|
new._num_component = self._num_component
|
||
|
new._event_ndims = self._event_ndims
|
||
|
event_shape = new._component_distribution.event_shape
|
||
|
super(MixtureSameFamily, new).__init__(
|
||
|
batch_shape=batch_shape, event_shape=event_shape, validate_args=False
|
||
|
)
|
||
|
new._validate_args = self._validate_args
|
||
|
return new
|
||
|
|
||
|
@constraints.dependent_property
|
||
|
def support(self):
|
||
|
# FIXME this may have the wrong shape when support contains batched
|
||
|
# parameters
|
||
|
return self._component_distribution.support
|
||
|
|
||
|
@property
|
||
|
def mixture_distribution(self):
|
||
|
return self._mixture_distribution
|
||
|
|
||
|
@property
|
||
|
def component_distribution(self):
|
||
|
return self._component_distribution
|
||
|
|
||
|
@property
|
||
|
def mean(self):
|
||
|
probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
|
||
|
return torch.sum(
|
||
|
probs * self.component_distribution.mean, dim=-1 - self._event_ndims
|
||
|
) # [B, E]
|
||
|
|
||
|
@property
|
||
|
def variance(self):
|
||
|
# Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X])
|
||
|
probs = self._pad_mixture_dimensions(self.mixture_distribution.probs)
|
||
|
mean_cond_var = torch.sum(
|
||
|
probs * self.component_distribution.variance, dim=-1 - self._event_ndims
|
||
|
)
|
||
|
var_cond_mean = torch.sum(
|
||
|
probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0),
|
||
|
dim=-1 - self._event_ndims,
|
||
|
)
|
||
|
return mean_cond_var + var_cond_mean
|
||
|
|
||
|
def cdf(self, x):
|
||
|
x = self._pad(x)
|
||
|
cdf_x = self.component_distribution.cdf(x)
|
||
|
mix_prob = self.mixture_distribution.probs
|
||
|
|
||
|
return torch.sum(cdf_x * mix_prob, dim=-1)
|
||
|
|
||
|
def log_prob(self, x):
|
||
|
if self._validate_args:
|
||
|
self._validate_sample(x)
|
||
|
x = self._pad(x)
|
||
|
log_prob_x = self.component_distribution.log_prob(x) # [S, B, k]
|
||
|
log_mix_prob = torch.log_softmax(
|
||
|
self.mixture_distribution.logits, dim=-1
|
||
|
) # [B, k]
|
||
|
return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B]
|
||
|
|
||
|
def sample(self, sample_shape=torch.Size()):
|
||
|
with torch.no_grad():
|
||
|
sample_len = len(sample_shape)
|
||
|
batch_len = len(self.batch_shape)
|
||
|
gather_dim = sample_len + batch_len
|
||
|
es = self.event_shape
|
||
|
|
||
|
# mixture samples [n, B]
|
||
|
mix_sample = self.mixture_distribution.sample(sample_shape)
|
||
|
mix_shape = mix_sample.shape
|
||
|
|
||
|
# component samples [n, B, k, E]
|
||
|
comp_samples = self.component_distribution.sample(sample_shape)
|
||
|
|
||
|
# Gather along the k dimension
|
||
|
mix_sample_r = mix_sample.reshape(
|
||
|
mix_shape + torch.Size([1] * (len(es) + 1))
|
||
|
)
|
||
|
mix_sample_r = mix_sample_r.repeat(
|
||
|
torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es
|
||
|
)
|
||
|
|
||
|
samples = torch.gather(comp_samples, gather_dim, mix_sample_r)
|
||
|
return samples.squeeze(gather_dim)
|
||
|
|
||
|
def _pad(self, x):
|
||
|
return x.unsqueeze(-1 - self._event_ndims)
|
||
|
|
||
|
def _pad_mixture_dimensions(self, x):
|
||
|
dist_batch_ndims = len(self.batch_shape)
|
||
|
cat_batch_ndims = len(self.mixture_distribution.batch_shape)
|
||
|
pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims
|
||
|
xs = x.shape
|
||
|
x = x.reshape(
|
||
|
xs[:-1]
|
||
|
+ torch.Size(pad_ndims * [1])
|
||
|
+ xs[-1:]
|
||
|
+ torch.Size(self._event_ndims * [1])
|
||
|
)
|
||
|
return x
|
||
|
|
||
|
def __repr__(self):
|
||
|
args_string = (
|
||
|
f"\n {self.mixture_distribution},\n {self.component_distribution}"
|
||
|
)
|
||
|
return "MixtureSameFamily" + "(" + args_string + ")"
|