437 lines
18 KiB
Python
437 lines
18 KiB
Python
from sympy.concrete.products import Product
|
|
from sympy.concrete.summations import Sum
|
|
from sympy.core.numbers import (Rational, oo, pi)
|
|
from sympy.core.relational import Eq
|
|
from sympy.core.singleton import S
|
|
from sympy.core.symbol import symbols
|
|
from sympy.functions.combinatorial.factorials import (RisingFactorial, factorial)
|
|
from sympy.functions.elementary.complexes import polar_lift
|
|
from sympy.functions.elementary.exponential import exp
|
|
from sympy.functions.elementary.miscellaneous import sqrt
|
|
from sympy.functions.elementary.piecewise import Piecewise
|
|
from sympy.functions.special.bessel import besselk
|
|
from sympy.functions.special.gamma_functions import gamma
|
|
from sympy.matrices.dense import eye
|
|
from sympy.matrices.expressions.determinant import Determinant
|
|
from sympy.sets.fancysets import Range
|
|
from sympy.sets.sets import (Interval, ProductSet)
|
|
from sympy.simplify.simplify import simplify
|
|
from sympy.tensor.indexed import (Indexed, IndexedBase)
|
|
from sympy.core.numbers import comp
|
|
from sympy.integrals.integrals import integrate
|
|
from sympy.matrices import Matrix, MatrixSymbol
|
|
from sympy.matrices.expressions.matexpr import MatrixElement
|
|
from sympy.stats import density, median, marginal_distribution, Normal, Laplace, E, sample
|
|
from sympy.stats.joint_rv_types import (JointRV, MultivariateNormalDistribution,
|
|
JointDistributionHandmade, MultivariateT, NormalGamma,
|
|
GeneralizedMultivariateLogGammaOmega as GMVLGO, MultivariateBeta,
|
|
GeneralizedMultivariateLogGamma as GMVLG, MultivariateEwens,
|
|
Multinomial, NegativeMultinomial, MultivariateNormal,
|
|
MultivariateLaplace)
|
|
from sympy.testing.pytest import raises, XFAIL, skip, slow
|
|
from sympy.external import import_module
|
|
|
|
from sympy.abc import x, y
|
|
|
|
|
|
|
|
def test_Normal():
|
|
m = Normal('A', [1, 2], [[1, 0], [0, 1]])
|
|
A = MultivariateNormal('A', [1, 2], [[1, 0], [0, 1]])
|
|
assert m == A
|
|
assert density(m)(1, 2) == 1/(2*pi)
|
|
assert m.pspace.distribution.set == ProductSet(S.Reals, S.Reals)
|
|
raises (ValueError, lambda:m[2])
|
|
n = Normal('B', [1, 2, 3], [[1, 0, 0], [0, 1, 0], [0, 0, 1]])
|
|
p = Normal('C', Matrix([1, 2]), Matrix([[1, 0], [0, 1]]))
|
|
assert density(m)(x, y) == density(p)(x, y)
|
|
assert marginal_distribution(n, 0, 1)(1, 2) == 1/(2*pi)
|
|
raises(ValueError, lambda: marginal_distribution(m))
|
|
assert integrate(density(m)(x, y), (x, -oo, oo), (y, -oo, oo)).evalf() == 1.0
|
|
N = Normal('N', [1, 2], [[x, 0], [0, y]])
|
|
assert density(N)(0, 0) == exp(-((4*x + y)/(2*x*y)))/(2*pi*sqrt(x*y))
|
|
|
|
raises (ValueError, lambda: Normal('M', [1, 2], [[1, 1], [1, -1]]))
|
|
# symbolic
|
|
n = symbols('n', integer=True, positive=True)
|
|
mu = MatrixSymbol('mu', n, 1)
|
|
sigma = MatrixSymbol('sigma', n, n)
|
|
X = Normal('X', mu, sigma)
|
|
assert density(X) == MultivariateNormalDistribution(mu, sigma)
|
|
raises (NotImplementedError, lambda: median(m))
|
|
# Below tests should work after issue #17267 is resolved
|
|
# assert E(X) == mu
|
|
# assert variance(X) == sigma
|
|
|
|
# test symbolic multivariate normal densities
|
|
n = 3
|
|
|
|
Sg = MatrixSymbol('Sg', n, n)
|
|
mu = MatrixSymbol('mu', n, 1)
|
|
obs = MatrixSymbol('obs', n, 1)
|
|
|
|
X = MultivariateNormal('X', mu, Sg)
|
|
density_X = density(X)
|
|
|
|
eval_a = density_X(obs).subs({Sg: eye(3),
|
|
mu: Matrix([0, 0, 0]), obs: Matrix([0, 0, 0])}).doit()
|
|
eval_b = density_X(0, 0, 0).subs({Sg: eye(3), mu: Matrix([0, 0, 0])}).doit()
|
|
|
|
assert eval_a == sqrt(2)/(4*pi**Rational(3/2))
|
|
assert eval_b == sqrt(2)/(4*pi**Rational(3/2))
|
|
|
|
n = symbols('n', integer=True, positive=True)
|
|
|
|
Sg = MatrixSymbol('Sg', n, n)
|
|
mu = MatrixSymbol('mu', n, 1)
|
|
obs = MatrixSymbol('obs', n, 1)
|
|
|
|
X = MultivariateNormal('X', mu, Sg)
|
|
density_X_at_obs = density(X)(obs)
|
|
|
|
expected_density = MatrixElement(
|
|
exp((S(1)/2) * (mu.T - obs.T) * Sg**(-1) * (-mu + obs)) / \
|
|
sqrt((2*pi)**n * Determinant(Sg)), 0, 0)
|
|
|
|
assert density_X_at_obs == expected_density
|
|
|
|
|
|
def test_MultivariateTDist():
|
|
t1 = MultivariateT('T', [0, 0], [[1, 0], [0, 1]], 2)
|
|
assert(density(t1))(1, 1) == 1/(8*pi)
|
|
assert t1.pspace.distribution.set == ProductSet(S.Reals, S.Reals)
|
|
assert integrate(density(t1)(x, y), (x, -oo, oo), \
|
|
(y, -oo, oo)).evalf() == 1.0
|
|
raises(ValueError, lambda: MultivariateT('T', [1, 2], [[1, 1], [1, -1]], 1))
|
|
t2 = MultivariateT('t2', [1, 2], [[x, 0], [0, y]], 1)
|
|
assert density(t2)(1, 2) == 1/(2*pi*sqrt(x*y))
|
|
|
|
|
|
def test_multivariate_laplace():
|
|
raises(ValueError, lambda: Laplace('T', [1, 2], [[1, 2], [2, 1]]))
|
|
L = Laplace('L', [1, 0], [[1, 0], [0, 1]])
|
|
L2 = MultivariateLaplace('L2', [1, 0], [[1, 0], [0, 1]])
|
|
assert density(L)(2, 3) == exp(2)*besselk(0, sqrt(39))/pi
|
|
L1 = Laplace('L1', [1, 2], [[x, 0], [0, y]])
|
|
assert density(L1)(0, 1) == \
|
|
exp(2/y)*besselk(0, sqrt((2 + 4/y + 1/x)/y))/(pi*sqrt(x*y))
|
|
assert L.pspace.distribution.set == ProductSet(S.Reals, S.Reals)
|
|
assert L.pspace.distribution == L2.pspace.distribution
|
|
|
|
|
|
def test_NormalGamma():
|
|
ng = NormalGamma('G', 1, 2, 3, 4)
|
|
assert density(ng)(1, 1) == 32*exp(-4)/sqrt(pi)
|
|
assert ng.pspace.distribution.set == ProductSet(S.Reals, Interval(0, oo))
|
|
raises(ValueError, lambda:NormalGamma('G', 1, 2, 3, -1))
|
|
assert marginal_distribution(ng, 0)(1) == \
|
|
3*sqrt(10)*gamma(Rational(7, 4))/(10*sqrt(pi)*gamma(Rational(5, 4)))
|
|
assert marginal_distribution(ng, y)(1) == exp(Rational(-1, 4))/128
|
|
assert marginal_distribution(ng,[0,1])(x) == x**2*exp(-x/4)/128
|
|
|
|
|
|
def test_GeneralizedMultivariateLogGammaDistribution():
|
|
h = S.Half
|
|
omega = Matrix([[1, h, h, h],
|
|
[h, 1, h, h],
|
|
[h, h, 1, h],
|
|
[h, h, h, 1]])
|
|
v, l, mu = (4, [1, 2, 3, 4], [1, 2, 3, 4])
|
|
y_1, y_2, y_3, y_4 = symbols('y_1:5', real=True)
|
|
delta = symbols('d', positive=True)
|
|
G = GMVLGO('G', omega, v, l, mu)
|
|
Gd = GMVLG('Gd', delta, v, l, mu)
|
|
dend = ("d**4*Sum(4*24**(-n - 4)*(1 - d)**n*exp((n + 4)*(y_1 + 2*y_2 + 3*y_3 "
|
|
"+ 4*y_4) - exp(y_1) - exp(2*y_2)/2 - exp(3*y_3)/3 - exp(4*y_4)/4)/"
|
|
"(gamma(n + 1)*gamma(n + 4)**3), (n, 0, oo))")
|
|
assert str(density(Gd)(y_1, y_2, y_3, y_4)) == dend
|
|
den = ("5*2**(2/3)*5**(1/3)*Sum(4*24**(-n - 4)*(-2**(2/3)*5**(1/3)/4 + 1)**n*"
|
|
"exp((n + 4)*(y_1 + 2*y_2 + 3*y_3 + 4*y_4) - exp(y_1) - exp(2*y_2)/2 - "
|
|
"exp(3*y_3)/3 - exp(4*y_4)/4)/(gamma(n + 1)*gamma(n + 4)**3), (n, 0, oo))/64")
|
|
assert str(density(G)(y_1, y_2, y_3, y_4)) == den
|
|
marg = ("5*2**(2/3)*5**(1/3)*exp(4*y_1)*exp(-exp(y_1))*Integral(exp(-exp(4*G[3])"
|
|
"/4)*exp(16*G[3])*Integral(exp(-exp(3*G[2])/3)*exp(12*G[2])*Integral(exp("
|
|
"-exp(2*G[1])/2)*exp(8*G[1])*Sum((-1/4)**n*(-4 + 2**(2/3)*5**(1/3"
|
|
"))**n*exp(n*y_1)*exp(2*n*G[1])*exp(3*n*G[2])*exp(4*n*G[3])/(24**n*gamma(n + 1)"
|
|
"*gamma(n + 4)**3), (n, 0, oo)), (G[1], -oo, oo)), (G[2], -oo, oo)), (G[3]"
|
|
", -oo, oo))/5308416")
|
|
assert str(marginal_distribution(G, G[0])(y_1)) == marg
|
|
omega_f1 = Matrix([[1, h, h]])
|
|
omega_f2 = Matrix([[1, h, h, h],
|
|
[h, 1, 2, h],
|
|
[h, h, 1, h],
|
|
[h, h, h, 1]])
|
|
omega_f3 = Matrix([[6, h, h, h],
|
|
[h, 1, 2, h],
|
|
[h, h, 1, h],
|
|
[h, h, h, 1]])
|
|
v_f = symbols("v_f", positive=False, real=True)
|
|
l_f = [1, 2, v_f, 4]
|
|
m_f = [v_f, 2, 3, 4]
|
|
omega_f4 = Matrix([[1, h, h, h, h],
|
|
[h, 1, h, h, h],
|
|
[h, h, 1, h, h],
|
|
[h, h, h, 1, h],
|
|
[h, h, h, h, 1]])
|
|
l_f1 = [1, 2, 3, 4, 5]
|
|
omega_f5 = Matrix([[1]])
|
|
mu_f5 = l_f5 = [1]
|
|
|
|
raises(ValueError, lambda: GMVLGO('G', omega_f1, v, l, mu))
|
|
raises(ValueError, lambda: GMVLGO('G', omega_f2, v, l, mu))
|
|
raises(ValueError, lambda: GMVLGO('G', omega_f3, v, l, mu))
|
|
raises(ValueError, lambda: GMVLGO('G', omega, v_f, l, mu))
|
|
raises(ValueError, lambda: GMVLGO('G', omega, v, l_f, mu))
|
|
raises(ValueError, lambda: GMVLGO('G', omega, v, l, m_f))
|
|
raises(ValueError, lambda: GMVLGO('G', omega_f4, v, l, mu))
|
|
raises(ValueError, lambda: GMVLGO('G', omega, v, l_f1, mu))
|
|
raises(ValueError, lambda: GMVLGO('G', omega_f5, v, l_f5, mu_f5))
|
|
raises(ValueError, lambda: GMVLG('G', Rational(3, 2), v, l, mu))
|
|
|
|
|
|
def test_MultivariateBeta():
|
|
a1, a2 = symbols('a1, a2', positive=True)
|
|
a1_f, a2_f = symbols('a1, a2', positive=False, real=True)
|
|
mb = MultivariateBeta('B', [a1, a2])
|
|
mb_c = MultivariateBeta('C', a1, a2)
|
|
assert density(mb)(1, 2) == S(2)**(a2 - 1)*gamma(a1 + a2)/\
|
|
(gamma(a1)*gamma(a2))
|
|
assert marginal_distribution(mb_c, 0)(3) == S(3)**(a1 - 1)*gamma(a1 + a2)/\
|
|
(a2*gamma(a1)*gamma(a2))
|
|
raises(ValueError, lambda: MultivariateBeta('b1', [a1_f, a2]))
|
|
raises(ValueError, lambda: MultivariateBeta('b2', [a1, a2_f]))
|
|
raises(ValueError, lambda: MultivariateBeta('b3', [0, 0]))
|
|
raises(ValueError, lambda: MultivariateBeta('b4', [a1_f, a2_f]))
|
|
assert mb.pspace.distribution.set == ProductSet(Interval(0, 1), Interval(0, 1))
|
|
|
|
|
|
def test_MultivariateEwens():
|
|
n, theta, i = symbols('n theta i', positive=True)
|
|
|
|
# tests for integer dimensions
|
|
theta_f = symbols('t_f', negative=True)
|
|
a = symbols('a_1:4', positive = True, integer = True)
|
|
ed = MultivariateEwens('E', 3, theta)
|
|
assert density(ed)(a[0], a[1], a[2]) == Piecewise((6*2**(-a[1])*3**(-a[2])*
|
|
theta**a[0]*theta**a[1]*theta**a[2]/
|
|
(theta*(theta + 1)*(theta + 2)*
|
|
factorial(a[0])*factorial(a[1])*
|
|
factorial(a[2])), Eq(a[0] + 2*a[1] +
|
|
3*a[2], 3)), (0, True))
|
|
assert marginal_distribution(ed, ed[1])(a[1]) == Piecewise((6*2**(-a[1])*
|
|
theta**a[1]/((theta + 1)*
|
|
(theta + 2)*factorial(a[1])),
|
|
Eq(2*a[1] + 1, 3)), (0, True))
|
|
raises(ValueError, lambda: MultivariateEwens('e1', 5, theta_f))
|
|
assert ed.pspace.distribution.set == ProductSet(Range(0, 4, 1),
|
|
Range(0, 2, 1), Range(0, 2, 1))
|
|
|
|
# tests for symbolic dimensions
|
|
eds = MultivariateEwens('E', n, theta)
|
|
a = IndexedBase('a')
|
|
j, k = symbols('j, k')
|
|
den = Piecewise((factorial(n)*Product(theta**a[j]*(j + 1)**(-a[j])/
|
|
factorial(a[j]), (j, 0, n - 1))/RisingFactorial(theta, n),
|
|
Eq(n, Sum((k + 1)*a[k], (k, 0, n - 1)))), (0, True))
|
|
assert density(eds)(a).dummy_eq(den)
|
|
|
|
|
|
def test_Multinomial():
|
|
n, x1, x2, x3, x4 = symbols('n, x1, x2, x3, x4', nonnegative=True, integer=True)
|
|
p1, p2, p3, p4 = symbols('p1, p2, p3, p4', positive=True)
|
|
p1_f, n_f = symbols('p1_f, n_f', negative=True)
|
|
M = Multinomial('M', n, [p1, p2, p3, p4])
|
|
C = Multinomial('C', 3, p1, p2, p3)
|
|
f = factorial
|
|
assert density(M)(x1, x2, x3, x4) == Piecewise((p1**x1*p2**x2*p3**x3*p4**x4*
|
|
f(n)/(f(x1)*f(x2)*f(x3)*f(x4)),
|
|
Eq(n, x1 + x2 + x3 + x4)), (0, True))
|
|
assert marginal_distribution(C, C[0])(x1).subs(x1, 1) ==\
|
|
3*p1*p2**2 +\
|
|
6*p1*p2*p3 +\
|
|
3*p1*p3**2
|
|
raises(ValueError, lambda: Multinomial('b1', 5, [p1, p2, p3, p1_f]))
|
|
raises(ValueError, lambda: Multinomial('b2', n_f, [p1, p2, p3, p4]))
|
|
raises(ValueError, lambda: Multinomial('b3', n, 0.5, 0.4, 0.3, 0.1))
|
|
|
|
|
|
def test_NegativeMultinomial():
|
|
k0, x1, x2, x3, x4 = symbols('k0, x1, x2, x3, x4', nonnegative=True, integer=True)
|
|
p1, p2, p3, p4 = symbols('p1, p2, p3, p4', positive=True)
|
|
p1_f = symbols('p1_f', negative=True)
|
|
N = NegativeMultinomial('N', 4, [p1, p2, p3, p4])
|
|
C = NegativeMultinomial('C', 4, 0.1, 0.2, 0.3)
|
|
g = gamma
|
|
f = factorial
|
|
assert simplify(density(N)(x1, x2, x3, x4) -
|
|
p1**x1*p2**x2*p3**x3*p4**x4*(-p1 - p2 - p3 - p4 + 1)**4*g(x1 + x2 +
|
|
x3 + x4 + 4)/(6*f(x1)*f(x2)*f(x3)*f(x4))) is S.Zero
|
|
assert comp(marginal_distribution(C, C[0])(1).evalf(), 0.33, .01)
|
|
raises(ValueError, lambda: NegativeMultinomial('b1', 5, [p1, p2, p3, p1_f]))
|
|
raises(ValueError, lambda: NegativeMultinomial('b2', k0, 0.5, 0.4, 0.3, 0.4))
|
|
assert N.pspace.distribution.set == ProductSet(Range(0, oo, 1),
|
|
Range(0, oo, 1), Range(0, oo, 1), Range(0, oo, 1))
|
|
|
|
|
|
@slow
|
|
def test_JointPSpace_marginal_distribution():
|
|
T = MultivariateT('T', [0, 0], [[1, 0], [0, 1]], 2)
|
|
got = marginal_distribution(T, T[1])(x)
|
|
ans = sqrt(2)*(x**2/2 + 1)/(4*polar_lift(x**2/2 + 1)**(S(5)/2))
|
|
assert got == ans, got
|
|
assert integrate(marginal_distribution(T, 1)(x), (x, -oo, oo)) == 1
|
|
|
|
t = MultivariateT('T', [0, 0, 0], [[1, 0, 0], [0, 1, 0], [0, 0, 1]], 3)
|
|
assert comp(marginal_distribution(t, 0)(1).evalf(), 0.2, .01)
|
|
|
|
|
|
def test_JointRV():
|
|
x1, x2 = (Indexed('x', i) for i in (1, 2))
|
|
pdf = exp(-x1**2/2 + x1 - x2**2/2 - S.Half)/(2*pi)
|
|
X = JointRV('x', pdf)
|
|
assert density(X)(1, 2) == exp(-2)/(2*pi)
|
|
assert isinstance(X.pspace.distribution, JointDistributionHandmade)
|
|
assert marginal_distribution(X, 0)(2) == sqrt(2)*exp(Rational(-1, 2))/(2*sqrt(pi))
|
|
|
|
|
|
def test_expectation():
|
|
m = Normal('A', [x, y], [[1, 0], [0, 1]])
|
|
assert simplify(E(m[1])) == y
|
|
|
|
|
|
@XFAIL
|
|
def test_joint_vector_expectation():
|
|
m = Normal('A', [x, y], [[1, 0], [0, 1]])
|
|
assert E(m) == (x, y)
|
|
|
|
|
|
def test_sample_numpy():
|
|
distribs_numpy = [
|
|
MultivariateNormal("M", [3, 4], [[2, 1], [1, 2]]),
|
|
MultivariateBeta("B", [0.4, 5, 15, 50, 203]),
|
|
Multinomial("N", 50, [0.3, 0.2, 0.1, 0.25, 0.15])
|
|
]
|
|
size = 3
|
|
numpy = import_module('numpy')
|
|
if not numpy:
|
|
skip('Numpy is not installed. Abort tests for _sample_numpy.')
|
|
else:
|
|
for X in distribs_numpy:
|
|
samps = sample(X, size=size, library='numpy')
|
|
for sam in samps:
|
|
assert tuple(sam) in X.pspace.distribution.set
|
|
N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1)
|
|
raises(NotImplementedError, lambda: sample(N_c, library='numpy'))
|
|
|
|
|
|
def test_sample_scipy():
|
|
distribs_scipy = [
|
|
MultivariateNormal("M", [0, 0], [[0.1, 0.025], [0.025, 0.1]]),
|
|
MultivariateBeta("B", [0.4, 5, 15]),
|
|
Multinomial("N", 8, [0.3, 0.2, 0.1, 0.4])
|
|
]
|
|
|
|
size = 3
|
|
scipy = import_module('scipy')
|
|
if not scipy:
|
|
skip('Scipy not installed. Abort tests for _sample_scipy.')
|
|
else:
|
|
for X in distribs_scipy:
|
|
samps = sample(X, size=size)
|
|
samps2 = sample(X, size=(2, 2))
|
|
for sam in samps:
|
|
assert tuple(sam) in X.pspace.distribution.set
|
|
for i in range(2):
|
|
for j in range(2):
|
|
assert tuple(samps2[i][j]) in X.pspace.distribution.set
|
|
N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1)
|
|
raises(NotImplementedError, lambda: sample(N_c))
|
|
|
|
|
|
def test_sample_pymc():
|
|
distribs_pymc = [
|
|
MultivariateNormal("M", [5, 2], [[1, 0], [0, 1]]),
|
|
MultivariateBeta("B", [0.4, 5, 15]),
|
|
Multinomial("N", 4, [0.3, 0.2, 0.1, 0.4])
|
|
]
|
|
size = 3
|
|
pymc = import_module('pymc')
|
|
if not pymc:
|
|
skip('PyMC is not installed. Abort tests for _sample_pymc.')
|
|
else:
|
|
for X in distribs_pymc:
|
|
samps = sample(X, size=size, library='pymc')
|
|
for sam in samps:
|
|
assert tuple(sam.flatten()) in X.pspace.distribution.set
|
|
N_c = NegativeMultinomial('N', 3, 0.1, 0.1, 0.1)
|
|
raises(NotImplementedError, lambda: sample(N_c, library='pymc'))
|
|
|
|
|
|
def test_sample_seed():
|
|
x1, x2 = (Indexed('x', i) for i in (1, 2))
|
|
pdf = exp(-x1**2/2 + x1 - x2**2/2 - S.Half)/(2*pi)
|
|
X = JointRV('x', pdf)
|
|
|
|
libraries = ['scipy', 'numpy', 'pymc']
|
|
for lib in libraries:
|
|
try:
|
|
imported_lib = import_module(lib)
|
|
if imported_lib:
|
|
s0, s1, s2 = [], [], []
|
|
s0 = sample(X, size=10, library=lib, seed=0)
|
|
s1 = sample(X, size=10, library=lib, seed=0)
|
|
s2 = sample(X, size=10, library=lib, seed=1)
|
|
assert all(s0 == s1)
|
|
assert all(s1 != s2)
|
|
except NotImplementedError:
|
|
continue
|
|
|
|
#
|
|
# XXX: This fails for pymc. Previously the test appeared to pass but that is
|
|
# just because the library argument was not passed so the test always used
|
|
# scipy.
|
|
#
|
|
def test_issue_21057():
|
|
m = Normal("x", [0, 0], [[0, 0], [0, 0]])
|
|
n = MultivariateNormal("x", [0, 0], [[0, 0], [0, 0]])
|
|
p = Normal("x", [0, 0], [[0, 0], [0, 1]])
|
|
assert m == n
|
|
libraries = ('scipy', 'numpy') # , 'pymc') # <-- pymc fails
|
|
for library in libraries:
|
|
try:
|
|
imported_lib = import_module(library)
|
|
if imported_lib:
|
|
s1 = sample(m, size=8, library=library)
|
|
s2 = sample(n, size=8, library=library)
|
|
s3 = sample(p, size=8, library=library)
|
|
assert tuple(s1.flatten()) == tuple(s2.flatten())
|
|
for s in s3:
|
|
assert tuple(s.flatten()) in p.pspace.distribution.set
|
|
except NotImplementedError:
|
|
continue
|
|
|
|
|
|
#
|
|
# When this passes the pymc part can be uncommented in test_issue_21057 above
|
|
# and this can be deleted.
|
|
#
|
|
@XFAIL
|
|
def test_issue_21057_pymc():
|
|
m = Normal("x", [0, 0], [[0, 0], [0, 0]])
|
|
n = MultivariateNormal("x", [0, 0], [[0, 0], [0, 0]])
|
|
p = Normal("x", [0, 0], [[0, 0], [0, 1]])
|
|
assert m == n
|
|
libraries = ('pymc',)
|
|
for library in libraries:
|
|
try:
|
|
imported_lib = import_module(library)
|
|
if imported_lib:
|
|
s1 = sample(m, size=8, library=library)
|
|
s2 = sample(n, size=8, library=library)
|
|
s3 = sample(p, size=8, library=library)
|
|
assert tuple(s1.flatten()) == tuple(s2.flatten())
|
|
for s in s3:
|
|
assert tuple(s.flatten()) in p.pspace.distribution.set
|
|
except NotImplementedError:
|
|
continue
|