Made Kraskov Conditional MI calculator implement the common Conditional MI interface

This commit is contained in:
joseph.lizier 2013-01-14 12:51:34 +00:00
parent a06efba6e0
commit f466012556
3 changed files with 143 additions and 52 deletions

View File

@ -1,5 +1,6 @@
package infodynamics.measures.continuous.kraskov; package infodynamics.measures.continuous.kraskov;
import infodynamics.measures.continuous.ConditionalMutualInfoCalculatorMultiVariate;
import infodynamics.utils.EuclideanUtils; import infodynamics.utils.EuclideanUtils;
import infodynamics.utils.MatrixUtils; import infodynamics.utils.MatrixUtils;
import infodynamics.utils.EmpiricalMeasurementDistribution; import infodynamics.utils.EmpiricalMeasurementDistribution;
@ -20,7 +21,8 @@ import infodynamics.utils.RandomGenerator;
* *
* @author Joseph Lizier * @author Joseph Lizier
*/ */
public abstract class ConditionalMutualInfoCalculatorMultiVariateKraskov { public abstract class ConditionalMutualInfoCalculatorMultiVariateKraskov
implements ConditionalMutualInfoCalculatorMultiVariate {
/** /**
* we compute distances to the kth neighbour in the joint space * we compute distances to the kth neighbour in the joint space
@ -185,44 +187,57 @@ public abstract class ConditionalMutualInfoCalculatorMultiVariateKraskov {
* The user should ensure that all values 0..N-1 are represented exactly once in the * The user should ensure that all values 0..N-1 are represented exactly once in the
* array reordering and that no other values are included here. * array reordering and that no other values are included here.
* *
* @param variableToReorder 1 for variable 1, 2 for variable 2
* @param reordering * @param reordering
* @return * @return
* @throws Exception * @throws Exception
*/ */
public abstract double computeAverageLocalOfObservations(int[] reordering) throws Exception; public abstract double computeAverageLocalOfObservations(int variableToReorder,
int[] reordering) throws Exception;
/** /**
* Compute the significance of the mutual information of the previously supplied observations. * Compute the significance of the mutual information of the previously supplied observations.
* We destroy the p(x,y,z) correlations, while retaining the p(x,z), p(y) marginals, to check how * We destroy the p(x,y,z) correlations, by permuting the given variable,
* while retaining the joint distribution of the other variable
* and the conditional, and the marginal distribution of the
* permuted variable. This checks how
* significant this conditional mutual information actually was. * significant this conditional mutual information actually was.
* *
* This is in the spirit of Chavez et. al., "Statistical assessment of nonlinear causality: * This is in the spirit of Chavez et. al., "Statistical assessment of nonlinear causality:
* application to epileptic EEG signals", Journal of Neuroscience Methods 124 (2003) 113-128 * application to epileptic EEG signals", Journal of Neuroscience Methods 124 (2003) 113-128
* which was performed for Transfer entropy. * which was performed for Transfer entropy.
* *
* @param variableToReorder 1 for variable 1, 2 for variable 2
* @param numPermutationsToCheck * @param numPermutationsToCheck
* @return the proportion of MI scores from the distribution which have higher or equal MIs to ours. * @return the proportion of MI scores from the distribution which have higher or equal MIs to ours.
*/ */
public synchronized EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck) throws Exception { public synchronized EmpiricalMeasurementDistribution computeSignificance(
int variableToReorder, int numPermutationsToCheck) throws Exception {
// Generate the re-ordered indices: // Generate the re-ordered indices:
RandomGenerator rg = new RandomGenerator(); RandomGenerator rg = new RandomGenerator();
int[][] newOrderings = rg.generateDistinctRandomPerturbations(data1.length, numPermutationsToCheck); int[][] newOrderings = rg.generateDistinctRandomPerturbations(data1.length, numPermutationsToCheck);
return computeSignificance(newOrderings); return computeSignificance(variableToReorder, newOrderings);
} }
/** /**
* Compute the significance of the mutual information of the previously supplied observations. * Compute the significance of the mutual information of the previously supplied observations.
* We destroy the p(x,y,z) correlations, while retaining the p(x,z), p(y) marginals, to check how * We destroy the p(x,y,z) correlations, by permuting the given variable,
* significant this mutual information actually was. * while retaining the joint distribution of the other variable
* and the conditional, and the marginal distribution of the
* permuted variable. This checks how
* significant this conditional mutual information actually was.
* *
* This is in the spirit of Chavez et. al., "Statistical assessment of nonlinear causality: * This is in the spirit of Chavez et. al., "Statistical assessment of nonlinear causality:
* application to epileptic EEG signals", Journal of Neuroscience Methods 124 (2003) 113-128 * application to epileptic EEG signals", Journal of Neuroscience Methods 124 (2003) 113-128
* which was performed for Transfer entropy. * which was performed for Transfer entropy.
* *
* @param variableToReorder 1 for variable 1, 2 for variable 2
* @param newOrderings the specific new orderings to use * @param newOrderings the specific new orderings to use
* @return the proportion of conditional MI scores from the distribution which have higher or equal MIs to ours. * @return the proportion of conditional MI scores from the distribution which have higher or equal MIs to ours.
*/ */
public EmpiricalMeasurementDistribution computeSignificance(int[][] newOrderings) throws Exception { public EmpiricalMeasurementDistribution computeSignificance(
int variableToReorder, int[][] newOrderings) throws Exception {
int numPermutationsToCheck = newOrderings.length; int numPermutationsToCheck = newOrderings.length;
if (!condMiComputed) { if (!condMiComputed) {
computeAverageLocalOfObservations(); computeAverageLocalOfObservations();
@ -235,7 +250,8 @@ public abstract class ConditionalMutualInfoCalculatorMultiVariateKraskov {
int countWhereMiIsMoreSignificantThanOriginal = 0; int countWhereMiIsMoreSignificantThanOriginal = 0;
for (int i = 0; i < numPermutationsToCheck; i++) { for (int i = 0; i < numPermutationsToCheck; i++) {
// Compute the MI under this reordering // Compute the MI under this reordering
double newMI = computeAverageLocalOfObservations(newOrderings[i]); double newMI = computeAverageLocalOfObservations(
variableToReorder, newOrderings[i]);
measDistribution.distribution[i] = newMI; measDistribution.distribution[i] = newMI;
if (debug){ if (debug){
System.out.println("New MI was " + newMI); System.out.println("New MI was " + newMI);
@ -256,7 +272,8 @@ public abstract class ConditionalMutualInfoCalculatorMultiVariateKraskov {
public abstract double[] computeLocalOfPreviousObservations() throws Exception; public abstract double[] computeLocalOfPreviousObservations() throws Exception;
public double[] computeLocalUsingPreviousObservations(double[][] states1, double[][] states2) throws Exception { public double[] computeLocalUsingPreviousObservations(double[][] states1,
double[][] states2, double[][] condStates) throws Exception {
// If implemented, will need to incorporate any normalisation here // If implemented, will need to incorporate any normalisation here
// (normalising the incoming data the same way the previously // (normalising the incoming data the same way the previously
// supplied observations were normalised). // supplied observations were normalised).

View File

@ -28,32 +28,49 @@ public class ConditionalMutualInfoCalculatorMultiVariateKraskov1
* Compute the average conditional MI from the previously set observations * Compute the average conditional MI from the previously set observations
*/ */
public double computeAverageLocalOfObservations() throws Exception { public double computeAverageLocalOfObservations() throws Exception {
return computeAverageLocalOfObservations(null); return computeAverageLocalOfObservations(1, null);
} }
/** /**
* Compute what the average conditional MI would look like were the second time series reordered * Compute what the average conditional MI would look like were the given
* time series reordered
* as per the array of time indices in reordering. * as per the array of time indices in reordering.
* The user should ensure that all values 0..N-1 are represented exactly once in the * The user should ensure that all values 0..N-1 are represented exactly once in the
* array reordering and that no other values are included here. * array reordering and that no other values are included here.
*
* If reordering is null, it is assumed there is no reordering of * If reordering is null, it is assumed there is no reordering of
* the y variable. * the given variable.
* *
* @param reordering the reordered time steps of the y variable * @param variableToReorder 1 for variable 1, 2 for variable 2
* @param reordering the reordered time steps of the given variable
* @return * @return
* @throws Exception * @throws Exception
*/ */
public double computeAverageLocalOfObservations(int[] reordering) throws Exception { public double computeAverageLocalOfObservations(int variableToReorder,
int[] reordering) throws Exception {
if (!tryKeepAllPairsNorms || (data1.length > MAX_DATA_SIZE_FOR_KEEP_ALL_PAIRS_NORM)) { if (!tryKeepAllPairsNorms || (data1.length > MAX_DATA_SIZE_FOR_KEEP_ALL_PAIRS_NORM)) {
double[][] originalData2 = data2; double[][] originalData;
if (variableToReorder == 1) {
originalData = data1;
} else {
originalData = data2;
}
if (reordering != null) { if (reordering != null) {
// Generate a new re-ordered data2 // Generate a new re-ordered data array
data2 = MatrixUtils.extractSelectedTimePointsReusingArrays(originalData2, reordering); if (variableToReorder == 1) {
data1 = MatrixUtils.extractSelectedTimePointsReusingArrays(originalData, reordering);
} else {
data2 = MatrixUtils.extractSelectedTimePointsReusingArrays(originalData, reordering);
}
} }
// Compute the MI // Compute the MI
double newMI = computeAverageLocalOfObservationsWhileComputingDistances(); double newMI = computeAverageLocalOfObservationsWhileComputingDistances();
// restore data2 // restore original data
data2 = originalData2; if (variableToReorder == 1) {
data1 = originalData;
} else {
data2 = originalData;
}
return newMI; return newMI;
} }
@ -76,13 +93,21 @@ public class ConditionalMutualInfoCalculatorMultiVariateKraskov1
// using x, y and z norms to all neighbours // using x, y and z norms to all neighbours
// (note that norm of point t to itself will be set to infinity). // (note that norm of point t to itself will be set to infinity).
int tForY = (reordering == null) ? t : reordering[t]; int tForReorderedVar = (reordering == null) ? t : reordering[t];
double[] jointNorm = new double[N]; double[] jointNorm = new double[N];
for (int t2 = 0; t2 < N; t2++) { for (int t2 = 0; t2 < N; t2++) {
int t2ForY = (reordering == null) ? t2 : reordering[t2]; int t2ForReorderedVar = (reordering == null) ? t2 : reordering[t2];
// Joint norm is the max of all three marginals // Joint norm is the max of all three marginals
jointNorm[t2] = Math.max(xNorms[t][t2], Math.max(yNorms[tForY][t2ForY], zNorms[t][t2])); if (variableToReorder == 1) {
jointNorm[t2] = Math.max(xNorms[tForReorderedVar][t2ForReorderedVar],
Math.max(yNorms[t][t2],
zNorms[t][t2]));
} else {
jointNorm[t2] = Math.max(xNorms[t][t2],
Math.max(yNorms[tForReorderedVar][t2ForReorderedVar],
zNorms[t][t2]));
}
} }
// Then find the kth closest neighbour, using a heuristic to // Then find the kth closest neighbour, using a heuristic to
// select whether to keep the k mins only or to do a sort. // select whether to keep the k mins only or to do a sort.
@ -107,12 +132,21 @@ public class ConditionalMutualInfoCalculatorMultiVariateKraskov1
for (int t2 = 0; t2 < N; t2++) { for (int t2 = 0; t2 < N; t2++) {
if (zNorms[t][t2] < epsilon) { if (zNorms[t][t2] < epsilon) {
n_z++; n_z++;
if (xNorms[t][t2] < epsilon) { int t2ForReorderedVar = (reordering == null) ? t2 : reordering[t2];
n_xz++; if (variableToReorder == 1) {
} if (xNorms[tForReorderedVar][t2ForReorderedVar] < epsilon) {
int t2ForY = (reordering == null) ? t2 : reordering[t2]; n_xz++;
if (yNorms[tForY][t2ForY] < epsilon) { }
n_yz++; if (yNorms[t][t2] < epsilon) {
n_yz++;
}
} else {
if (xNorms[t][t2] < epsilon) {
n_xz++;
}
if (yNorms[tForReorderedVar][t2ForReorderedVar] < epsilon) {
n_yz++;
}
} }
} }
} }

View File

@ -38,24 +38,40 @@ public class ConditionalMutualInfoCalculatorMultiVariateKraskov2
protected static final double CUTOFF_MULTIPLIER = 1.5; protected static final double CUTOFF_MULTIPLIER = 1.5;
/** /**
* Compute what the average conditional MI would look like were the second time series reordered * Compute what the average conditional MI would look like were the given
* time series reordered
* as per the array of time indices in reordering. * as per the array of time indices in reordering.
* The user should ensure that all values 0..N-1 are represented exactly once in the * The user should ensure that all values 0..N-1 are represented exactly once in the
* array reordering and that no other values are included here. * array reordering and that no other values are included here.
* *
* @param reordering * @param variableToReorder 1 for variable 1, 2 for variable 2
* @param reordering the reordered time steps of the given variable
* @return * @return
* @throws Exception * @throws Exception
*/ */
public double computeAverageLocalOfObservations(int[] reordering) throws Exception { public double computeAverageLocalOfObservations(int variableToReorder,
int[] reordering) throws Exception {
if (!tryKeepAllPairsNorms || (data1.length > MAX_DATA_SIZE_FOR_KEEP_ALL_PAIRS_NORM)) { if (!tryKeepAllPairsNorms || (data1.length > MAX_DATA_SIZE_FOR_KEEP_ALL_PAIRS_NORM)) {
double[][] originalData2 = data2; double[][] originalData;
// Generate a new re-ordered data2 if (variableToReorder == 1) {
data2 = MatrixUtils.extractSelectedTimePointsReusingArrays(originalData2, reordering); originalData = data1;
} else {
originalData = data2;
}
// Generate a new re-ordered data array
if (variableToReorder == 1) {
data1 = MatrixUtils.extractSelectedTimePointsReusingArrays(originalData, reordering);
} else {
data2 = MatrixUtils.extractSelectedTimePointsReusingArrays(originalData, reordering);
}
// Compute the MI // Compute the MI
double newMI = computeAverageLocalOfObservationsWhileComputingDistances(); double newMI = computeAverageLocalOfObservationsWhileComputingDistances();
// restore data2 // restore original data
data2 = originalData2; if (variableToReorder == 1) {
data1 = originalData;
} else {
data2 = originalData;
}
return newMI; return newMI;
} }
@ -80,13 +96,19 @@ public class ConditionalMutualInfoCalculatorMultiVariateKraskov2
// First get x and y and z norms to all neighbours // First get x and y and z norms to all neighbours
// (note that norm of point t to itself will be set to infinity). // (note that norm of point t to itself will be set to infinity).
int tForY = reordering[t]; int tForReorderedVar = reordering[t];
double[][] jointNorm = new double[N][2]; double[][] jointNorm = new double[N][2];
for (int t2 = 0; t2 < N; t2++) { for (int t2 = 0; t2 < N; t2++) {
int t2ForY = reordering[t2]; int t2ForReorderedVar = reordering[t2];
jointNorm[t2][JOINT_NORM_VAL_COLUMN] = Math.max(xNorms[t][t2], if (variableToReorder == 1) {
Math.max(yNorms[tForY][t2ForY], zNorms[t][t2])); jointNorm[t2][JOINT_NORM_VAL_COLUMN] = Math.max(
xNorms[tForReorderedVar][t2ForReorderedVar],
Math.max(yNorms[t][t2], zNorms[t][t2]));
} else {
jointNorm[t2][JOINT_NORM_VAL_COLUMN] = Math.max(xNorms[t][t2],
Math.max(yNorms[tForReorderedVar][t2ForReorderedVar], zNorms[t][t2]));
}
// And store the time step for back reference after the // And store the time step for back reference after the
// array is sorted. // array is sorted.
jointNorm[t2][JOINT_NORM_TIMESTEP_COLUMN] = t2; jointNorm[t2][JOINT_NORM_TIMESTEP_COLUMN] = t2;
@ -112,11 +134,20 @@ public class ConditionalMutualInfoCalculatorMultiVariateKraskov2
// Find eps_{x,y,z} as the maximum x and y and z norms amongst this set: // Find eps_{x,y,z} as the maximum x and y and z norms amongst this set:
for (int j = 0; j < k; j++) { for (int j = 0; j < k; j++) {
int timeStepOfJthPoint = timeStepsOfKthMins[j]; int timeStepOfJthPoint = timeStepsOfKthMins[j];
if (xNorms[t][timeStepOfJthPoint] > eps_x) { if (variableToReorder == 1) {
eps_x = xNorms[t][timeStepOfJthPoint]; if (xNorms[tForReorderedVar][reordering[timeStepOfJthPoint]] > eps_x) {
} eps_x = xNorms[tForReorderedVar][reordering[timeStepOfJthPoint]];
if (yNorms[tForY][reordering[timeStepOfJthPoint]] > eps_y) { }
eps_y = yNorms[tForY][reordering[timeStepOfJthPoint]]; if (yNorms[t][timeStepOfJthPoint] > eps_y) {
eps_y = yNorms[t][timeStepOfJthPoint];
}
} else {
if (xNorms[t][timeStepOfJthPoint] > eps_x) {
eps_x = xNorms[t][timeStepOfJthPoint];
}
if (yNorms[tForReorderedVar][reordering[timeStepOfJthPoint]] > eps_y) {
eps_y = yNorms[tForReorderedVar][reordering[timeStepOfJthPoint]];
}
} }
if (zNorms[t][timeStepOfJthPoint] > eps_z) { if (zNorms[t][timeStepOfJthPoint] > eps_z) {
eps_z = zNorms[t][timeStepOfJthPoint]; eps_z = zNorms[t][timeStepOfJthPoint];
@ -133,11 +164,20 @@ public class ConditionalMutualInfoCalculatorMultiVariateKraskov2
for (int t2 = 0; t2 < N; t2++) { for (int t2 = 0; t2 < N; t2++) {
if (zNorms[t][t2] <= eps_z) { if (zNorms[t][t2] <= eps_z) {
n_z++; n_z++;
if (xNorms[t][t2] <= eps_x) { if (variableToReorder == 1) {
n_xz++; if (xNorms[tForReorderedVar][reordering[t2]] <= eps_x) {
} n_xz++;
if (yNorms[tForY][reordering[t2]] <= eps_y) { }
n_yz++; if (yNorms[t][t2] <= eps_y) {
n_yz++;
}
} else {
if (xNorms[t][t2] <= eps_x) {
n_xz++;
}
if (yNorms[tForReorderedVar][reordering[t2]] <= eps_y) {
n_yz++;
}
} }
} }
} }