from typing import Dict import torch from torch.distributions import Categorical, constraints from torch.distributions.distribution import Distribution __all__ = ["MixtureSameFamily"] class MixtureSameFamily(Distribution): r""" The `MixtureSameFamily` distribution implements a (batch of) mixture distribution where all component are from different parameterizations of the same distribution type. It is parameterized by a `Categorical` "selecting distribution" (over `k` component) and a component distribution, i.e., a `Distribution` with a rightmost batch shape (equal to `[k]`) which indexes each (batch of) component. Examples:: >>> # xdoctest: +SKIP("undefined vars") >>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally >>> # weighted normal distributions >>> mix = D.Categorical(torch.ones(5,)) >>> comp = D.Normal(torch.randn(5,), torch.rand(5,)) >>> gmm = MixtureSameFamily(mix, comp) >>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally >>> # weighted bivariate normal distributions >>> mix = D.Categorical(torch.ones(5,)) >>> comp = D.Independent(D.Normal( ... torch.randn(5,2), torch.rand(5,2)), 1) >>> gmm = MixtureSameFamily(mix, comp) >>> # Construct a batch of 3 Gaussian Mixture Models in 2D each >>> # consisting of 5 random weighted bivariate normal distributions >>> mix = D.Categorical(torch.rand(3,5)) >>> comp = D.Independent(D.Normal( ... torch.randn(3,5,2), torch.rand(3,5,2)), 1) >>> gmm = MixtureSameFamily(mix, comp) Args: mixture_distribution: `torch.distributions.Categorical`-like instance. Manages the probability of selecting component. The number of categories must match the rightmost batch dimension of the `component_distribution`. Must have either scalar `batch_shape` or `batch_shape` matching `component_distribution.batch_shape[:-1]` component_distribution: `torch.distributions.Distribution`-like instance. Right-most batch dimension indexes component. """ arg_constraints: Dict[str, constraints.Constraint] = {} has_rsample = False def __init__( self, mixture_distribution, component_distribution, validate_args=None ): self._mixture_distribution = mixture_distribution self._component_distribution = component_distribution if not isinstance(self._mixture_distribution, Categorical): raise ValueError( " The Mixture distribution needs to be an " " instance of torch.distributions.Categorical" ) if not isinstance(self._component_distribution, Distribution): raise ValueError( "The Component distribution need to be an " "instance of torch.distributions.Distribution" ) # Check that batch size matches mdbs = self._mixture_distribution.batch_shape cdbs = self._component_distribution.batch_shape[:-1] for size1, size2 in zip(reversed(mdbs), reversed(cdbs)): if size1 != 1 and size2 != 1 and size1 != size2: raise ValueError( f"`mixture_distribution.batch_shape` ({mdbs}) is not " "compatible with `component_distribution." f"batch_shape`({cdbs})" ) # Check that the number of mixture component matches km = self._mixture_distribution.logits.shape[-1] kc = self._component_distribution.batch_shape[-1] if km is not None and kc is not None and km != kc: raise ValueError( f"`mixture_distribution component` ({km}) does not" " equal `component_distribution.batch_shape[-1]`" f" ({kc})" ) self._num_component = km event_shape = self._component_distribution.event_shape self._event_ndims = len(event_shape) super().__init__( batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args ) def expand(self, batch_shape, _instance=None): batch_shape = torch.Size(batch_shape) batch_shape_comp = batch_shape + (self._num_component,) new = self._get_checked_instance(MixtureSameFamily, _instance) new._component_distribution = self._component_distribution.expand( batch_shape_comp ) new._mixture_distribution = self._mixture_distribution.expand(batch_shape) new._num_component = self._num_component new._event_ndims = self._event_ndims event_shape = new._component_distribution.event_shape super(MixtureSameFamily, new).__init__( batch_shape=batch_shape, event_shape=event_shape, validate_args=False ) new._validate_args = self._validate_args return new @constraints.dependent_property def support(self): # FIXME this may have the wrong shape when support contains batched # parameters return self._component_distribution.support @property def mixture_distribution(self): return self._mixture_distribution @property def component_distribution(self): return self._component_distribution @property def mean(self): probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) return torch.sum( probs * self.component_distribution.mean, dim=-1 - self._event_ndims ) # [B, E] @property def variance(self): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) mean_cond_var = torch.sum( probs * self.component_distribution.variance, dim=-1 - self._event_ndims ) var_cond_mean = torch.sum( probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0), dim=-1 - self._event_ndims, ) return mean_cond_var + var_cond_mean def cdf(self, x): x = self._pad(x) cdf_x = self.component_distribution.cdf(x) mix_prob = self.mixture_distribution.probs return torch.sum(cdf_x * mix_prob, dim=-1) def log_prob(self, x): if self._validate_args: self._validate_sample(x) x = self._pad(x) log_prob_x = self.component_distribution.log_prob(x) # [S, B, k] log_mix_prob = torch.log_softmax( self.mixture_distribution.logits, dim=-1 ) # [B, k] return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B] def sample(self, sample_shape=torch.Size()): with torch.no_grad(): sample_len = len(sample_shape) batch_len = len(self.batch_shape) gather_dim = sample_len + batch_len es = self.event_shape # mixture samples [n, B] mix_sample = self.mixture_distribution.sample(sample_shape) mix_shape = mix_sample.shape # component samples [n, B, k, E] comp_samples = self.component_distribution.sample(sample_shape) # Gather along the k dimension mix_sample_r = mix_sample.reshape( mix_shape + torch.Size([1] * (len(es) + 1)) ) mix_sample_r = mix_sample_r.repeat( torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es ) samples = torch.gather(comp_samples, gather_dim, mix_sample_r) return samples.squeeze(gather_dim) def _pad(self, x): return x.unsqueeze(-1 - self._event_ndims) def _pad_mixture_dimensions(self, x): dist_batch_ndims = len(self.batch_shape) cat_batch_ndims = len(self.mixture_distribution.batch_shape) pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims xs = x.shape x = x.reshape( xs[:-1] + torch.Size(pad_ndims * [1]) + xs[-1:] + torch.Size(self._event_ndims * [1]) ) return x def __repr__(self): args_string = ( f"\n {self.mixture_distribution},\n {self.component_distribution}" ) return "MixtureSameFamily" + "(" + args_string + ")"