[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request #589 from j-towns/fix-588
Browse files Browse the repository at this point in the history
Fix mvn tests (fixes #588)
  • Loading branch information
j-towns committed Mar 22, 2023
2 parents c6d81ce + 7a66f14 commit 81b288e
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions tests/test_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,23 @@ def test_t_logcdf_broadcast(): combo_check(stats.t.logcdf, [0,2])( [R(4,3)],
def make_psd(mat): return np.dot(mat.T, mat) + np.eye(mat.shape[0])
def test_mvn_pdf(): combo_check(symmetrize_matrix_arg(mvn.pdf, 2), [0, 1, 2])([R(4)], [R(4)], [make_psd(R(4, 4))], allow_singular=[False])
def test_mvn_logpdf(): combo_check(symmetrize_matrix_arg(mvn.logpdf, 2), [0, 1, 2])([R(4)], [R(4)], [make_psd(R(4, 4))], allow_singular=[False])
def test_mvn_entropy():combo_check(mvn.entropy,[0, 1])([R(4)], [make_psd(R(4, 4))])

C = np.zeros((4, 4))
C[0, 0] = C[1, 1] = 1
# C += 1e-3 * np.eye(4)
def test_mvn_pdf_sing_cov(): combo_check(mvn.pdf, [0, 1])([np.concatenate((R(2), np.zeros(2)))], [np.concatenate((R(2), np.zeros(2)))], [C], [True])
def test_mvn_logpdf_sing_cov(): combo_check(mvn.logpdf, [0, 1])([np.concatenate((R(2), np.zeros(2)))], [np.concatenate((R(2), np.zeros(2)))], [C], [True])
def test_mvn_entropy():combo_check(symmetrize_matrix_arg(mvn.entropy, 1), [0, 1])([10 * R(4)], [make_psd(R(4, 4))])

def test_mvn_sing_cov():
cov = np.zeros((4, 4))
cov[0, 0] = cov[1, 1] = 1

# Only allow variations in x along the first two dimensions, because
# variance is zero in the last two.
def pdf(x, mean, cov):
x = np.concatenate([x[:2], mean[2:]])
return symmetrize_matrix_arg(partial(mvn.pdf, allow_singular=True), 2)(x, mean, cov)
combo_check(pdf, [0, 1])([np.concatenate((R(2), np.zeros(2)))], [np.concatenate((R(2), np.zeros(2)))], [cov])

def logpdf(x, mean, cov):
x = np.concatenate([x[:2], mean[2:]])
return symmetrize_matrix_arg(partial(mvn.logpdf, allow_singular=True), 2)(x, mean, cov)
combo_check(logpdf, [0, 1])([np.concatenate((R(2), np.zeros(2)))], [np.concatenate((R(2), np.zeros(2)))], [cov])

def test_mvn_pdf_broadcast(): combo_check(symmetrize_matrix_arg(mvn.pdf, 2), [0, 1, 2])([R(5, 4)], [R(4)], [make_psd(R(4, 4))])
def test_mvn_logpdf_broadcast(): combo_check(symmetrize_matrix_arg(mvn.logpdf, 2), [0, 1, 2])([R(5, 4)], [R(4)], [make_psd(R(4, 4))])
Expand Down

0 comments on commit 81b288e

Please sign in to comment.