mirror of https://github.com/jlizier/jidt
Added common unit tests for conditional MI calculators, and added implementations of the tests for Kraskov and Gaussian conditional MI calculators
This commit is contained in:
parent
f466012556
commit
40e7d4b211
|
@ -0,0 +1,99 @@
|
|||
package infodynamics.measures.continuous;
|
||||
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
import infodynamics.utils.RandomGenerator;
|
||||
import junit.framework.TestCase;
|
||||
|
||||
public abstract class ConditionalMutualInfoMultiVariateAbstractTester
|
||||
extends TestCase {
|
||||
|
||||
/**
|
||||
* Confirm that the local values average correctly back to the average value
|
||||
*
|
||||
* @param condMiCalc a pre-constructed ConditionalMutualInfoCalculatorMultiVariate object
|
||||
* @param dimensions number of dimensions for the source and dest data to use
|
||||
* @param timeSteps number of time steps for the random data
|
||||
*/
|
||||
public void testLocalsAverageCorrectly(ConditionalMutualInfoCalculatorMultiVariate condMiCalc,
|
||||
int dimensions, int timeSteps)
|
||||
throws Exception {
|
||||
|
||||
condMiCalc.initialise(dimensions, dimensions, dimensions);
|
||||
|
||||
// generate some random data
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
double[][] sourceData = rg.generateNormalData(timeSteps, dimensions,
|
||||
0, 1);
|
||||
double[][] destData = rg.generateNormalData(timeSteps, dimensions,
|
||||
0, 1);
|
||||
double[][] condData = rg.generateNormalData(timeSteps, dimensions,
|
||||
0, 1);
|
||||
|
||||
condMiCalc.setObservations(sourceData, destData, condData);
|
||||
|
||||
//teCalc.setDebug(true);
|
||||
double condmi = condMiCalc.computeAverageLocalOfObservations();
|
||||
//miCalc.setDebug(false);
|
||||
double[] condMiLocal = condMiCalc.computeLocalOfPreviousObservations();
|
||||
|
||||
System.out.printf("Average was %.5f\n", condmi);
|
||||
|
||||
assertEquals(condmi, MatrixUtils.mean(condMiLocal), 0.00001);
|
||||
}
|
||||
|
||||
/**
|
||||
* Confirm that significance testing doesn't alter the average that
|
||||
* would be returned.
|
||||
*
|
||||
* @param condMiCalc a pre-constructed ConditionalMutualInfoCalculatorMultiVariate object
|
||||
* @param dimensions number of dimensions for the source and dest data to use
|
||||
* @param timeSteps number of time steps for the random data
|
||||
* @throws Exception
|
||||
*/
|
||||
public void testComputeSignificanceDoesntAlterAverage(ConditionalMutualInfoCalculatorMultiVariate condMiCalc,
|
||||
int dimensions, int timeSteps) throws Exception {
|
||||
|
||||
condMiCalc.initialise(dimensions, dimensions, dimensions);
|
||||
|
||||
// generate some random data
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
double[][] sourceData = rg.generateNormalData(timeSteps, dimensions,
|
||||
0, 1);
|
||||
double[][] destData = rg.generateNormalData(timeSteps, dimensions,
|
||||
0, 1);
|
||||
double[][] condData = rg.generateNormalData(timeSteps, dimensions,
|
||||
0, 1);
|
||||
|
||||
condMiCalc.setObservations(sourceData, destData, condData);
|
||||
|
||||
//condMiCalc.setDebug(true);
|
||||
double condMi = condMiCalc.computeAverageLocalOfObservations();
|
||||
//condMiCalc.setDebug(false);
|
||||
//double[] condMiLocal = miCalc.computeLocalOfPreviousObservations();
|
||||
|
||||
System.out.printf("Average was %.5f\n", condMi);
|
||||
|
||||
// Now look at statistical significance tests
|
||||
int[][] newOrderings = rg.generateDistinctRandomPerturbations(
|
||||
timeSteps, 100);
|
||||
|
||||
// Compute significance for permuting first variable
|
||||
condMiCalc.computeSignificance(1, newOrderings);
|
||||
|
||||
// And compute the average value again to check that it's consistent:
|
||||
for (int i = 0; i < 10; i++) {
|
||||
double averageCheck1 = condMiCalc.computeAverageLocalOfObservations();
|
||||
assertEquals(condMi, averageCheck1);
|
||||
}
|
||||
|
||||
// Compute significance for permuting second variable
|
||||
condMiCalc.computeSignificance(2, newOrderings);
|
||||
|
||||
// And compute the average value again to check that it's consistent:
|
||||
for (int i = 0; i < 10; i++) {
|
||||
double averageCheck1 = condMiCalc.computeAverageLocalOfObservations();
|
||||
assertEquals(condMi, averageCheck1);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
package infodynamics.measures.continuous.gaussian;
|
||||
|
||||
import infodynamics.measures.continuous.ConditionalMutualInfoMultiVariateAbstractTester;
|
||||
|
||||
public class ConditionalMutualInfoMultiVariateTester extends
|
||||
ConditionalMutualInfoMultiVariateAbstractTester {
|
||||
|
||||
public void testLocalsAverageCorrectly() throws Exception {
|
||||
ConditionalMutualInfoCalculatorMultiVariateGaussian condMiCalc =
|
||||
new ConditionalMutualInfoCalculatorMultiVariateGaussian();
|
||||
super.testLocalsAverageCorrectly(condMiCalc, 2, 100);
|
||||
}
|
||||
|
||||
public void testComputeSignificanceDoesntAlterAverage() throws Exception {
|
||||
ConditionalMutualInfoCalculatorMultiVariateGaussian condMiCalc =
|
||||
new ConditionalMutualInfoCalculatorMultiVariateGaussian();
|
||||
super.testComputeSignificanceDoesntAlterAverage(condMiCalc, 2, 100);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
package infodynamics.measures.continuous.kraskov;
|
||||
|
||||
public class ConditionalMutualInfoMultiVariateTester
|
||||
extends infodynamics.measures.continuous.ConditionalMutualInfoMultiVariateAbstractTester {
|
||||
|
||||
/**
|
||||
* Utility function to create a calculator for the given algorithm number
|
||||
*
|
||||
* @param algNumber
|
||||
* @return
|
||||
*/
|
||||
public ConditionalMutualInfoCalculatorMultiVariateKraskov getNewCalc(int algNumber) {
|
||||
ConditionalMutualInfoCalculatorMultiVariateKraskov condMiCalc = null;
|
||||
if (algNumber == 1) {
|
||||
condMiCalc = new ConditionalMutualInfoCalculatorMultiVariateKraskov1();
|
||||
} else if (algNumber == 2) {
|
||||
condMiCalc = new ConditionalMutualInfoCalculatorMultiVariateKraskov2();
|
||||
}
|
||||
return condMiCalc;
|
||||
}
|
||||
|
||||
/**
|
||||
* Confirm that the local values average correctly back to the average value
|
||||
*
|
||||
*/
|
||||
public void checkLocalsAverageCorrectly(int algNumber) throws Exception {
|
||||
|
||||
ConditionalMutualInfoCalculatorMultiVariateKraskov miCalc = getNewCalc(algNumber);
|
||||
|
||||
String kraskov_K = "4";
|
||||
|
||||
miCalc.setProperty(
|
||||
MutualInfoCalculatorMultiVariateKraskov.PROP_K,
|
||||
kraskov_K);
|
||||
|
||||
super.testLocalsAverageCorrectly(miCalc, 2, 100);
|
||||
}
|
||||
public void testLocalsAverageCorrectly() throws Exception {
|
||||
checkLocalsAverageCorrectly(1);
|
||||
checkLocalsAverageCorrectly(2);
|
||||
}
|
||||
|
||||
/**
|
||||
* Confirm that significance testing doesn't alter the average that
|
||||
* would be returned.
|
||||
*
|
||||
* @throws Exception
|
||||
*/
|
||||
public void checkComputeSignificanceDoesntAlterAverage(int algNumber) throws Exception {
|
||||
|
||||
ConditionalMutualInfoCalculatorMultiVariateKraskov condMiCalc = getNewCalc(algNumber);
|
||||
|
||||
String kraskov_K = "4";
|
||||
|
||||
condMiCalc.setProperty(
|
||||
MutualInfoCalculatorMultiVariateKraskov.PROP_K,
|
||||
kraskov_K);
|
||||
|
||||
super.testComputeSignificanceDoesntAlterAverage(condMiCalc, 2, 100);
|
||||
}
|
||||
public void testComputeSignificanceDoesntAlterAverage() throws Exception {
|
||||
checkComputeSignificanceDoesntAlterAverage(1);
|
||||
checkComputeSignificanceDoesntAlterAverage(2);
|
||||
}
|
||||
|
||||
}
|
Loading…
Reference in New Issue