Extending MI discrete calculator to allow different bases for each variable. This necessitates immediate removal of the (int,int) constructor (where the 2nd argument was the time difference), this will now be confusing between (base, timeDiff) and (base1, base2). In time we may bring it back, after we can be reasonably comfortable people have switched away from using (base, timeDiff).

This commit is contained in:
jlizier 2018-06-20 10:55:21 +10:00
parent 4170339426
commit 7fafabe451
4 changed files with 47 additions and 33 deletions

View File

@ -213,9 +213,9 @@ public class AutoAnalyserMI extends AutoAnalyserChannelCalculator
}
return new DiscreteCalcAndArguments(
new MutualInformationCalculatorDiscrete(base, timeDiff),
new MutualInformationCalculatorDiscrete(base, base, timeDiff),
base,
base + ", " + timeDiff);
base + ", " + base + ", " + timeDiff);
}
/**

View File

@ -66,10 +66,16 @@ Theory' (John Wiley & Sons, New York, 1991).</li>
public class MutualInformationCalculatorDiscrete extends InfoMeasureCalculatorDiscrete
implements ChannelCalculatorDiscrete, AnalyticNullDistributionComputer {
private int timeDiff = 0;
private int[][] jointCount = null; // Count for (i[t-timeDiff], j[t]) tuples
private int[] iCount = null; // Count for i[t-timeDiff]
private int[] jCount = null; // Count for j[t]
/**
* Store the number of symbols for each variable
*/
protected int base1;
protected int base2;
protected int timeDiff = 0;
protected int[][] jointCount = null; // Count for (i[t-timeDiff], j[t]) tuples
protected int[] iCount = null; // Count for i[t-timeDiff]
protected int[] jCount = null; // Count for j[t]
protected boolean miComputed = false;
@ -77,32 +83,39 @@ public class MutualInformationCalculatorDiscrete extends InfoMeasureCalculatorDi
* Construct a new MI calculator with default time difference of 0
* between the variables
*
* @param base number of symbols for each variable.
* @param base number of symbols for each variable. (same for each)
* E.g. binary variables are in base-2.
* @throws Exception
*/
public MutualInformationCalculatorDiscrete(int base) throws Exception {
this(base, 0);
this(base, base, 0);
}
/**
* Create a new mutual information calculator
*
* @param base number of symbols for each variable.
* @param base1 number of symbols for first variable.
* E.g. binary variables are in base-2.
* @param base2 number of symbols for second variable.
* @param timeDiff number of time steps across which to compute
* MI for given time series
* @throws Exception when timeDiff < 0
*/
public MutualInformationCalculatorDiscrete(int base, int timeDiff) throws Exception {
super(base);
public MutualInformationCalculatorDiscrete(int base1, int base2, int timeDiff) throws Exception {
// Create super object, just with first base
super(base1);
// Store the bases
this.base1 = base1;
this.base2 = base2;
if (timeDiff < 0) {
throw new Exception("timeDiff must be >= 0");
}
this.timeDiff = timeDiff;
jointCount = new int[base][base];
iCount = new int[base];
jCount = new int[base];
jointCount = new int[base1][base2];
iCount = new int[base1];
jCount = new int[base2];
}
@Override
@ -176,10 +189,10 @@ public class MutualInformationCalculatorDiscrete extends InfoMeasureCalculatorDi
if (debug) {
System.out.println("i\tj\tp_i\tp_j\tp_joint\tlocal");
}
for (int i = 0; i < base; i++) {
for (int i = 0; i < base1; i++) {
// compute p_i
double probi = (double) iCount[i] / (double) observations;
for (int j = 0; j < base; j++) {
for (int j = 0; j < base2; j++) {
// compute p_j
double probj = (double) jCount[j] / (double) observations;
// compute p(veci=i, vecj=j)
@ -230,21 +243,23 @@ public class MutualInformationCalculatorDiscrete extends InfoMeasureCalculatorDi
// Reconstruct the values of the first and second variables (not necessarily in order)
int[] iValues = new int[observations];
int[] jValues = new int[observations];
int t_i = 0;
int t_j = 0;
for (int iVal = 0; iVal < base; iVal++) {
for (int iVal = 0; iVal < base1; iVal++) {
int numberOfSamplesI = iCount[iVal];
MatrixUtils.fill(iValues, iVal, t_i, numberOfSamplesI);
t_i += numberOfSamplesI;
int numberOfSamplesJ = jCount[iVal];
MatrixUtils.fill(jValues, iVal, t_j, numberOfSamplesJ);
}
int[] jValues = new int[observations];
int t_j = 0;
for (int jVal = 0; jVal < base2; jVal++) {
int numberOfSamplesJ = jCount[jVal];
MatrixUtils.fill(jValues, jVal, t_j, numberOfSamplesJ);
t_j += numberOfSamplesJ;
}
MutualInformationCalculatorDiscrete mi2;
try {
mi2 = new MutualInformationCalculatorDiscrete(base, timeDiff);
mi2 = new MutualInformationCalculatorDiscrete(base1, base2, timeDiff);
} catch (Exception e) {
// The only possible exception is if timeDiff < 0, which
// it cannot be. Shut down the JVM
@ -286,7 +301,7 @@ public class MutualInformationCalculatorDiscrete extends InfoMeasureCalculatorDi
}
return new ChiSquareMeasurementDistribution(average,
observations,
(base - 1) * (base - 1));
(base1 - 1) * (base2 - 1));
}
/**

View File

@ -92,7 +92,7 @@ public class MutualInfoCalculatorMultiVariateWithDiscreteSymbolic implements
// Make the base the maximum of the number of combinations of orderings of the
// continuous variables and the discrete base.
miCalc = new MutualInformationCalculatorDiscrete(baseToUse,0);
miCalc = new MutualInformationCalculatorDiscrete(baseToUse);
miCalc.initialise();
}

View File

@ -23,7 +23,7 @@ import junit.framework.TestCase;
public class MutualInformationTester extends TestCase {
public void testFullyDependent() throws Exception {
MutualInformationCalculatorDiscrete miCalc = new MutualInformationCalculatorDiscrete(2, 0);
MutualInformationCalculatorDiscrete miCalc = new MutualInformationCalculatorDiscrete(2);
// X2 is a copy of X1 - MI should be 1 bit
miCalc.initialise();
@ -41,7 +41,7 @@ public class MutualInformationTester extends TestCase {
}
public void testIndependent() throws Exception {
MutualInformationCalculatorDiscrete miCalc = new MutualInformationCalculatorDiscrete(2, 0);
MutualInformationCalculatorDiscrete miCalc = new MutualInformationCalculatorDiscrete(2);
// X2 is unrelated to X1 - MI should be 0 bits
miCalc.initialise();
@ -51,10 +51,9 @@ public class MutualInformationTester extends TestCase {
}
public void testAnd() throws Exception {
MutualInformationCalculatorDiscrete miCalc = new MutualInformationCalculatorDiscrete(2, 0);
MutualInformationCalculatorDiscrete miCalc = new MutualInformationCalculatorDiscrete(2);
int[] X1 = new int[] {0, 0, 1, 1};
int[] X2 = new int[] {0, 1, 0, 1};
int[] Y = new int[] {0, 0, 0, 1};
// Y is dependent on X1 - MI should be 0.311 bits
@ -66,7 +65,7 @@ public class MutualInformationTester extends TestCase {
}
public void testXor() throws Exception {
MutualInformationCalculatorDiscrete miCalc = new MutualInformationCalculatorDiscrete(2, 0);
MutualInformationCalculatorDiscrete miCalc = new MutualInformationCalculatorDiscrete(2);
int[] X1 = new int[] {0, 0, 1, 1};
int[] X2 = new int[] {0, 1, 0, 1};
@ -85,7 +84,7 @@ public class MutualInformationTester extends TestCase {
assertEquals(0.0, miX2Y, 0.000001);
// Y is fully determined from X1, X2 - MI should be 1 bits
MutualInformationCalculatorDiscrete miCalcBase4 = new MutualInformationCalculatorDiscrete(4, 0);
MutualInformationCalculatorDiscrete miCalcBase4 = new MutualInformationCalculatorDiscrete(4);
int[] X12 = new int[] {0, 1, 2, 3};
miCalcBase4.initialise();
miCalcBase4.addObservations(X12, Y);
@ -95,7 +94,7 @@ public class MutualInformationTester extends TestCase {
}
public void test3Xor() throws Exception {
MutualInformationCalculatorDiscrete miCalc = new MutualInformationCalculatorDiscrete(2, 0);
MutualInformationCalculatorDiscrete miCalc = new MutualInformationCalculatorDiscrete(2);
int[] X1 = new int[] {0, 1, 0, 1, 0, 1, 0, 1};
int[] X2 = new int[] {0, 0, 1, 1, 0, 0, 1, 1};
@ -121,7 +120,7 @@ public class MutualInformationTester extends TestCase {
assertEquals(0.0, miX3Y, 0.000001);
// Y is independent of X1, X2 - MI should be 0 bits
MutualInformationCalculatorDiscrete miCalcBase4 = new MutualInformationCalculatorDiscrete(4, 0);
MutualInformationCalculatorDiscrete miCalcBase4 = new MutualInformationCalculatorDiscrete(4);
int[] X12 = new int[] {0, 1, 2, 3, 0, 1, 2, 3};
miCalcBase4.initialise();
miCalcBase4.addObservations(X12, Y);
@ -130,7 +129,7 @@ public class MutualInformationTester extends TestCase {
assertEquals(0.0, miX12Y, 0.000001);
// Y is fully determined from X1, X2, X3 - MI should be 1 bits
MutualInformationCalculatorDiscrete miCalcBase8 = new MutualInformationCalculatorDiscrete(8, 0);
MutualInformationCalculatorDiscrete miCalcBase8 = new MutualInformationCalculatorDiscrete(8);
int[] X123 = new int[] {0, 1, 2, 3, 4, 5, 6, 7};
miCalcBase8.initialise();
miCalcBase8.addObservations(X123, Y);