mirror of https://github.com/jlizier/jidt
Added test for Cholesky decomposition, inversion of symmetric positive-definite matrices, and normal PDF/CDF
This commit is contained in:
parent
d19045014a
commit
4b3c84cd22
|
@ -12,6 +12,8 @@ import junit.framework.TestCase;
|
|||
*/
|
||||
public class MathsUtilsTest extends TestCase {
|
||||
|
||||
private static double OCTAVE_RESOLUTION = 0.00001;
|
||||
|
||||
/**
|
||||
* Confirm that our erf() function is correct to 6 dp
|
||||
*/
|
||||
|
@ -132,5 +134,316 @@ public class MathsUtilsTest extends TestCase {
|
|||
MathsUtils.chiSquareCdf(n*0.1, 10), 0.000001);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Test our Cholesky decomposition implementation
|
||||
*
|
||||
* @throws Exception
|
||||
*/
|
||||
public void testCholesky() throws Exception {
|
||||
|
||||
// Check some ordinary Cholesky decompositions:
|
||||
|
||||
double[][] A = {{6, 2, 3}, {2, 5, 1}, {3, 1, 4}};
|
||||
// Expected result from Octave:
|
||||
double[][] expectedL = {{2.44949, 0, 0}, {0.81650, 2.08167, 0},
|
||||
{1.22474, 0, 1.58114}};
|
||||
double[][] L = MatrixUtils.CholeskyDecomposition(A);
|
||||
checkMatrix(expectedL, L, OCTAVE_RESOLUTION);
|
||||
|
||||
double[][] A2 = {{6, 2, 3, 1}, {2, 5, 1, 0.5}, {3, 1, 4, 2}, {1, 0.5, 2, 3}};
|
||||
// Expected result from Octave:
|
||||
double[][] expectedL2 = {{2.44949, 0, 0, 0}, {0.81650, 2.08167, 0, 0},
|
||||
{1.22474, 0, 1.58114, 0}, {0.40825, 0.08006, 0.94868, 1.38814}};
|
||||
double[][] L2 = MatrixUtils.CholeskyDecomposition(A2);
|
||||
checkMatrix(expectedL2, L2, OCTAVE_RESOLUTION);
|
||||
|
||||
// Now check that it picks up asymmetric A:
|
||||
double[][] asymmetricA = {{6, 2, 3}, {2, 5, 1}, {3, 1.0001, 4}};
|
||||
boolean flaggedException = false;
|
||||
try {
|
||||
MatrixUtils.CholeskyDecomposition(asymmetricA);
|
||||
} catch (Exception e) {
|
||||
flaggedException = true;
|
||||
}
|
||||
assertTrue(flaggedException);
|
||||
|
||||
// Now check that it picks up if A is not positive definite:
|
||||
double[][] notpositiveDefiniteA = {{1, 2, 3}, {2, 4, 5}, {3, 5, 6}};
|
||||
flaggedException = false;
|
||||
try {
|
||||
MatrixUtils.CholeskyDecomposition(notpositiveDefiniteA);
|
||||
} catch (Exception e) {
|
||||
flaggedException = true;
|
||||
}
|
||||
assertTrue(flaggedException);
|
||||
}
|
||||
|
||||
/**
|
||||
* Test the inversion of symmetric positive definite matrices
|
||||
*
|
||||
* @throws Exception
|
||||
*/
|
||||
public void testInverseOfSymmPosDefMatrices() throws Exception {
|
||||
// Check some ordinary matrices:
|
||||
|
||||
double[][] A = {{6, 2, 3}, {2, 5, 1}, {3, 1, 4}};
|
||||
// Expected result from Octave:
|
||||
double[][] expectedInv = {{0.29231, -0.07692, -0.2},
|
||||
{-0.07692, 0.23077, 0}, {-0.2, 0, 0.4}};
|
||||
double[][] inv = MatrixUtils.invertSymmPosDefMatrix(A);
|
||||
checkMatrix(expectedInv, inv, OCTAVE_RESOLUTION);
|
||||
|
||||
double[][] A2 = {{6, 2, 3, 1}, {2, 5, 1, 0.5}, {3, 1, 4, 2}, {1, 0.5, 2, 3}};
|
||||
// Expected result from Octave:
|
||||
double[][] expectedInv2 = {{0.303393, -0.079840, -0.245509, 0.075848},
|
||||
{-0.079840, 0.231537, 0.011976, -0.019960},
|
||||
{-0.245509, 0.011976, 0.586826, -0.311377},
|
||||
{0.075848, -0.019960, -0.311377, 0.518962}};
|
||||
double[][] inv2 = MatrixUtils.invertSymmPosDefMatrix(A2);
|
||||
checkMatrix(expectedInv2, inv2, OCTAVE_RESOLUTION);
|
||||
|
||||
// Now check that it picks up asymmetric A:
|
||||
double[][] asymmetricA = {{6, 2, 3}, {2, 5, 1}, {3, 1.0001, 4}};
|
||||
boolean flaggedException = false;
|
||||
try {
|
||||
MatrixUtils.invertSymmPosDefMatrix(asymmetricA);
|
||||
} catch (Exception e) {
|
||||
flaggedException = true;
|
||||
}
|
||||
assertTrue(flaggedException);
|
||||
|
||||
// Now check that it picks up if A is not positive definite:
|
||||
double[][] notpositiveDefiniteA = {{1, 2, 3}, {2, 4, 5}, {3, 5, 6}};
|
||||
flaggedException = false;
|
||||
try {
|
||||
MatrixUtils.invertSymmPosDefMatrix(notpositiveDefiniteA);
|
||||
} catch (Exception e) {
|
||||
flaggedException = true;
|
||||
}
|
||||
assertTrue(flaggedException);
|
||||
}
|
||||
|
||||
public void testNormalPdf() throws Exception {
|
||||
// Check values for x = -4:0.1:4 generated by octave
|
||||
double[] expectedPdfMu0Std1 = {0.00013, 0.00020, 0.00029, 0.00042,
|
||||
0.00061, 0.00087, 0.00123, 0.00172, 0.00238, 0.00327, 0.00443,
|
||||
0.00595, 0.00792, 0.01042, 0.01358, 0.01753, 0.02239, 0.02833,
|
||||
0.03547, 0.04398, 0.05399, 0.06562, 0.07895, 0.09405, 0.11092,
|
||||
0.12952, 0.14973, 0.17137, 0.19419, 0.21785, 0.24197, 0.26609,
|
||||
0.28969, 0.31225, 0.33322, 0.35207, 0.36827, 0.38139, 0.39104,
|
||||
0.39695, 0.39894, 0.39695, 0.39104, 0.38139, 0.36827, 0.35207,
|
||||
0.33322, 0.31225, 0.28969, 0.26609, 0.24197, 0.21785, 0.19419,
|
||||
0.17137, 0.14973, 0.12952, 0.11092, 0.09405, 0.07895, 0.06562,
|
||||
0.05399, 0.04398, 0.03547, 0.02833, 0.02239, 0.01753, 0.01358,
|
||||
0.01042, 0.00792, 0.00595, 0.00443, 0.00327, 0.00238, 0.00172,
|
||||
0.00123, 0.00087, 0.00061, 0.00042, 0.00029, 0.00020, 0.00013};
|
||||
|
||||
// Mean 5.5, std 2.3
|
||||
double[] expectedPdfMu5_5Std2_3 = {0.00003, 0.00004, 0.00005, 0.00006,
|
||||
0.00007, 0.00008, 0.00010, 0.00011, 0.00014, 0.00016, 0.00019,
|
||||
0.00022, 0.00026, 0.00030, 0.00035, 0.00041, 0.00048, 0.00055,
|
||||
0.00064, 0.00074, 0.00085, 0.00098, 0.00113, 0.00129, 0.00148,
|
||||
0.00169, 0.00193, 0.00219, 0.00249, 0.00283, 0.00320, 0.00361,
|
||||
0.00407, 0.00458, 0.00515, 0.00577, 0.00646, 0.00722, 0.00804,
|
||||
0.00895, 0.00994, 0.01102, 0.01219, 0.01347, 0.01484, 0.01633,
|
||||
0.01793, 0.01965, 0.02150, 0.02347, 0.02558, 0.02783, 0.03021,
|
||||
0.03274, 0.03541, 0.03823, 0.04119, 0.04430, 0.04756, 0.05096,
|
||||
0.05449, 0.05816, 0.06197, 0.06589, 0.06994, 0.07409, 0.07834,
|
||||
0.08267, 0.08708, 0.09156, 0.09608, 0.10063, 0.10520, 0.10978,
|
||||
0.11433, 0.11885, 0.12331, 0.12770, 0.13199, 0.13618, 0.14022};
|
||||
|
||||
double x = -4.0;
|
||||
for (int xIndex = 0; xIndex < 81; xIndex++) {
|
||||
// System.out.printf("%.1f: %.5f (expected %.5f)\n",
|
||||
// x, MathsUtils.normalPdf(x, mean, stddev), expectedPdfMu0Std1[xIndex]);
|
||||
assertEquals(expectedPdfMu0Std1[xIndex], MathsUtils.normalPdf(x, 0.0, 1.0), OCTAVE_RESOLUTION);
|
||||
assertEquals(expectedPdfMu5_5Std2_3[xIndex], MathsUtils.normalPdf(x, 5.5, 2.3), OCTAVE_RESOLUTION);
|
||||
x += 0.1;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public void testNormalCdf() throws Exception {
|
||||
// Check values for x = -4:0.1:4 generated by octave
|
||||
double[] expectedCdfMu0Std1 = {0.00003, 0.00005, 0.00007, 0.00011, 0.00016,
|
||||
0.00023, 0.00034, 0.00048, 0.00069, 0.00097, 0.00135, 0.00187,
|
||||
0.00256, 0.00347, 0.00466, 0.00621, 0.00820, 0.01072, 0.01390,
|
||||
0.01786, 0.02275, 0.02872, 0.03593, 0.04457, 0.05480, 0.06681,
|
||||
0.08076, 0.09680, 0.11507, 0.13567, 0.15866, 0.18406, 0.21186,
|
||||
0.24196, 0.27425, 0.30854, 0.34458, 0.38209, 0.42074, 0.46017,
|
||||
0.50000, 0.53983, 0.57926, 0.61791, 0.65542, 0.69146, 0.72575,
|
||||
0.75804, 0.78814, 0.81594, 0.84134, 0.86433, 0.88493, 0.90320,
|
||||
0.91924, 0.93319, 0.94520, 0.95543, 0.96407, 0.97128, 0.97725,
|
||||
0.98214, 0.98610, 0.98928, 0.99180, 0.99379, 0.99534, 0.99653,
|
||||
0.99744, 0.99813, 0.99865, 0.99903, 0.99931, 0.99952, 0.99966,
|
||||
0.99977, 0.99984, 0.99989, 0.99993, 0.99995, 0.99997};
|
||||
|
||||
// Mean 5.5, std 2.3
|
||||
double[] expectedCdfMu5_5Std2_3 = {0.00002, 0.00002, 0.00003, 0.00003,
|
||||
0.00004, 0.00005, 0.00005, 0.00007, 0.00008, 0.00009, 0.00011,
|
||||
0.00013, 0.00015, 0.00018, 0.00021, 0.00025, 0.00030, 0.00035,
|
||||
0.00041, 0.00048, 0.00056, 0.00065, 0.00075, 0.00087, 0.00101,
|
||||
0.00117, 0.00135, 0.00156, 0.00179, 0.00206, 0.00236, 0.00270,
|
||||
0.00308, 0.00351, 0.00400, 0.00454, 0.00516, 0.00584, 0.00660,
|
||||
0.00745, 0.00839, 0.00944, 0.01060, 0.01188, 0.01330, 0.01486,
|
||||
0.01657, 0.01845, 0.02050, 0.02275, 0.02520, 0.02787, 0.03077,
|
||||
0.03392, 0.03733, 0.04101, 0.04498, 0.04925, 0.05384, 0.05877,
|
||||
0.06404, 0.06967, 0.07567, 0.08207, 0.08886, 0.09606, 0.10368,
|
||||
0.11173, 0.12021, 0.12915, 0.13853, 0.14836, 0.15866, 0.16940,
|
||||
0.18061, 0.19227, 0.20438, 0.21693, 0.22991, 0.24332, 0.25714};
|
||||
|
||||
double x = -4.0;
|
||||
for (int xIndex = 0; xIndex < 81; xIndex++) {
|
||||
// System.out.printf("%.1f: %.5f (expected %.5f)\n",
|
||||
// x, MathsUtils.normalPdf(x, mean, stddev), expectedPdfMu0Std1[xIndex]);
|
||||
assertEquals(expectedCdfMu0Std1[xIndex], MathsUtils.normalCdf(x, 0.0, 1.0), OCTAVE_RESOLUTION);
|
||||
assertEquals(expectedCdfMu5_5Std2_3[xIndex], MathsUtils.normalCdf(x, 5.5, 2.3), OCTAVE_RESOLUTION);
|
||||
x += 0.1;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public void testMultivariateNormalPdf() throws Exception {
|
||||
// Testing against data generated using public domain octave code
|
||||
// by Paul Kienzle available at
|
||||
// http://octave.1599824.n4.nabble.com/Multivariate-pdf-of-a-normal-distribution-td1601886.html
|
||||
// http://octave.1599824.n4.nabble.com/attachment/1601887/0/mvnpdf.m
|
||||
|
||||
double[] means = {1, -1};
|
||||
double[][] covariance = {{.9, .4}, {.4, .3}};
|
||||
// Evaluated using:
|
||||
// for x = -3:0.5:5
|
||||
// for y = -5:0.5:3
|
||||
// printf("%.5f, ", mvnpdf([x, y], mu, sigma));
|
||||
// end
|
||||
// end
|
||||
double[] expectedMvnPdf1 = {0.00000, 0.00000, 0.00000, 0.00001, 0.00005,
|
||||
0.00005, 0.00001, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00001, 0.00024, 0.00052, 0.00015, 0.00001, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00001, 0.00052, 0.00289,
|
||||
0.00205, 0.00019, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00001, 0.00059, 0.00803, 0.01417, 0.00323, 0.00010, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00033, 0.01129, 0.04944,
|
||||
0.02801, 0.00205, 0.00002, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00010, 0.00803, 0.08727, 0.12272, 0.02232, 0.00052, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00001, 0.00289, 0.07789, 0.27187,
|
||||
0.12272, 0.00716, 0.00005, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00052, 0.03516, 0.30459, 0.34125, 0.04944, 0.00093, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00005, 0.00803, 0.17257, 0.47987,
|
||||
0.17257, 0.00803, 0.00005, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00093, 0.04944, 0.34125, 0.30459, 0.03516, 0.00052, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00005, 0.00716, 0.12272, 0.27187,
|
||||
0.07789, 0.00289, 0.00001, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00052, 0.02232, 0.12272, 0.08727, 0.00803, 0.00010, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00002, 0.00205, 0.02801, 0.04944,
|
||||
0.01129, 0.00033, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00010, 0.00323, 0.01417, 0.00803, 0.00059, 0.00001, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00019, 0.00205, 0.00289,
|
||||
0.00052, 0.00001, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00001, 0.00015, 0.00052, 0.00024, 0.00001, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00001, 0.00005, 0.00005,
|
||||
0.00001, 0.00000, 0.00000, 0.00000};
|
||||
|
||||
int expIndex = 0;
|
||||
double[] observations = new double[2];
|
||||
for (double x1 = -3.0; x1 <= 5.0; x1 += 0.5) {
|
||||
for (double x2 = -5.0; x2 <= 3.0; x2 += 0.5) {
|
||||
observations[0] = x1;
|
||||
observations[1] = x2;
|
||||
assertEquals(expectedMvnPdf1[expIndex],
|
||||
MathsUtils.normalPdf(observations, means, covariance), OCTAVE_RESOLUTION);
|
||||
expIndex++;
|
||||
}
|
||||
}
|
||||
|
||||
double[] means2 = {4, 5, -3};
|
||||
double[][] covariance2 = {{.9, .4, .25}, {.4, .3, .2}, {.25, .2, .6}};
|
||||
// Evaluated using:
|
||||
// for x = 2:1:6
|
||||
// for y = 3:1:7
|
||||
// for z = -5:1:-1
|
||||
// printf("%.5f, ", mvnpdf([x, y, z], means2, covariance2));
|
||||
// end
|
||||
// end
|
||||
// end
|
||||
double[] expectedMvnPdf2 = {0.00013, 0.00017, 0.00003, 0.00000, 0.00000,
|
||||
0.00393, 0.02507, 0.01871, 0.00163, 0.00002, 0.00001, 0.00033,
|
||||
0.00119, 0.00049, 0.00002, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00001,
|
||||
0.00001, 0.00000, 0.00000, 0.00000, 0.00705, 0.04084, 0.02764,
|
||||
0.00219, 0.00002, 0.00080, 0.02220, 0.07156, 0.02698, 0.00119,
|
||||
0.00000, 0.00000, 0.00002, 0.00003, 0.00001, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00082, 0.00433, 0.00266, 0.00019, 0.00000, 0.00383,
|
||||
0.09590, 0.28047, 0.09590, 0.00383, 0.00000, 0.00019, 0.00266,
|
||||
0.00433, 0.00082, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00001, 0.00003,
|
||||
0.00002, 0.00000, 0.00000, 0.00119, 0.02698, 0.07156, 0.02220,
|
||||
0.00080, 0.00002, 0.00219, 0.02764, 0.04084, 0.00705, 0.00000,
|
||||
0.00000, 0.00000, 0.00001, 0.00001, 0.00000, 0.00000, 0.00000,
|
||||
0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
|
||||
0.00002, 0.00049, 0.00119, 0.00033, 0.00001, 0.00002, 0.00163,
|
||||
0.01871, 0.02507, 0.00393, 0.00000, 0.00000, 0.00003, 0.00017,
|
||||
0.00013};
|
||||
|
||||
expIndex = 0;
|
||||
observations = new double[3];
|
||||
for (double x1 = 2.0; x1 <= 6.0; x1 += 1.0) {
|
||||
observations[0] = x1;
|
||||
for (double x2 = 3.0; x2 <= 7.0; x2 += 1.0) {
|
||||
observations[1] = x2;
|
||||
for (double x3 = -5.0; x3 <= -1.0; x3 += 1.0) {
|
||||
observations[2] = x3;
|
||||
assertEquals(expectedMvnPdf2[expIndex],
|
||||
MathsUtils.normalPdf(observations, means2, covariance2),
|
||||
OCTAVE_RESOLUTION);
|
||||
expIndex++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// And now let's test some really random numbers:
|
||||
double[] observations3 = {142.1, -10.25, 0.13, 5.13, 5.935};
|
||||
double[] means3 = {142.2, -10.3, 0.1, 5, 6};
|
||||
double[][] covariance3 = {{1.2, 0.4, 0.6, 0.1, 0.8},
|
||||
{0.4, 2, 0.3, 0.2, 0.4}, {0.6, 0.3, 0.9, 0.5, 0.45},
|
||||
{0.1, 0.2, 0.5, 1.1, 0.8}, {0.8, 0.4, 0.45, 0.8, 1.3}};
|
||||
assertEquals(0.028085,
|
||||
MathsUtils.normalPdf(observations3, means3, covariance3),
|
||||
0.000001);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check that all entries in the given matrix match those of the expected
|
||||
* matrix
|
||||
*
|
||||
* @param expected
|
||||
* @param actual
|
||||
* @param resolution
|
||||
*/
|
||||
protected void checkMatrix(double[][] expected, double[][] actual, double resolution) {
|
||||
for (int r = 0; r < expected.length; r++) {
|
||||
for (int c = 0; c < expected[r].length; c++) {
|
||||
assertEquals(expected[r][c], actual[r][c], resolution);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue