""" Joint Random Variables Module See Also ======== sympy.stats.rv sympy.stats.frv sympy.stats.crv sympy.stats.drv """ from math import prod from sympy.core.basic import Basic from sympy.core.function import Lambda from sympy.core.singleton import S from sympy.core.symbol import (Dummy, Symbol) from sympy.core.sympify import sympify from sympy.sets.sets import ProductSet from sympy.tensor.indexed import Indexed from sympy.concrete.products import Product from sympy.concrete.summations import Sum, summation from sympy.core.containers import Tuple from sympy.integrals.integrals import Integral, integrate from sympy.matrices import ImmutableMatrix, matrix2numpy, list2numpy from sympy.stats.crv import SingleContinuousDistribution, SingleContinuousPSpace from sympy.stats.drv import SingleDiscreteDistribution, SingleDiscretePSpace from sympy.stats.rv import (ProductPSpace, NamedArgsMixin, Distribution, ProductDomain, RandomSymbol, random_symbols, SingleDomain, _symbol_converter) from sympy.utilities.iterables import iterable from sympy.utilities.misc import filldedent from sympy.external import import_module # __all__ = ['marginal_distribution'] class JointPSpace(ProductPSpace): """ Represents a joint probability space. Represented using symbols for each component and a distribution. """ def __new__(cls, sym, dist): if isinstance(dist, SingleContinuousDistribution): return SingleContinuousPSpace(sym, dist) if isinstance(dist, SingleDiscreteDistribution): return SingleDiscretePSpace(sym, dist) sym = _symbol_converter(sym) return Basic.__new__(cls, sym, dist) @property def set(self): return self.domain.set @property def symbol(self): return self.args[0] @property def distribution(self): return self.args[1] @property def value(self): return JointRandomSymbol(self.symbol, self) @property def component_count(self): _set = self.distribution.set if isinstance(_set, ProductSet): return S(len(_set.args)) elif isinstance(_set, Product): return _set.limits[0][-1] return S.One @property def pdf(self): sym = [Indexed(self.symbol, i) for i in range(self.component_count)] return self.distribution(*sym) @property def domain(self): rvs = random_symbols(self.distribution) if not rvs: return SingleDomain(self.symbol, self.distribution.set) return ProductDomain(*[rv.pspace.domain for rv in rvs]) def component_domain(self, index): return self.set.args[index] def marginal_distribution(self, *indices): count = self.component_count if count.atoms(Symbol): raise ValueError("Marginal distributions cannot be computed " "for symbolic dimensions. It is a work under progress.") orig = [Indexed(self.symbol, i) for i in range(count)] all_syms = [Symbol(str(i)) for i in orig] replace_dict = dict(zip(all_syms, orig)) sym = tuple(Symbol(str(Indexed(self.symbol, i))) for i in indices) limits = [[i,] for i in all_syms if i not in sym] index = 0 for i in range(count): if i not in indices: limits[index].append(self.distribution.set.args[i]) limits[index] = tuple(limits[index]) index += 1 if self.distribution.is_Continuous: f = Lambda(sym, integrate(self.distribution(*all_syms), *limits)) elif self.distribution.is_Discrete: f = Lambda(sym, summation(self.distribution(*all_syms), *limits)) return f.xreplace(replace_dict) def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs): syms = tuple(self.value[i] for i in range(self.component_count)) rvs = rvs or syms if not any(i in rvs for i in syms): return expr expr = expr*self.pdf for rv in rvs: if isinstance(rv, Indexed): expr = expr.xreplace({rv: Indexed(str(rv.base), rv.args[1])}) elif isinstance(rv, RandomSymbol): expr = expr.xreplace({rv: rv.symbol}) if self.value in random_symbols(expr): raise NotImplementedError(filldedent(''' Expectations of expression with unindexed joint random symbols cannot be calculated yet.''')) limits = tuple((Indexed(str(rv.base),rv.args[1]), self.distribution.set.args[rv.args[1]]) for rv in syms) return Integral(expr, *limits) def where(self, condition): raise NotImplementedError() def compute_density(self, expr): raise NotImplementedError() def sample(self, size=(), library='scipy', seed=None): """ Internal sample method Returns dictionary mapping RandomSymbol to realization value. """ return {RandomSymbol(self.symbol, self): self.distribution.sample(size, library=library, seed=seed)} def probability(self, condition): raise NotImplementedError() class SampleJointScipy: """Returns the sample from scipy of the given distribution""" def __new__(cls, dist, size, seed=None): return cls._sample_scipy(dist, size, seed) @classmethod def _sample_scipy(cls, dist, size, seed): """Sample from SciPy.""" import numpy if seed is None or isinstance(seed, int): rand_state = numpy.random.default_rng(seed=seed) else: rand_state = seed from scipy import stats as scipy_stats scipy_rv_map = { 'MultivariateNormalDistribution': lambda dist, size: scipy_stats.multivariate_normal.rvs( mean=matrix2numpy(dist.mu).flatten(), cov=matrix2numpy(dist.sigma), size=size, random_state=rand_state), 'MultivariateBetaDistribution': lambda dist, size: scipy_stats.dirichlet.rvs( alpha=list2numpy(dist.alpha, float).flatten(), size=size, random_state=rand_state), 'MultinomialDistribution': lambda dist, size: scipy_stats.multinomial.rvs( n=int(dist.n), p=list2numpy(dist.p, float).flatten(), size=size, random_state=rand_state) } sample_shape = { 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape, 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape, 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape } dist_list = scipy_rv_map.keys() if dist.__class__.__name__ not in dist_list: return None samples = scipy_rv_map[dist.__class__.__name__](dist, size) return samples.reshape(size + sample_shape[dist.__class__.__name__](dist)) class SampleJointNumpy: """Returns the sample from numpy of the given distribution""" def __new__(cls, dist, size, seed=None): return cls._sample_numpy(dist, size, seed) @classmethod def _sample_numpy(cls, dist, size, seed): """Sample from NumPy.""" import numpy if seed is None or isinstance(seed, int): rand_state = numpy.random.default_rng(seed=seed) else: rand_state = seed numpy_rv_map = { 'MultivariateNormalDistribution': lambda dist, size: rand_state.multivariate_normal( mean=matrix2numpy(dist.mu, float).flatten(), cov=matrix2numpy(dist.sigma, float), size=size), 'MultivariateBetaDistribution': lambda dist, size: rand_state.dirichlet( alpha=list2numpy(dist.alpha, float).flatten(), size=size), 'MultinomialDistribution': lambda dist, size: rand_state.multinomial( n=int(dist.n), pvals=list2numpy(dist.p, float).flatten(), size=size) } sample_shape = { 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape, 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape, 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape } dist_list = numpy_rv_map.keys() if dist.__class__.__name__ not in dist_list: return None samples = numpy_rv_map[dist.__class__.__name__](dist, prod(size)) return samples.reshape(size + sample_shape[dist.__class__.__name__](dist)) class SampleJointPymc: """Returns the sample from pymc of the given distribution""" def __new__(cls, dist, size, seed=None): return cls._sample_pymc(dist, size, seed) @classmethod def _sample_pymc(cls, dist, size, seed): """Sample from PyMC.""" try: import pymc except ImportError: import pymc3 as pymc pymc_rv_map = { 'MultivariateNormalDistribution': lambda dist: pymc.MvNormal('X', mu=matrix2numpy(dist.mu, float).flatten(), cov=matrix2numpy(dist.sigma, float), shape=(1, dist.mu.shape[0])), 'MultivariateBetaDistribution': lambda dist: pymc.Dirichlet('X', a=list2numpy(dist.alpha, float).flatten()), 'MultinomialDistribution': lambda dist: pymc.Multinomial('X', n=int(dist.n), p=list2numpy(dist.p, float).flatten(), shape=(1, len(dist.p))) } sample_shape = { 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape, 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape, 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape } dist_list = pymc_rv_map.keys() if dist.__class__.__name__ not in dist_list: return None import logging logging.getLogger("pymc3").setLevel(logging.ERROR) with pymc.Model(): pymc_rv_map[dist.__class__.__name__](dist) samples = pymc.sample(draws=prod(size), chains=1, progressbar=False, random_seed=seed, return_inferencedata=False, compute_convergence_checks=False)[:]['X'] return samples.reshape(size + sample_shape[dist.__class__.__name__](dist)) _get_sample_class_jrv = { 'scipy': SampleJointScipy, 'pymc3': SampleJointPymc, 'pymc': SampleJointPymc, 'numpy': SampleJointNumpy } class JointDistribution(Distribution, NamedArgsMixin): """ Represented by the random variables part of the joint distribution. Contains methods for PDF, CDF, sampling, marginal densities, etc. """ _argnames = ('pdf', ) def __new__(cls, *args): args = list(map(sympify, args)) for i in range(len(args)): if isinstance(args[i], list): args[i] = ImmutableMatrix(args[i]) return Basic.__new__(cls, *args) @property def domain(self): return ProductDomain(self.symbols) @property def pdf(self): return self.density.args[1] def cdf(self, other): if not isinstance(other, dict): raise ValueError("%s should be of type dict, got %s"%(other, type(other))) rvs = other.keys() _set = self.domain.set.sets expr = self.pdf(tuple(i.args[0] for i in self.symbols)) for i in range(len(other)): if rvs[i].is_Continuous: density = Integral(expr, (rvs[i], _set[i].inf, other[rvs[i]])) elif rvs[i].is_Discrete: density = Sum(expr, (rvs[i], _set[i].inf, other[rvs[i]])) return density def sample(self, size=(), library='scipy', seed=None): """ A random realization from the distribution """ libraries = ('scipy', 'numpy', 'pymc3', 'pymc') if library not in libraries: raise NotImplementedError("Sampling from %s is not supported yet." % str(library)) if not import_module(library): raise ValueError("Failed to import %s" % library) samps = _get_sample_class_jrv[library](self, size, seed=seed) if samps is not None: return samps raise NotImplementedError( "Sampling for %s is not currently implemented from %s" % (self.__class__.__name__, library) ) def __call__(self, *args): return self.pdf(*args) class JointRandomSymbol(RandomSymbol): """ Representation of random symbols with joint probability distributions to allow indexing." """ def __getitem__(self, key): if isinstance(self.pspace, JointPSpace): if (self.pspace.component_count <= key) == True: raise ValueError("Index keys for %s can only up to %s." % (self.name, self.pspace.component_count - 1)) return Indexed(self, key) class MarginalDistribution(Distribution): """ Represents the marginal distribution of a joint probability space. Initialised using a probability distribution and random variables(or their indexed components) which should be a part of the resultant distribution. """ def __new__(cls, dist, *rvs): if len(rvs) == 1 and iterable(rvs[0]): rvs = tuple(rvs[0]) if not all(isinstance(rv, (Indexed, RandomSymbol)) for rv in rvs): raise ValueError(filldedent('''Marginal distribution can be intitialised only in terms of random variables or indexed random variables''')) rvs = Tuple.fromiter(rv for rv in rvs) if not isinstance(dist, JointDistribution) and len(random_symbols(dist)) == 0: return dist return Basic.__new__(cls, dist, rvs) def check(self): pass @property def set(self): rvs = [i for i in self.args[1] if isinstance(i, RandomSymbol)] return ProductSet(*[rv.pspace.set for rv in rvs]) @property def symbols(self): rvs = self.args[1] return {rv.pspace.symbol for rv in rvs} def pdf(self, *x): expr, rvs = self.args[0], self.args[1] marginalise_out = [i for i in random_symbols(expr) if i not in rvs] if isinstance(expr, JointDistribution): count = len(expr.domain.args) x = Dummy('x', real=True) syms = tuple(Indexed(x, i) for i in count) expr = expr.pdf(syms) else: syms = tuple(rv.pspace.symbol if isinstance(rv, RandomSymbol) else rv.args[0] for rv in rvs) return Lambda(syms, self.compute_pdf(expr, marginalise_out))(*x) def compute_pdf(self, expr, rvs): for rv in rvs: lpdf = 1 if isinstance(rv, RandomSymbol): lpdf = rv.pspace.pdf expr = self.marginalise_out(expr*lpdf, rv) return expr def marginalise_out(self, expr, rv): from sympy.concrete.summations import Sum if isinstance(rv, RandomSymbol): dom = rv.pspace.set elif isinstance(rv, Indexed): dom = rv.base.component_domain( rv.pspace.component_domain(rv.args[1])) expr = expr.xreplace({rv: rv.pspace.symbol}) if rv.pspace.is_Continuous: #TODO: Modify to support integration #for all kinds of sets. expr = Integral(expr, (rv.pspace.symbol, dom)) elif rv.pspace.is_Discrete: #incorporate this into `Sum`/`summation` if dom in (S.Integers, S.Naturals, S.Naturals0): dom = (dom.inf, dom.sup) expr = Sum(expr, (rv.pspace.symbol, dom)) return expr def __call__(self, *args): return self.pdf(*args)