Added common unit tests for conditional MI calculators, and added implementations of the tests for Kraskov and Gaussian conditional MI calculators

This commit is contained in:
joseph.lizier 2013-01-14 12:57:01 +00:00
parent f466012556
commit 40e7d4b211
3 changed files with 185 additions and 0 deletions

View File

@ -0,0 +1,99 @@
package infodynamics.measures.continuous;
import infodynamics.utils.MatrixUtils;
import infodynamics.utils.RandomGenerator;
import junit.framework.TestCase;
public abstract class ConditionalMutualInfoMultiVariateAbstractTester
extends TestCase {
/**
* Confirm that the local values average correctly back to the average value
*
* @param condMiCalc a pre-constructed ConditionalMutualInfoCalculatorMultiVariate object
* @param dimensions number of dimensions for the source and dest data to use
* @param timeSteps number of time steps for the random data
*/
public void testLocalsAverageCorrectly(ConditionalMutualInfoCalculatorMultiVariate condMiCalc,
int dimensions, int timeSteps)
throws Exception {
condMiCalc.initialise(dimensions, dimensions, dimensions);
// generate some random data
RandomGenerator rg = new RandomGenerator();
double[][] sourceData = rg.generateNormalData(timeSteps, dimensions,
0, 1);
double[][] destData = rg.generateNormalData(timeSteps, dimensions,
0, 1);
double[][] condData = rg.generateNormalData(timeSteps, dimensions,
0, 1);
condMiCalc.setObservations(sourceData, destData, condData);
//teCalc.setDebug(true);
double condmi = condMiCalc.computeAverageLocalOfObservations();
//miCalc.setDebug(false);
double[] condMiLocal = condMiCalc.computeLocalOfPreviousObservations();
System.out.printf("Average was %.5f\n", condmi);
assertEquals(condmi, MatrixUtils.mean(condMiLocal), 0.00001);
}
/**
* Confirm that significance testing doesn't alter the average that
* would be returned.
*
* @param condMiCalc a pre-constructed ConditionalMutualInfoCalculatorMultiVariate object
* @param dimensions number of dimensions for the source and dest data to use
* @param timeSteps number of time steps for the random data
* @throws Exception
*/
public void testComputeSignificanceDoesntAlterAverage(ConditionalMutualInfoCalculatorMultiVariate condMiCalc,
int dimensions, int timeSteps) throws Exception {
condMiCalc.initialise(dimensions, dimensions, dimensions);
// generate some random data
RandomGenerator rg = new RandomGenerator();
double[][] sourceData = rg.generateNormalData(timeSteps, dimensions,
0, 1);
double[][] destData = rg.generateNormalData(timeSteps, dimensions,
0, 1);
double[][] condData = rg.generateNormalData(timeSteps, dimensions,
0, 1);
condMiCalc.setObservations(sourceData, destData, condData);
//condMiCalc.setDebug(true);
double condMi = condMiCalc.computeAverageLocalOfObservations();
//condMiCalc.setDebug(false);
//double[] condMiLocal = miCalc.computeLocalOfPreviousObservations();
System.out.printf("Average was %.5f\n", condMi);
// Now look at statistical significance tests
int[][] newOrderings = rg.generateDistinctRandomPerturbations(
timeSteps, 100);
// Compute significance for permuting first variable
condMiCalc.computeSignificance(1, newOrderings);
// And compute the average value again to check that it's consistent:
for (int i = 0; i < 10; i++) {
double averageCheck1 = condMiCalc.computeAverageLocalOfObservations();
assertEquals(condMi, averageCheck1);
}
// Compute significance for permuting second variable
condMiCalc.computeSignificance(2, newOrderings);
// And compute the average value again to check that it's consistent:
for (int i = 0; i < 10; i++) {
double averageCheck1 = condMiCalc.computeAverageLocalOfObservations();
assertEquals(condMi, averageCheck1);
}
}
}

View File

@ -0,0 +1,20 @@
package infodynamics.measures.continuous.gaussian;
import infodynamics.measures.continuous.ConditionalMutualInfoMultiVariateAbstractTester;
public class ConditionalMutualInfoMultiVariateTester extends
ConditionalMutualInfoMultiVariateAbstractTester {
public void testLocalsAverageCorrectly() throws Exception {
ConditionalMutualInfoCalculatorMultiVariateGaussian condMiCalc =
new ConditionalMutualInfoCalculatorMultiVariateGaussian();
super.testLocalsAverageCorrectly(condMiCalc, 2, 100);
}
public void testComputeSignificanceDoesntAlterAverage() throws Exception {
ConditionalMutualInfoCalculatorMultiVariateGaussian condMiCalc =
new ConditionalMutualInfoCalculatorMultiVariateGaussian();
super.testComputeSignificanceDoesntAlterAverage(condMiCalc, 2, 100);
}
}

View File

@ -0,0 +1,66 @@
package infodynamics.measures.continuous.kraskov;
public class ConditionalMutualInfoMultiVariateTester
extends infodynamics.measures.continuous.ConditionalMutualInfoMultiVariateAbstractTester {
/**
* Utility function to create a calculator for the given algorithm number
*
* @param algNumber
* @return
*/
public ConditionalMutualInfoCalculatorMultiVariateKraskov getNewCalc(int algNumber) {
ConditionalMutualInfoCalculatorMultiVariateKraskov condMiCalc = null;
if (algNumber == 1) {
condMiCalc = new ConditionalMutualInfoCalculatorMultiVariateKraskov1();
} else if (algNumber == 2) {
condMiCalc = new ConditionalMutualInfoCalculatorMultiVariateKraskov2();
}
return condMiCalc;
}
/**
* Confirm that the local values average correctly back to the average value
*
*/
public void checkLocalsAverageCorrectly(int algNumber) throws Exception {
ConditionalMutualInfoCalculatorMultiVariateKraskov miCalc = getNewCalc(algNumber);
String kraskov_K = "4";
miCalc.setProperty(
MutualInfoCalculatorMultiVariateKraskov.PROP_K,
kraskov_K);
super.testLocalsAverageCorrectly(miCalc, 2, 100);
}
public void testLocalsAverageCorrectly() throws Exception {
checkLocalsAverageCorrectly(1);
checkLocalsAverageCorrectly(2);
}
/**
* Confirm that significance testing doesn't alter the average that
* would be returned.
*
* @throws Exception
*/
public void checkComputeSignificanceDoesntAlterAverage(int algNumber) throws Exception {
ConditionalMutualInfoCalculatorMultiVariateKraskov condMiCalc = getNewCalc(algNumber);
String kraskov_K = "4";
condMiCalc.setProperty(
MutualInfoCalculatorMultiVariateKraskov.PROP_K,
kraskov_K);
super.testComputeSignificanceDoesntAlterAverage(condMiCalc, 2, 100);
}
public void testComputeSignificanceDoesntAlterAverage() throws Exception {
checkComputeSignificanceDoesntAlterAverage(1);
checkComputeSignificanceDoesntAlterAverage(2);
}
}