mirror of https://github.com/jlizier/jidt
Merge pull request #86 from pmediano/master
Thanks Pedro, and sorry to take so long to attend to this
This commit is contained in:
commit
0e33393cdc
|
@ -0,0 +1,575 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous;
|
||||
|
||||
import infodynamics.utils.EmpiricalMeasurementDistribution;
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
import infodynamics.utils.RandomGenerator;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
import java.util.Vector;
|
||||
|
||||
/**
|
||||
* Implements a base class with common functionality for child class
|
||||
* implementations of multivariate information measures via various estimators.
|
||||
*
|
||||
* <p>Multivariate information measures are functionals of probability
|
||||
* distributions over <code>R^n</code>, and typical examples include multi-information
|
||||
* (a.k.a. total correlation), dual total correlation, O-information, and connected
|
||||
* information.</p>
|
||||
*
|
||||
* <p>These measures can be computed via different kinds of estimators, such as
|
||||
* linear-gaussian, KSG estimators, etc (see the child classes linked above).
|
||||
* </p>
|
||||
*
|
||||
* <p>
|
||||
* Usage of the child classes is intended to follow this paradigm:
|
||||
* </p>
|
||||
* <ol>
|
||||
* <li>Construct the calculator;</li>
|
||||
* <li>Set properties using {@link #setProperty(String, String)};</li>
|
||||
* <li>Initialise the calculator using {@link #initialise(int)};</li>
|
||||
* <li>Provide the observations/samples for the calculator
|
||||
* to set up the PDFs, using:
|
||||
* <ul>
|
||||
* <li>{@link #setObservations(double[][])}
|
||||
* for calculations based on single time-series, OR</li>
|
||||
* <li>The following sequence:<ol>
|
||||
* <li>{@link #startAddObservations()}, then</li>
|
||||
* <li>One or more calls to {@link #addObservations(double[][])} or
|
||||
* {@link #addObservation(double[])}, then</li>
|
||||
* <li>{@link #finaliseAddObservations()};</li>
|
||||
* </ol></li>
|
||||
* </ul>
|
||||
* <li>Compute the required quantities, being one or more of:
|
||||
* <ul>
|
||||
* <li>the average measure:
|
||||
* {@link #computeAverageLocalOfObservations()};</li>
|
||||
* <li>the local values for these samples:
|
||||
* {@link #computeLocalOfPreviousObservations()}</li>
|
||||
* <li>local values for a specific set of samples:
|
||||
* {@link #computeLocalUsingPreviousObservations(double[][])}</li>
|
||||
* </ul>
|
||||
* </li>
|
||||
* <li>
|
||||
* Return to step 2 or 3 to re-use the calculator on a new data set.
|
||||
* </li>
|
||||
* </ol>
|
||||
* </p>
|
||||
*
|
||||
* <p><b>References:</b><br/>
|
||||
* <ul>
|
||||
* <li>Rosas, F., Mediano, P., Gastpar, M, Jensen, H.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.100.032305">"Quantifying high-order
|
||||
* interdependencies via multivariate extensions of the mutual information"</a>,
|
||||
* Physical Review E 100, (2019) 032305.</li>
|
||||
* </ul>
|
||||
*
|
||||
* @author Pedro A.M. Mediano (<a href="pmediano at pm.me">email</a>,
|
||||
* <a href="http://www.doc.ic.ac.uk/~pam213">www</a>)
|
||||
*/
|
||||
public abstract class MultiVariateInfoMeasureCalculatorCommon
|
||||
implements InfoMeasureCalculatorContinuous {
|
||||
|
||||
/**
|
||||
* Number of joint variables to consider
|
||||
*/
|
||||
protected int dimensions = 1;
|
||||
/**
|
||||
* Number of samples supplied
|
||||
*/
|
||||
protected int totalObservations = 0;
|
||||
/**
|
||||
* Whether we are in debug mode
|
||||
*/
|
||||
protected boolean debug = false;
|
||||
/**
|
||||
* Cached supplied observations
|
||||
*/
|
||||
protected double[][] observations;
|
||||
/**
|
||||
* Set of individually supplied observations
|
||||
*/
|
||||
protected Vector<double[]> individualObservations;
|
||||
/**
|
||||
* Whether the user has supplied more than one (disjoint) set of samples
|
||||
*/
|
||||
protected boolean addedMoreThanOneObservationSet;
|
||||
/**
|
||||
* Whether the measure has been computed for the latest supplied data
|
||||
*/
|
||||
protected boolean isComputed = false;
|
||||
/**
|
||||
* Cached last the measure value calculated
|
||||
*/
|
||||
protected double lastAverage;
|
||||
|
||||
/**
|
||||
* Whether to normalise incoming values
|
||||
*/
|
||||
protected boolean normalise = true;
|
||||
|
||||
/**
|
||||
* Property name for whether to normalise incoming values to mean 0,
|
||||
* standard deviation 1 (default true)
|
||||
*/
|
||||
public static final String PROP_NORMALISE = "NORMALISE";
|
||||
|
||||
|
||||
/**
|
||||
* Initialise the calculator for (re-)use, with the existing
|
||||
* (or default) values of parameters, with number of
|
||||
* joint variables specified.
|
||||
* Clears any PDFs of previously supplied observations.
|
||||
*
|
||||
* @param dimensions the number of joint variables to consider
|
||||
*/
|
||||
public void initialise() {
|
||||
initialise(dimensions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialise the calculator for (re-)use, with the existing
|
||||
* (or default) values of parameters, with number of
|
||||
* joint variables specified.
|
||||
* Clears an PDFs of previously supplied observations.
|
||||
*
|
||||
* @param dimensions the number of joint variables to consider
|
||||
*/
|
||||
public void initialise(int dimensions) {
|
||||
this.dimensions = dimensions;
|
||||
lastAverage = 0.0;
|
||||
totalObservations = 0;
|
||||
isComputed = false;
|
||||
observations = null;
|
||||
addedMoreThanOneObservationSet = false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set properties for the calculator.
|
||||
* New property values are not guaranteed to take effect until the next call
|
||||
* to an initialise method.
|
||||
*
|
||||
* <p>Valid property names, and what their
|
||||
* values should represent, include:</p>
|
||||
* <ul>
|
||||
* <li>{@link #PROP_NORMALISE} -- whether to normalise the incoming variables
|
||||
* to mean 0, standard deviation 1, or not (default false).</li>
|
||||
* </ul>
|
||||
*
|
||||
* <p>Unknown property values are ignored.</p>
|
||||
*
|
||||
* @param propertyName name of the property
|
||||
* @param propertyValue value of the property
|
||||
* @throws Exception for invalid property values
|
||||
*/
|
||||
public void setProperty(String propertyName, String propertyValue)
|
||||
throws Exception {
|
||||
boolean propertySet = true;
|
||||
if (propertyName.equalsIgnoreCase(PROP_NORMALISE)) {
|
||||
normalise = Boolean.parseBoolean(propertyValue);
|
||||
} else {
|
||||
// No property was set here
|
||||
propertySet = false;
|
||||
}
|
||||
if (debug && propertySet) {
|
||||
System.out.println(this.getClass().getSimpleName() + ": Set property " + propertyName +
|
||||
" to " + propertyValue);
|
||||
}
|
||||
}
|
||||
|
||||
public String getProperty(String propertyName) throws Exception {
|
||||
if (propertyName.equalsIgnoreCase(PROP_NORMALISE)) {
|
||||
return Boolean.toString(normalise);
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets a single series from which to compute the PDF for the instantiated measure..
|
||||
* Cannot be called in conjunction with other methods for setting/adding
|
||||
* observations.
|
||||
*
|
||||
* <p>The supplied series may be a time-series, or may be simply
|
||||
* a set of separate observations
|
||||
* without a time interpretation.</p>
|
||||
*
|
||||
* <p>Should only be called once, the last call contains the
|
||||
* observations that are used (they are not accumulated).</p>
|
||||
*
|
||||
* @param observations series of multivariate observations
|
||||
* (first index is time or observation index, second is variable number)
|
||||
* @throws Exception
|
||||
*/
|
||||
public void setObservations(double[][] observations) throws Exception {
|
||||
startAddObservations();
|
||||
addObservations(observations);
|
||||
finaliseAddObservations();
|
||||
addedMoreThanOneObservationSet = false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Signal that we will add in the samples for computing the PDF
|
||||
* from several disjoint time-series or trials via calls to
|
||||
* {@link #addObservation(double[])} or {@link #addObservations(double[][])}
|
||||
* rather than {@link #setDebug(boolean)}.
|
||||
*/
|
||||
public void startAddObservations() {
|
||||
individualObservations = new Vector<double[]>();
|
||||
}
|
||||
|
||||
/**
|
||||
* <p>Adds a new (single) observation to update the PDFs with - is
|
||||
* intended to be called multiple times.
|
||||
* Must be called after {@link #startAddObservations()}; call
|
||||
* {@link #finaliseAddObservations()} once all observations have
|
||||
* been supplied.</p>
|
||||
*
|
||||
* <p>Note that the arrays must not be over-written by the user
|
||||
* until after finaliseAddObservations() has been called
|
||||
* (they are not copied by this method necessarily, but the method
|
||||
* may simply hold a pointer to them).</p>
|
||||
*
|
||||
* @param observation a single multivariate observation
|
||||
* (index is variable number)
|
||||
*/
|
||||
public void addObservation(double[] observation) {
|
||||
if (individualObservations.size() > 0) {
|
||||
addedMoreThanOneObservationSet = true;
|
||||
}
|
||||
individualObservations.add(observation);
|
||||
}
|
||||
|
||||
/**
|
||||
* <p>Adds a new set of observations to update the PDFs with - is
|
||||
* intended to be called multiple times.
|
||||
* Must be called after {@link #startAddObservations()}; call
|
||||
* {@link #finaliseAddObservations()} once all observations have
|
||||
* been supplied.</p>
|
||||
*
|
||||
* <p>Note that the arrays must not be over-written by the user
|
||||
* until after finaliseAddObservations() has been called
|
||||
* (they are not copied by this method necessarily, but the method
|
||||
* may simply hold a pointer to them).</p>
|
||||
*
|
||||
* @param observations series of multivariate observations
|
||||
* (first index is time or observation index, second is variable number)
|
||||
*/
|
||||
public void addObservations(double[][] observations) {
|
||||
// This implementation is not particularly efficient,
|
||||
// however it will suffice for now.
|
||||
for (int s = 0; s < observations.length; s++) {
|
||||
addObservation(observations[s]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* This class provides a basic implementation, generating
|
||||
* the internal set of samples in observations; child classes
|
||||
* should then process these observations as required.
|
||||
*
|
||||
*/
|
||||
public void finaliseAddObservations() throws Exception {
|
||||
observations = new double[individualObservations.size()][];
|
||||
for (int t = 0; t < observations.length; t++) {
|
||||
observations[t] = individualObservations.elementAt(t);
|
||||
}
|
||||
// Allow vector to be reclaimed
|
||||
individualObservations = null;
|
||||
|
||||
if (observations[0].length != dimensions) {
|
||||
throw new Exception("Incorrect number of dimensions " + observations[0].length +
|
||||
" in supplied observations (expected " + dimensions + ")");
|
||||
}
|
||||
totalObservations = observations.length;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a resampled distribution of what the measure would look like,
|
||||
* under a null hypothesis that the individual values of each
|
||||
* variable in the
|
||||
* samples have no relation to each other.
|
||||
* That is, we destroy the p(x,y,z,..) correlations, while
|
||||
* retaining the p(x), p(y),.. marginals, to check how
|
||||
* significant this measure actually was.
|
||||
*
|
||||
* <p>See Section II.E "Statistical significance testing" of
|
||||
* the JIDT paper below for a description of how this is done for MI,
|
||||
* we are extending that here.
|
||||
* </p>
|
||||
*
|
||||
* <p>Note that if several disjoint time-series have been added
|
||||
* as observations using {@link #addObservations(double[])} etc.,
|
||||
* then these separate "trials" will be mixed up in the generation
|
||||
* of surrogates here.</p>
|
||||
*
|
||||
* <p>This method (in contrast to {@link #computeSignificance(int[][][])})
|
||||
* creates <i>random</i> shufflings of the next values for the surrogate MultiInfo
|
||||
* calculations.</p>
|
||||
*
|
||||
* @param numPermutationsToCheck number of surrogate samples to permute
|
||||
* to generate the distribution.
|
||||
* @return the distribution of surrogate measure values under this null hypothesis.
|
||||
* @see "J.T. Lizier, 'JIDT: An information-theoretic
|
||||
* toolkit for studying the dynamics of complex systems', 2014."
|
||||
* @throws Exception
|
||||
*/
|
||||
public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck) throws Exception {
|
||||
// Generate the re-ordered indices:
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
int[][][] newOrderings = new int[numPermutationsToCheck][][];
|
||||
// Generate numPermutationsToCheck * (dimensions-1) permutations of 0 .. data.length-1
|
||||
for (int n = 0; n < numPermutationsToCheck; n++) {
|
||||
// (Not necessary to check for distinct random perturbations)
|
||||
newOrderings[n] = rg.generateRandomPerturbations(totalObservations, dimensions-1);
|
||||
}
|
||||
return computeSignificance(newOrderings);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a resampled distribution of what the measure would look like,
|
||||
* under a null hypothesis that the individual values of each
|
||||
* variable in the
|
||||
* samples have no relation to eachother.
|
||||
* That is, we destroy the p(x,y,z,..) correlations, while
|
||||
* retaining the p(x), p(y),.. marginals, to check how
|
||||
* significant this measure actually was.
|
||||
*
|
||||
* <p>See Section II.E "Statistical significance testing" of
|
||||
* the JIDT paper below for a description of how this is done for MI,
|
||||
* we are extending that here.
|
||||
* </p>
|
||||
*
|
||||
* <p>Note that if several disjoint time-series have been added
|
||||
* as observations using {@link #addObservations(double[])} etc.,
|
||||
* then these separate "trials" will be mixed up in the generation
|
||||
* of surrogates here.</p>
|
||||
*
|
||||
* <p>This method (in contrast to {@link #computeSignificance(int)})
|
||||
* allows the user to specify how to construct the surrogates,
|
||||
* such that repeatable results may be obtained.</p>
|
||||
*
|
||||
* @param newOrderings a specification of how to shuffle the values
|
||||
* to create the surrogates to generate the distribution with. The first
|
||||
* index is the permutation number (i.e. newOrderings.length is the number
|
||||
* of surrogate samples we use to bootstrap to generate the distribution here.)
|
||||
* The second index is the variable number (minus 1, since we don't reorder
|
||||
* the first variable),
|
||||
* Each array newOrderings[i][v] should be an array of length N (where
|
||||
* would be the value returned by {@link #getNumObservations()}),
|
||||
* containing a permutation of the values in 0..(N-1).
|
||||
* @return the distribution of surrogate measure values under this null hypothesis.
|
||||
* @see "J.T. Lizier, 'JIDT: An information-theoretic
|
||||
* toolkit for studying the dynamics of complex systems', 2014."
|
||||
* @throws Exception where the length of each permutation in newOrderings
|
||||
* is not equal to the number N samples that were previously supplied.
|
||||
*/
|
||||
public EmpiricalMeasurementDistribution computeSignificance(int[][][] newOrderings) throws Exception {
|
||||
|
||||
int numPermutationsToCheck = newOrderings.length;
|
||||
if (!isComputed) {
|
||||
computeAverageLocalOfObservations();
|
||||
}
|
||||
|
||||
// Store the real observations and their measure value:
|
||||
double actualMeasure = lastAverage;
|
||||
|
||||
EmpiricalMeasurementDistribution measDistribution = new EmpiricalMeasurementDistribution(numPermutationsToCheck);
|
||||
|
||||
int countWhereSurrogateIsMoreSignificantThanOriginal = 0;
|
||||
for (int i = 0; i < numPermutationsToCheck; i++) {
|
||||
// Compute the measure under this reordering
|
||||
double newMeasure = computeAverageLocalOfObservations(newOrderings[i]);
|
||||
measDistribution.distribution[i] = newMeasure;
|
||||
if (debug){
|
||||
System.out.println("New measure value was " + newMeasure);
|
||||
}
|
||||
if (newMeasure >= actualMeasure) {
|
||||
countWhereSurrogateIsMoreSignificantThanOriginal++;
|
||||
}
|
||||
}
|
||||
|
||||
// Restore the actual measure and the observations
|
||||
lastAverage = actualMeasure;
|
||||
|
||||
// And return the significance
|
||||
measDistribution.pValue = (double) countWhereSurrogateIsMoreSignificantThanOriginal / (double) numPermutationsToCheck;
|
||||
measDistribution.actualValue = actualMeasure;
|
||||
return measDistribution;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute what the measure would look like were all time series (bar the first)
|
||||
* reordered as per the array of time indices in newOrdering.
|
||||
*
|
||||
* <p>The reordering array contains the reordering for each marginal variable
|
||||
* (first index). The user should ensure that all values 0..N-1 are
|
||||
* represented exactly once in the array reordering and that no other values
|
||||
* are included here.</p>
|
||||
*
|
||||
* <p>Note that if several disjoint time-series have been added as
|
||||
* observations using {@link #addObservations(double[])} etc., then these
|
||||
* separate "trials" will be mixed up in the generation of a shuffled source
|
||||
* series here.</p>
|
||||
*
|
||||
* <p>This method is primarily intended for use in {@link
|
||||
* #computeSignificance(int[][])} however has been made public in case users
|
||||
* wish to access it.</p>
|
||||
*
|
||||
* @param newOrdering the specific permuted new orderings to use. First index
|
||||
* is the variable number (minus 1, since we don't reorder the first
|
||||
* variable), second index is the time step, the value is the reordered time
|
||||
* step to use for that variable at the given time step. The values must be
|
||||
* an array of length N (where
|
||||
* would be the value returned by {@link #getNumObservations()}), containing
|
||||
* a permutation of the values in 0..(N-1). If null, no reordering is
|
||||
* performed.
|
||||
* @return what the average measure would look like under this reordering
|
||||
* @throws Exception
|
||||
*/
|
||||
public double computeAverageLocalOfObservations(int[][] newOrdering)
|
||||
throws Exception {
|
||||
|
||||
if (newOrdering == null) {
|
||||
return computeAverageLocalOfObservations();
|
||||
}
|
||||
|
||||
// Take a clone of the object to compute the measure of the surrogates:
|
||||
// (this is a shallow copy, it doesn't make new copies of all
|
||||
// the arrays)
|
||||
MultiVariateInfoMeasureCalculatorCommon surrogateCalculator =
|
||||
(MultiVariateInfoMeasureCalculatorCommon) this.clone();
|
||||
|
||||
// Generate a new re-ordered source data
|
||||
double[][] shuffledData =
|
||||
MatrixUtils.reorderDataForVariables(
|
||||
observations, newOrdering);
|
||||
// Perform new initialisations
|
||||
surrogateCalculator.initialise(dimensions);
|
||||
// Set new observations
|
||||
surrogateCalculator.setObservations(shuffledData);
|
||||
// Compute the MI
|
||||
return surrogateCalculator.computeAverageLocalOfObservations();
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculates the local measure at every sample provided since the last time the
|
||||
* calculator was initialised.
|
||||
*
|
||||
* @return the "time-series" of local measure values in nats (not bits!)
|
||||
* @throws Exception
|
||||
*/
|
||||
public double[] computeLocalOfPreviousObservations() throws Exception {
|
||||
// Cannot do if observations haven't been set
|
||||
if (observations == null) {
|
||||
throw new Exception("Cannot compute local values of previous observations " +
|
||||
"if they have not been set!");
|
||||
}
|
||||
|
||||
return computeLocalUsingPreviousObservations(observations);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the local measure values for each of the
|
||||
* supplied samples in <code>states</code>.
|
||||
*
|
||||
* <p>PDFs are computed using all of the previously supplied
|
||||
* observations, but not those in <code>states</code>
|
||||
* (unless they were
|
||||
* some of the previously supplied samples).</p>
|
||||
*
|
||||
* @param states series of multivariate observations
|
||||
* (first index is time or observation index, second is variable number)
|
||||
* @return the series of local measure values.
|
||||
* @throws Exception
|
||||
*/
|
||||
public abstract double[] computeLocalUsingPreviousObservations(double states[][])
|
||||
throws Exception;
|
||||
|
||||
/**
|
||||
* Shortcut method to initialise the calculator, set observations and compute
|
||||
* the average measure in one line.
|
||||
*
|
||||
* @param new_observations series of multivariate observations
|
||||
* (first index is time or observation index, second is variable number)
|
||||
*/
|
||||
public double compute(double[][] new_observations) throws Exception {
|
||||
initialise(new_observations[0].length);
|
||||
setObservations(new_observations);
|
||||
return computeAverageLocalOfObservations();
|
||||
}
|
||||
|
||||
/**
|
||||
* Shortcut method to initialise the calculator, set observations and compute
|
||||
* the local measure in one line.
|
||||
*
|
||||
* @param new_observations series of multivariate observations
|
||||
* (first index is time or observation index, second is variable number)
|
||||
*/
|
||||
public double[] computeLocals(double[][] new_observations) throws Exception {
|
||||
initialise(new_observations[0].length);
|
||||
setObservations(new_observations);
|
||||
return computeLocalOfPreviousObservations();
|
||||
}
|
||||
|
||||
public int getNumObservations() throws Exception {
|
||||
return totalObservations;
|
||||
}
|
||||
|
||||
public void setDebug(boolean debug) {
|
||||
this.debug = debug;
|
||||
}
|
||||
|
||||
public double getLastAverage() {
|
||||
return lastAverage;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an <code>int[]</code> array with all integers from 0 to
|
||||
* <code>N-1</code>, except <code>idx</code>.
|
||||
*
|
||||
* <p>This method is primarily intended for internal use (to extract blocks
|
||||
* of covariance matrices excluding one variable).</p>
|
||||
*
|
||||
* @param idx index of integer to omit
|
||||
* @param N upper limit of the integer array
|
||||
*/
|
||||
protected int[] allExcept(int idx, int N) {
|
||||
boolean[] v = new boolean[N];
|
||||
Arrays.fill(v, true);
|
||||
v[idx] = false;
|
||||
|
||||
int[] v2 = new int[N - 1];
|
||||
int counter = 0;
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (v[i]) {
|
||||
v2[counter] = i;
|
||||
counter++;
|
||||
}
|
||||
}
|
||||
|
||||
return v2;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,122 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2017, Joseph T. Lizier, Ipek Oezdemir and Pedro Mediano
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous.gaussian;
|
||||
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
|
||||
/**
|
||||
* <p>Computes the differential dual total correlation (DTC) of a given multivariate
|
||||
* <code>double[][]</code> set of
|
||||
* observations (extending {@link MultiVariateInfoMeasureCalculatorGaussian}),
|
||||
* assuming that the probability distribution function for these observations is
|
||||
* a multivariate Gaussian distribution.</p>
|
||||
*
|
||||
* <p>Usage is as per the paradigm outlined for {@link MultiVariateInfoMeasureCalculatorCommon}.
|
||||
* </p>
|
||||
*
|
||||
* <p><b>References:</b><br/>
|
||||
* <ul>
|
||||
* <li>Rosas, F., Mediano, P., Gastpar, M, Jensen, H.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.100.032305">"Quantifying high-order
|
||||
* interdependencies via multivariate extensions of the mutual information"</a>,
|
||||
* Physical Review E 100, (2019) 032305.</li>
|
||||
* </ul>
|
||||
*
|
||||
* @author Pedro A.M. Mediano (<a href="pmediano at pm.me">email</a>,
|
||||
* <a href="http://www.doc.ic.ac.uk/~pam213">www</a>)
|
||||
*/
|
||||
public class DualTotalCorrelationCalculatorGaussian
|
||||
extends MultiVariateInfoMeasureCalculatorGaussian {
|
||||
|
||||
/**
|
||||
* Constructor.
|
||||
*/
|
||||
public DualTotalCorrelationCalculatorGaussian() {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @return the average DTC in nats (not bits!)
|
||||
* @throws Exception if not sufficient data have been provided, or if the
|
||||
* supplied covariance matrix is invalid.
|
||||
*/
|
||||
public double computeAverageLocalOfObservations() throws Exception {
|
||||
|
||||
if (covariance == null) {
|
||||
throw new Exception("Cannot calculate DTC without having " +
|
||||
"a covariance either supplied or computed via setObservations()");
|
||||
}
|
||||
|
||||
if (!isComputed) {
|
||||
double dtc = - (dimensions - 1)*Math.log(MatrixUtils.determinantSymmPosDefMatrix(covariance));
|
||||
for (int i = 0; i < dimensions; i++) {
|
||||
int[] idx = allExcept(i, dimensions);
|
||||
double[][] marginal_cov = MatrixUtils.selectRowsAndColumns(covariance, idx, idx);
|
||||
dtc += Math.log(MatrixUtils.determinantSymmPosDefMatrix(marginal_cov));
|
||||
}
|
||||
// This "0.5" comes from the entropy formula for Gaussians: h = 0.5*logdet(2*pi*e*Sigma)
|
||||
lastAverage = 0.5*dtc;;
|
||||
isComputed = true;
|
||||
}
|
||||
|
||||
return lastAverage;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @return the "time-series" of local DTC values in nats (not bits!)
|
||||
* for the supplied states.
|
||||
* @throws Exception if not sufficient data have been provided, or if the
|
||||
* supplied covariance matrix is invalid.
|
||||
*/
|
||||
public double[] computeLocalUsingPreviousObservations(double[][] states) throws Exception {
|
||||
|
||||
if ((means == null) || (covariance == null)) {
|
||||
throw new Exception("Cannot compute local values without having means " +
|
||||
"and covariance either supplied or computed via setObservations()");
|
||||
}
|
||||
|
||||
EntropyCalculatorMultiVariateGaussian hCalc = new EntropyCalculatorMultiVariateGaussian();
|
||||
hCalc.initialise(dimensions);
|
||||
hCalc.setCovarianceAndMeans(covariance, means);
|
||||
double[] localValues = MatrixUtils.multiply(hCalc.computeLocalUsingPreviousObservations(states), -(dimensions - 1));
|
||||
|
||||
for (int i = 0; i < dimensions; i++) {
|
||||
int[] idx = allExcept(i, dimensions);
|
||||
double[][] marginal_cov = MatrixUtils.selectRowsAndColumns(covariance, idx, idx);
|
||||
double[] marginal_means = MatrixUtils.select(means, idx);
|
||||
double[][] marginal_state = MatrixUtils.selectColumns(states, idx);
|
||||
|
||||
hCalc.initialise(dimensions - 1);
|
||||
hCalc.setCovarianceAndMeans(marginal_cov, marginal_means);
|
||||
double[] thisLocals = hCalc.computeLocalUsingPreviousObservations(marginal_state);
|
||||
|
||||
localValues = MatrixUtils.add(localValues, thisLocals);
|
||||
|
||||
}
|
||||
|
||||
return localValues;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2017, Joseph T. Lizier, Ipek Oezdemir and Pedro Mediano
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous.gaussian;
|
||||
|
||||
import infodynamics.measures.continuous.MultiVariateInfoMeasureCalculatorCommon;
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
|
||||
/**
|
||||
* <p>Base class with common functionality for child class implementations of
|
||||
* multivariate information measures on a given multivariate
|
||||
* <code>double[][]</code> set of
|
||||
* observations (extending {@link MultiVariateInfoMeasureCalculatorCommon}),
|
||||
* assuming that the probability distribution function for these observations is
|
||||
* a multivariate Gaussian distribution.</p>
|
||||
*
|
||||
* <p>Usage is as per the paradigm outlined for {@link MultiVariateInfoMeasureCalculatorCommon},
|
||||
* with:
|
||||
* <ul>
|
||||
* <li>For constructors see the child classes.</li>
|
||||
* <li>Further properties are defined in {@link #setProperty(String, String)}.</li>
|
||||
* <li>Computed values are in <b>nats</b>, not bits!</li>
|
||||
* </ul>
|
||||
* </p>
|
||||
*
|
||||
* <p><b>References:</b><br/>
|
||||
* <ul>
|
||||
* <li>Rosas, F., Mediano, P., Gastpar, M, Jensen, H.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.100.032305">"Quantifying high-order
|
||||
* interdependencies via multivariate extensions of the mutual information"</a>,
|
||||
* Physical Review E 100, (2019) 032305.</li>
|
||||
* </ul>
|
||||
*
|
||||
* @author Pedro A.M. Mediano (<a href="pmediano at pm.me">email</a>,
|
||||
* <a href="http://www.doc.ic.ac.uk/~pam213">www</a>)
|
||||
*/
|
||||
public abstract class MultiVariateInfoMeasureCalculatorGaussian
|
||||
extends MultiVariateInfoMeasureCalculatorCommon {
|
||||
|
||||
/**
|
||||
* Covariance of the system. Can be calculated from supplied observations
|
||||
* or supplied directly by the user.
|
||||
*/
|
||||
double[][] covariance = null;
|
||||
|
||||
/**
|
||||
* Means of the system. Can be calculated from supplied observations
|
||||
* or supplied directly by the user.
|
||||
*/
|
||||
double[] means = null;
|
||||
|
||||
/**
|
||||
* Whether the current covariance matrix has been determined from data or
|
||||
* supplied directly. This changes the approach to local measures and
|
||||
* significance testing.
|
||||
*/
|
||||
boolean covFromObservations;
|
||||
|
||||
@Override
|
||||
public void finaliseAddObservations() throws Exception {
|
||||
super.finaliseAddObservations();
|
||||
|
||||
this.means = MatrixUtils.means(observations);
|
||||
setCovariance(MatrixUtils.covarianceMatrix(observations), true);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
/**
|
||||
* <p>Set the covariance of the distribution for which we will compute the
|
||||
* measure directly, without supplying observations.</p>
|
||||
*
|
||||
* <p>See {@link #setCovariance(double[][], boolean)}.
|
||||
*
|
||||
* @param covariance covariance matrix of the system
|
||||
* @throws Exception for covariance matrix not matching the expected dimensions,
|
||||
* being non-square, asymmetric or non-positive definite
|
||||
*/
|
||||
public void setCovariance(double[][] covariance) throws Exception {
|
||||
setCovariance(covariance, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* <p>Set the covariance of the distribution for which we will compute the
|
||||
* measure.</p>
|
||||
*
|
||||
* <p>This is an alternative to sequences of calls to {@link #setObservations(double[][])} or
|
||||
* {@link #addObservations(double[][])} etc.
|
||||
* Note that without setting any observations, you cannot later
|
||||
* call {@link #computeLocalOfPreviousObservations()}.</p>
|
||||
*
|
||||
* @param covariance covariance matrix of the system
|
||||
* @param means mean of the system
|
||||
*/
|
||||
public void setCovarianceAndMeans(double[][] covariance, double[] means)
|
||||
throws Exception {
|
||||
this.means = means;
|
||||
setCovariance(covariance, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* <p>Set the covariance of the distribution for which we will compute the
|
||||
* measure.</p>
|
||||
*
|
||||
* <p>This is an alternative to sequences of calls to {@link #setObservations(double[][])} or
|
||||
* {@link #addObservations(double[][])} etc.
|
||||
* Note that without setting any observations, you cannot later
|
||||
* call {@link #computeLocalOfPreviousObservations()}, and without
|
||||
* providing the means of the variables, you cannot later call
|
||||
* {@link #computeLocalUsingPreviousObservations(double[][])}.</p>
|
||||
*
|
||||
* @param covariance covariance matrix of the system
|
||||
* @param covFromObservations whether the covariance matrix
|
||||
* was determined internally from observations or not
|
||||
* @throws Exception for covariance matrix not matching the expected dimensions,
|
||||
* being non-square, asymmetric or non-positive definite
|
||||
*/
|
||||
public void setCovariance(double[][] cov, boolean covFromObservations) throws Exception {
|
||||
|
||||
if (!covFromObservations) {
|
||||
// Make sure we're not keeping any observations
|
||||
observations = null;
|
||||
}
|
||||
if (cov.length != dimensions) {
|
||||
throw new Exception("Supplied covariance matrix does not match initialised number of dimensions");
|
||||
}
|
||||
if (cov.length != cov[0].length) {
|
||||
throw new Exception("Covariance matrices must be square");
|
||||
}
|
||||
|
||||
this.covFromObservations = covFromObservations;
|
||||
this.covariance = cov;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2017, Joseph T. Lizier, Ipek Oezdemir and Pedro Mediano
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous.gaussian;
|
||||
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
|
||||
/**
|
||||
* <p>Computes the differential O-information of a given multivariate
|
||||
* <code>double[][]</code> set of
|
||||
* observations (extending {@link MultiVariateInfoMeasureCalculatorGaussian}),
|
||||
* assuming that the probability distribution function for these observations is
|
||||
* a multivariate Gaussian distribution.</p>
|
||||
*
|
||||
* <p>Usage is as per the paradigm outlined for {@link MultiVariateInfoMeasureCalculatorCommon}.
|
||||
* </p>
|
||||
*
|
||||
* <p><b>References:</b><br/>
|
||||
* <ul>
|
||||
* <li>Rosas, F., Mediano, P., Gastpar, M, Jensen, H.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.100.032305">"Quantifying high-order
|
||||
* interdependencies via multivariate extensions of the mutual information"</a>,
|
||||
* Physical Review E 100, (2019) 032305.</li>
|
||||
* </ul>
|
||||
*
|
||||
* @author Pedro A.M. Mediano (<a href="pmediano at pm.me">email</a>,
|
||||
* <a href="http://www.doc.ic.ac.uk/~pam213">www</a>)
|
||||
*/
|
||||
public class OInfoCalculatorGaussian
|
||||
extends MultiVariateInfoMeasureCalculatorGaussian {
|
||||
|
||||
/**
|
||||
* Constructor.
|
||||
*/
|
||||
public OInfoCalculatorGaussian() {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @return the average O-info in nats (not bits!)
|
||||
* @throws Exception if not sufficient data have been provided, or if the
|
||||
* supplied covariance matrix is invalid.
|
||||
*/
|
||||
public double computeAverageLocalOfObservations() throws Exception {
|
||||
|
||||
if (covariance == null) {
|
||||
throw new Exception("Cannot calculate O-Info without having " +
|
||||
"a covariance either supplied or computed via setObservations()");
|
||||
}
|
||||
|
||||
if (!isComputed) {
|
||||
double oinfo = (dimensions - 2)*Math.log(MatrixUtils.determinantSymmPosDefMatrix(covariance));
|
||||
for (int i = 0; i < dimensions; i++) {
|
||||
int[] idx = allExcept(i, dimensions);
|
||||
double[][] marginal_cov = MatrixUtils.selectRowsAndColumns(covariance, idx, idx);
|
||||
oinfo += Math.log(covariance[i][i]) - Math.log(MatrixUtils.determinantSymmPosDefMatrix(marginal_cov));
|
||||
}
|
||||
// This "0.5" comes from the entropy formula for Gaussians: h = 0.5*logdet(2*pi*e*Sigma)
|
||||
lastAverage = 0.5*oinfo;;
|
||||
isComputed = true;
|
||||
}
|
||||
|
||||
return lastAverage;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @return the "time-series" of local O-info values in nats (not bits!)
|
||||
* for the supplied states.
|
||||
* @throws Exception if not sufficient data have been provided, or if the
|
||||
* supplied covariance matrix is invalid.
|
||||
*/
|
||||
public double[] computeLocalUsingPreviousObservations(double[][] states) throws Exception {
|
||||
|
||||
if ((means == null) || (covariance == null)) {
|
||||
throw new Exception("Cannot compute local values without having means " +
|
||||
"and covariance either supplied or computed via setObservations()");
|
||||
}
|
||||
|
||||
EntropyCalculatorMultiVariateGaussian hCalc = new EntropyCalculatorMultiVariateGaussian();
|
||||
hCalc.initialise(dimensions);
|
||||
hCalc.setCovarianceAndMeans(covariance, means);
|
||||
double[] localValues = MatrixUtils.multiply(hCalc.computeLocalUsingPreviousObservations(states), dimensions - 2);
|
||||
|
||||
for (int i = 0; i < dimensions; i++) {
|
||||
int[] idx = allExcept(i, dimensions);
|
||||
|
||||
// Local entropy of this variable (i) only
|
||||
double[][] this_cov = MatrixUtils.selectRowsAndColumns(covariance, i, 1, i, 1);
|
||||
double[] this_means = MatrixUtils.select(means, i, 1);
|
||||
double[][] this_state = MatrixUtils.selectColumns(states, i, 1);
|
||||
|
||||
hCalc.initialise(1);
|
||||
hCalc.setCovarianceAndMeans(this_cov, this_means);
|
||||
double[] thisLocals = hCalc.computeLocalUsingPreviousObservations(this_state);
|
||||
|
||||
// Local entropy of the rest of the variables (0, ... i-1, i+1, ... D)
|
||||
double[][] rest_cov = MatrixUtils.selectRowsAndColumns(covariance, idx, idx);
|
||||
double[] rest_means = MatrixUtils.select(means, idx);
|
||||
double[][] rest_state = MatrixUtils.selectColumns(states, idx);
|
||||
|
||||
hCalc.initialise(dimensions - 1);
|
||||
hCalc.setCovarianceAndMeans(rest_cov, rest_means);
|
||||
double[] restLocals = hCalc.computeLocalUsingPreviousObservations(rest_state);
|
||||
|
||||
localValues = MatrixUtils.add(localValues, MatrixUtils.subtract(thisLocals, restLocals));
|
||||
|
||||
}
|
||||
|
||||
return localValues;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2017, Joseph T. Lizier, Ipek Oezdemir and Pedro Mediano
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous.gaussian;
|
||||
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
|
||||
/**
|
||||
* <p>Computes the differential S-information of a given multivariate
|
||||
* <code>double[][]</code> set of
|
||||
* observations (extending {@link MultiVariateInfoMeasureCalculatorGaussian}),
|
||||
* assuming that the probability distribution function for these observations is
|
||||
* a multivariate Gaussian distribution.</p>
|
||||
*
|
||||
* <p>Usage is as per the paradigm outlined for {@link MultiVariateInfoMeasureCalculatorCommon}.
|
||||
* </p>
|
||||
*
|
||||
* <p><b>References:</b><br/>
|
||||
* <ul>
|
||||
* <li>Rosas, F., Mediano, P., Gastpar, M, Jensen, H.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.100.032305">"Quantifying high-order
|
||||
* interdependencies via multivariate extensions of the mutual information"</a>,
|
||||
* Physical Review E 100, (2019) 032305.</li>
|
||||
* </ul>
|
||||
*
|
||||
* @author Pedro A.M. Mediano (<a href="pmediano at pm.me">email</a>,
|
||||
* <a href="http://www.doc.ic.ac.uk/~pam213">www</a>)
|
||||
*/
|
||||
public class SInfoCalculatorGaussian
|
||||
extends MultiVariateInfoMeasureCalculatorGaussian {
|
||||
|
||||
/**
|
||||
* Constructor.
|
||||
*/
|
||||
public SInfoCalculatorGaussian() {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @return the average S-info in nats (not bits!)
|
||||
* @throws Exception if not sufficient data have been provided, or if the
|
||||
* supplied covariance matrix is invalid.
|
||||
*/
|
||||
public double computeAverageLocalOfObservations() throws Exception {
|
||||
|
||||
if (covariance == null) {
|
||||
throw new Exception("Cannot calculate O-Info without having " +
|
||||
"a covariance either supplied or computed via setObservations()");
|
||||
}
|
||||
|
||||
if (!isComputed) {
|
||||
|
||||
MultiInfoCalculatorGaussian tcCalc = new MultiInfoCalculatorGaussian();
|
||||
tcCalc.initialise(dimensions);
|
||||
tcCalc.setCovariance(covariance);
|
||||
double tc = tcCalc.computeAverageLocalOfObservations();
|
||||
|
||||
DualTotalCorrelationCalculatorGaussian dtcCalc = new DualTotalCorrelationCalculatorGaussian();
|
||||
dtcCalc.initialise(dimensions);
|
||||
dtcCalc.setCovariance(covariance);
|
||||
double dtc = dtcCalc.computeAverageLocalOfObservations();
|
||||
|
||||
lastAverage = tc + dtc;
|
||||
isComputed = true;
|
||||
}
|
||||
|
||||
return lastAverage;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @return the "time-series" of local S-info values in nats (not bits!)
|
||||
* for the supplied states.
|
||||
* @throws Exception if not sufficient data have been provided, or if the
|
||||
* supplied covariance matrix is invalid.
|
||||
*/
|
||||
public double[] computeLocalUsingPreviousObservations(double[][] states) throws Exception {
|
||||
|
||||
if ((means == null) || (covariance == null)) {
|
||||
throw new Exception("Cannot compute local values without having means " +
|
||||
"and covariance either supplied or computed via setObservations()");
|
||||
}
|
||||
|
||||
MultiInfoCalculatorGaussian tcCalc = new MultiInfoCalculatorGaussian();
|
||||
tcCalc.initialise(dimensions);
|
||||
tcCalc.setCovarianceAndMeans(covariance, means);
|
||||
double[] localTC = tcCalc.computeLocalUsingPreviousObservations(states);
|
||||
|
||||
DualTotalCorrelationCalculatorGaussian dtcCalc = new DualTotalCorrelationCalculatorGaussian();
|
||||
dtcCalc.initialise(dimensions);
|
||||
dtcCalc.setCovarianceAndMeans(covariance, means);
|
||||
double[] localDTC = dtcCalc.computeLocalUsingPreviousObservations(states);
|
||||
|
||||
double[] localValues = MatrixUtils.add(localTC, localDTC);
|
||||
|
||||
return localValues;
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous.kraskov;
|
||||
|
||||
import infodynamics.utils.EuclideanUtils;
|
||||
import infodynamics.utils.NeighbourNodeData;
|
||||
import infodynamics.utils.KdTree;
|
||||
import infodynamics.utils.MathsUtils;
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.Calendar;
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* <p>Computes the differential dual total correlation of a given multivariate
|
||||
* set of observations, using Kraskov-Stoegbauer-Grassberger (KSG) estimation
|
||||
* (see Kraskov et al., below).</p>
|
||||
*
|
||||
* <p>Usage is as per the paradigm outlined for
|
||||
* {@link MultiVariateInfoMeasureCalculatorCommon}.</p>
|
||||
*
|
||||
* <p>Finally, note that {@link Cloneable} is implemented allowing clone()
|
||||
* to produce only an automatic shallow copy, which is fine
|
||||
* for the statistical significance calculation it is intended for
|
||||
* (none of the array
|
||||
* data will be changed there).
|
||||
* </p>
|
||||
*
|
||||
* <p><b>References:</b><br/>
|
||||
* <ul>
|
||||
* <li>Rosas, F., Mediano, P., Gastpar, M, Jensen, H.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.100.032305">"Quantifying high-order
|
||||
* interdependencies via multivariate extensions of the mutual information"</a>,
|
||||
* Physical Review E 100, (2019) 032305.</li>
|
||||
*
|
||||
* <li>Kraskov, A., Stoegbauer, H., Grassberger, P.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.69.066138">"Estimating mutual information"</a>,
|
||||
* Physical Review E 69, (2004) 066138.</li>
|
||||
* </ul>
|
||||
*
|
||||
* @author Pedro A.M. Mediano (<a href="pmediano at pm.me">email</a>,
|
||||
* <a href="http://www.doc.ic.ac.uk/~pam213">www</a>)
|
||||
*/
|
||||
public class DualTotalCorrelationCalculatorKraskov
|
||||
extends MultiVariateInfoMeasureCalculatorKraskov
|
||||
implements Cloneable { // See comments on clonability above
|
||||
|
||||
|
||||
protected double[] partialComputeFromObservations(
|
||||
int startTimePoint, int numTimePoints, boolean returnLocals) throws Exception {
|
||||
|
||||
double startTime = Calendar.getInstance().getTimeInMillis();
|
||||
|
||||
double[] localMi = null;
|
||||
if (returnLocals) {
|
||||
localMi = new double[numTimePoints];
|
||||
}
|
||||
|
||||
// Constants:
|
||||
double dimensionsMinus1TimesDiGammaN = (double) (dimensions - 1) * digammaN;
|
||||
|
||||
// Count the average number of points within eps_x for each marginal x of each point
|
||||
double totalSumF = 0.0;
|
||||
|
||||
for (int t = startTimePoint; t < startTimePoint + numTimePoints; t++) {
|
||||
// Compute eps for this time step by
|
||||
// finding the kth closest neighbour for point t:
|
||||
PriorityQueue<NeighbourNodeData> nnPQ =
|
||||
kdTreeJoint.findKNearestNeighbours(k, t, dynCorrExclTime);
|
||||
// First element in the PQ is the kth NN,
|
||||
// and epsilon = kthNnData.distance
|
||||
NeighbourNodeData kthNnData = nnPQ.poll();
|
||||
|
||||
// Distance to kth neighbour in joint space
|
||||
double eps = kthNnData.distance;
|
||||
|
||||
double sumF = 0.0;
|
||||
|
||||
sumF += (digammaK - digammaN);
|
||||
|
||||
for (int d = 0; d < dimensions; d++) {
|
||||
int n = rangeSearchersInBigMarginals[d].countPointsStrictlyWithinR(
|
||||
t, eps, dynCorrExclTime);
|
||||
sumF -= (MathsUtils.digamma(n + 1) - digammaN)/(dimensions - 1);
|
||||
}
|
||||
|
||||
sumF *= (dimensions - 1);
|
||||
|
||||
totalSumF += sumF;
|
||||
|
||||
if (returnLocals) {
|
||||
localMi[t-startTimePoint] = sumF;
|
||||
}
|
||||
}
|
||||
|
||||
if (debug) {
|
||||
Calendar rightNow2 = Calendar.getInstance();
|
||||
long endTime = rightNow2.getTimeInMillis();
|
||||
System.out.println("Subset " + startTimePoint + ":" +
|
||||
(startTimePoint + numTimePoints) + " Calculation time: " +
|
||||
((endTime - startTime)/1000.0) + " sec" );
|
||||
}
|
||||
|
||||
// Select what to return:
|
||||
if (returnLocals) {
|
||||
return localMi;
|
||||
} else {
|
||||
double[] returnArray = new double[] {totalSumF/((double) totalObservations)};
|
||||
return returnArray;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,538 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous.kraskov;
|
||||
|
||||
import infodynamics.measures.continuous.MultiVariateInfoMeasureCalculatorCommon;
|
||||
import infodynamics.utils.EuclideanUtils;
|
||||
import infodynamics.utils.NeighbourNodeData;
|
||||
import infodynamics.utils.KdTree;
|
||||
import infodynamics.utils.UnivariateNearestNeighbourSearcher;
|
||||
import infodynamics.utils.MathsUtils;
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.Calendar;
|
||||
import java.util.Random;
|
||||
|
||||
/**
|
||||
* <p>Base class with common functionality for child class implementations of
|
||||
* multivariate information measures on a given multivariate
|
||||
* <code>double[][]</code> set of
|
||||
* observations (extending {@link MultiVariateInfoMeasureCalculatorCommon}),
|
||||
* using Kraskov-Stoegbauer-Grassberger (KSG) estimation
|
||||
* (see Kraskov et al., below).</p>
|
||||
*
|
||||
* <p>Usage is as per the paradigm outlined for {@link MultiVariateInfoMeasureCalculatorCommon},
|
||||
* with:
|
||||
* <ul>
|
||||
* <li>For constructors see the child classes.</li>
|
||||
* <li>Further properties are defined in {@link #setProperty(String, String)}.</li>
|
||||
* <li>Computed values are in <b>nats</b>, not bits!</li>
|
||||
* </ul>
|
||||
* </p>
|
||||
*
|
||||
* <p>Finally, note that {@link Cloneable} is implemented allowing clone()
|
||||
* to produce only an automatic shallow copy, which is fine
|
||||
* for the statistical significance calculation it is intended for
|
||||
* (none of the array data will be changed there).
|
||||
* </p>
|
||||
*
|
||||
* <p><b>References:</b><br/>
|
||||
* <ul>
|
||||
* <li>Rosas, F., Mediano, P., Gastpar, M, Jensen, H.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.100.032305">"Quantifying high-order
|
||||
* interdependencies via multivariate extensions of the mutual information"</a>,
|
||||
* Physical Review E 100, (2019) 032305.</li>
|
||||
*
|
||||
* <li>Kraskov, A., Stoegbauer, H., Grassberger, P.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.69.066138">"Estimating mutual information"</a>,
|
||||
* Physical Review E 69, (2004) 066138.</li>
|
||||
* </ul>
|
||||
* @author Pedro A.M. Mediano (<a href="pmediano at pm.me">email</a>,
|
||||
* <a href="http://www.doc.ic.ac.uk/~pam213">www</a>)
|
||||
*/
|
||||
public abstract class MultiVariateInfoMeasureCalculatorKraskov
|
||||
extends MultiVariateInfoMeasureCalculatorCommon
|
||||
implements Cloneable { // See comments on clonability above
|
||||
|
||||
/**
|
||||
* we compute distances to the kth nearest neighbour
|
||||
*/
|
||||
protected int k = 4;
|
||||
|
||||
/**
|
||||
* The norm type in use (see {@link #PROP_NORM_TYPE})
|
||||
*/
|
||||
protected int normType = EuclideanUtils.NORM_MAX_NORM;
|
||||
|
||||
/**
|
||||
* Property name for the number of K nearest neighbours used in
|
||||
* the KSG algorithm (default 4).
|
||||
*/
|
||||
public final static String PROP_K = "k";
|
||||
/**
|
||||
* Property name for what type of norm to use between data points
|
||||
* for each marginal variable -- Options are defined by
|
||||
* {@link KdTree#setNormType(String)} and the
|
||||
* default is {@link EuclideanUtils#NORM_MAX_NORM}.
|
||||
*/
|
||||
public final static String PROP_NORM_TYPE = "NORM_TYPE";
|
||||
/**
|
||||
* Property name for an amount of random Gaussian noise to be
|
||||
* added to the data (default 1e-8 to match the noise order in MILCA toolkit.).
|
||||
*/
|
||||
public static final String PROP_ADD_NOISE = "NOISE_LEVEL_TO_ADD";
|
||||
/**
|
||||
* Property name for a dynamics exclusion time window
|
||||
* otherwise known as Theiler window (see Kantz and Schreiber).
|
||||
* Default is 0 which means no dynamic exclusion window.
|
||||
*/
|
||||
public static final String PROP_DYN_CORR_EXCL_TIME = "DYN_CORR_EXCL";
|
||||
/**
|
||||
* Property name for the number of parallel threads to use in the
|
||||
* computation (default is to use all available)
|
||||
*/
|
||||
public static final String PROP_NUM_THREADS = "NUM_THREADS";
|
||||
/**
|
||||
* Valid property value for {@link #PROP_NUM_THREADS} to indicate
|
||||
* that all available processors should be used.
|
||||
*/
|
||||
public static final String USE_ALL_THREADS = "USE_ALL";
|
||||
|
||||
/**
|
||||
* Whether to add an amount of random noise to the incoming data
|
||||
*/
|
||||
protected boolean addNoise = true;
|
||||
/**
|
||||
* Amount of random Gaussian noise to add to the incoming data
|
||||
*/
|
||||
protected double noiseLevel = (double) 1e-8;
|
||||
/**
|
||||
* Whether we use dynamic correlation exclusion
|
||||
*/
|
||||
protected boolean dynCorrExcl = false;
|
||||
/**
|
||||
* Size of dynamic correlation exclusion window.
|
||||
*/
|
||||
protected int dynCorrExclTime = 0;
|
||||
/**
|
||||
* Number of parallel threads to use in the computation;
|
||||
* defaults to use all available.
|
||||
*/
|
||||
protected int numThreads = Runtime.getRuntime().availableProcessors();
|
||||
/**
|
||||
* Protected k-d tree data structure (for fast nearest neighbour searches)
|
||||
* representing the joint space
|
||||
*/
|
||||
protected KdTree kdTreeJoint;
|
||||
// /**
|
||||
// * protected data structures (for fast nearest neighbour searches)
|
||||
// * representing the marginal spaces
|
||||
// */
|
||||
// protected KdTree[] rangeSearchersInMarginals;
|
||||
/**
|
||||
* Protected data structures (for fast nearest neighbour searches)
|
||||
* representing the marginal spaces of each individual variable
|
||||
*/
|
||||
protected UnivariateNearestNeighbourSearcher[] rangeSearchersInSmallMarginals;
|
||||
/**
|
||||
* Protected data structures (for fast nearest neighbour searches)
|
||||
* representing the marginal spaces of each set of (D-1) variables
|
||||
*/
|
||||
protected KdTree[] rangeSearchersInBigMarginals;
|
||||
/**
|
||||
* Constant for digamma(k), with k the number of nearest neighbours selected
|
||||
*/
|
||||
protected double digammaK;
|
||||
/**
|
||||
* Constant for digamma(N), with N the number of samples.
|
||||
*/
|
||||
protected double digammaN;
|
||||
|
||||
|
||||
public void initialise(int dimensions) {
|
||||
this.dimensions = dimensions;
|
||||
lastAverage = 0.0;
|
||||
totalObservations = 0;
|
||||
isComputed = false;
|
||||
observations = null;
|
||||
kdTreeJoint = null;
|
||||
rangeSearchersInSmallMarginals = null;
|
||||
rangeSearchersInBigMarginals = null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sets properties for the KSG multivariate measure calculator.
|
||||
* New property values are not guaranteed to take effect until the next call
|
||||
* to an initialise method.
|
||||
*
|
||||
* <p>Valid property names, and what their
|
||||
* values should represent, include:</p>
|
||||
* <ul>
|
||||
* <li>{@link #PROP_K} -- number of k nearest neighbours to use in joint kernel space
|
||||
* in the KSG algorithm (default is 4).</li>
|
||||
* <li>{@link #PROP_NORM_TYPE} -- normalization type to apply to
|
||||
* working out the norms between the points in each marginal space.
|
||||
* Options are defined by {@link KdTree#setNormType(String)} -
|
||||
* default is {@link EuclideanUtils#NORM_MAX_NORM}.</li>
|
||||
* <li>{@link #PROP_DYN_CORR_EXCL_TIME} -- a dynamics exclusion time window,
|
||||
* also known as Theiler window (see Kantz and Schreiber);
|
||||
* default is 0 which means no dynamic exclusion window.</li>
|
||||
* <li>{@link #PROP_ADD_NOISE} -- a standard deviation for an amount of
|
||||
* random Gaussian noise to add to
|
||||
* each variable, to avoid having neighbourhoods with artificially
|
||||
* large counts. (We also accept "false" to indicate "0".)
|
||||
* The amount is added in after any normalisation,
|
||||
* so can be considered as a number of standard deviations of the data.
|
||||
* (Recommended by Kraskov. MILCA uses 1e-8; but adds in
|
||||
* a random amount of noise in [0,noiseLevel) ).
|
||||
* Default 1e-8 to match the noise order in MILCA toolkit..</li>
|
||||
* </ul>
|
||||
*
|
||||
* <p>Unknown property values are ignored.</p>
|
||||
*
|
||||
* @param propertyName name of the property
|
||||
* @param propertyValue value of the property
|
||||
* @throws Exception for invalid property values
|
||||
*/
|
||||
public void setProperty(String propertyName, String propertyValue) throws Exception {
|
||||
boolean propertySet = true;
|
||||
if (propertyName.equalsIgnoreCase(PROP_K)) {
|
||||
k = Integer.parseInt(propertyValue);
|
||||
} else if (propertyName.equalsIgnoreCase(PROP_NORM_TYPE)) {
|
||||
normType = KdTree.validateNormType(propertyValue);
|
||||
} else if (propertyName.equalsIgnoreCase(PROP_DYN_CORR_EXCL_TIME)) {
|
||||
dynCorrExclTime = Integer.parseInt(propertyValue);
|
||||
dynCorrExcl = (dynCorrExclTime > 0);
|
||||
} else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) {
|
||||
if (propertyValue.equals("0") ||
|
||||
propertyValue.equalsIgnoreCase("false")) {
|
||||
addNoise = false;
|
||||
noiseLevel = 0;
|
||||
} else {
|
||||
addNoise = true;
|
||||
noiseLevel = Double.parseDouble(propertyValue);
|
||||
}
|
||||
} else if (propertyName.equalsIgnoreCase(PROP_NUM_THREADS)) {
|
||||
if (propertyValue.equalsIgnoreCase(USE_ALL_THREADS)) {
|
||||
numThreads = Runtime.getRuntime().availableProcessors();
|
||||
} else { // otherwise the user has passed in an integer:
|
||||
numThreads = Integer.parseInt(propertyValue);
|
||||
}
|
||||
} else {
|
||||
// No property was set here
|
||||
propertySet = false;
|
||||
// try the superclass:
|
||||
super.setProperty(propertyName, propertyValue);
|
||||
}
|
||||
if (debug && propertySet) {
|
||||
System.out.println(this.getClass().getSimpleName() + ": Set property " + propertyName +
|
||||
" to " + propertyValue);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finaliseAddObservations() throws Exception {
|
||||
super.finaliseAddObservations();
|
||||
|
||||
if ((observations == null) || (observations[0].length == 0)) {
|
||||
throw new Exception("Computing measure with a null set of data");
|
||||
}
|
||||
if (observations.length <= k + 2*dynCorrExclTime) {
|
||||
throw new Exception("There are less observations provided (" +
|
||||
observations.length +
|
||||
") than required for the number of nearest neighbours parameter (" +
|
||||
k + ") and any dynamic correlation exclusion (" + dynCorrExclTime + ")");
|
||||
}
|
||||
|
||||
// Normalise the data if required
|
||||
if (normalise) {
|
||||
// We can overwrite these since they're already
|
||||
// a copy of the users' data.
|
||||
MatrixUtils.normalise(observations);
|
||||
}
|
||||
|
||||
// Add small random noise if required
|
||||
if (addNoise) {
|
||||
Random random = new Random();
|
||||
// Add Gaussian noise of std dev noiseLevel to the data
|
||||
for (int r = 0; r < observations.length; r++) {
|
||||
for (int c = 0; c < dimensions; c++) {
|
||||
observations[r][c] += random.nextGaussian()*noiseLevel;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Set the constants:
|
||||
digammaK = MathsUtils.digamma(k);
|
||||
digammaN = MathsUtils.digamma(totalObservations);
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal method to ensure that the Kd-tree data structures to represent the
|
||||
* observational data have been constructed (should be called prior to attempting
|
||||
* to use these data structures)
|
||||
*/
|
||||
protected void ensureKdTreesConstructed() throws Exception {
|
||||
|
||||
// We need to construct the k-d trees for use by the child
|
||||
// classes. We check each tree for existence separately
|
||||
// since source can be used across original and surrogate data
|
||||
// TODO can parallelise these -- best done within the kdTree --
|
||||
// though it's unclear if there's much point given that
|
||||
// the tree construction itself afterwards can't really be well parallelised.
|
||||
if (kdTreeJoint == null) {
|
||||
kdTreeJoint = new KdTree(observations);
|
||||
kdTreeJoint.setNormType(normType);
|
||||
}
|
||||
if (rangeSearchersInSmallMarginals == null) {
|
||||
rangeSearchersInSmallMarginals = new UnivariateNearestNeighbourSearcher[dimensions];
|
||||
for (int d = 0; d < dimensions; d++) {
|
||||
rangeSearchersInSmallMarginals[d] = new UnivariateNearestNeighbourSearcher(
|
||||
MatrixUtils.selectColumn(observations, d));
|
||||
rangeSearchersInSmallMarginals[d].setNormType(normType);
|
||||
}
|
||||
}
|
||||
if (rangeSearchersInBigMarginals == null) {
|
||||
rangeSearchersInBigMarginals = new KdTree[dimensions];
|
||||
for (int d = 0; d < dimensions; d++) {
|
||||
rangeSearchersInBigMarginals[d] = new KdTree(
|
||||
MatrixUtils.selectColumns(observations, allExcept(d, dimensions)));
|
||||
rangeSearchersInBigMarginals[d].setNormType(normType);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @return the average measure in nats (not bits!)
|
||||
*/
|
||||
public double computeAverageLocalOfObservations() throws Exception {
|
||||
// Compute the measure
|
||||
double startTime = Calendar.getInstance().getTimeInMillis();
|
||||
lastAverage = computeFromObservations(false)[0];
|
||||
isComputed = true;
|
||||
if (debug) {
|
||||
Calendar rightNow2 = Calendar.getInstance();
|
||||
long endTime = rightNow2.getTimeInMillis();
|
||||
System.out.println("Calculation time: " + ((endTime - startTime)/1000.0) + " sec" );
|
||||
}
|
||||
return lastAverage;
|
||||
}
|
||||
|
||||
/**
|
||||
* {@inheritDoc}
|
||||
*
|
||||
* @return the "time-series" of local measure values in nats (not bits!)
|
||||
* @throws Exception
|
||||
*/
|
||||
public double[] computeLocalOfPreviousObservations() throws Exception {
|
||||
double[] localValues = computeFromObservations(true);
|
||||
lastAverage = MatrixUtils.mean(localValues);
|
||||
isComputed = true;
|
||||
return localValues;
|
||||
}
|
||||
|
||||
/**
|
||||
* This method, specified in {@link MultiVariateInfoMeasureCalculatorCommon}
|
||||
* is not implemented yet here.
|
||||
*/
|
||||
public double[] computeLocalUsingPreviousObservations(double[][] states) throws Exception {
|
||||
// TODO If this is implemented, will need to normalise the incoming
|
||||
// observations the same way that previously supplied ones were
|
||||
// normalised (if they were normalised, that is)
|
||||
throw new Exception("Local method not implemented yet");
|
||||
}
|
||||
|
||||
/**
|
||||
* This protected method handles the multiple threads which
|
||||
* computes either the average or local measure (over parts of the total
|
||||
* observations), computing the
|
||||
* distances between all tuples in time.
|
||||
*
|
||||
* <p>The method returns:<ol>
|
||||
* <li>for (returnLocals == false), an array of size 1,
|
||||
* containing the average measure </li>
|
||||
* <li>for (returnLocals == true), the array of local
|
||||
* measure values</li>
|
||||
* </ol>
|
||||
*
|
||||
* @param returnLocals whether to return an array or local values, or else
|
||||
* sums of these values
|
||||
* @return either the average measure, or array of local measure value,
|
||||
* in nats not bits
|
||||
* @throws Exception
|
||||
*/
|
||||
protected double[] computeFromObservations(boolean returnLocals) throws Exception {
|
||||
|
||||
double[] returnValues = null;
|
||||
|
||||
ensureKdTreesConstructed();
|
||||
|
||||
if (numThreads == 1) {
|
||||
// Single-threaded implementation:
|
||||
returnValues = partialComputeFromObservations(0, totalObservations, returnLocals);
|
||||
|
||||
} else {
|
||||
// We're going multithreaded:
|
||||
if (returnLocals) {
|
||||
// We're computing locals
|
||||
returnValues = new double[totalObservations];
|
||||
} else {
|
||||
// We're computing average
|
||||
returnValues = new double[1];
|
||||
}
|
||||
|
||||
// Distribute the observations to the threads for the parallel processing
|
||||
int lTimesteps = totalObservations / numThreads; // each thread gets the same amount of data
|
||||
int res = totalObservations % numThreads; // the first thread gets the residual data
|
||||
if (debug) {
|
||||
System.out.printf("Computing Kraskov Multi-Info with %d threads (%d timesteps each, plus %d residual)%n",
|
||||
numThreads, lTimesteps, res);
|
||||
}
|
||||
Thread[] tCalculators = new Thread[numThreads];
|
||||
KraskovThreadRunner[] runners = new KraskovThreadRunner[numThreads];
|
||||
for (int t = 0; t < numThreads; t++) {
|
||||
int startTime = (t == 0) ? 0 : lTimesteps * t + res;
|
||||
int numTimesteps = (t == 0) ? lTimesteps + res : lTimesteps;
|
||||
if (debug) {
|
||||
System.out.println(t + ".Thread: from " + startTime +
|
||||
" to " + (startTime + numTimesteps)); // Trace Message
|
||||
}
|
||||
runners[t] = new KraskovThreadRunner(this, startTime, numTimesteps, returnLocals);
|
||||
tCalculators[t] = new Thread(runners[t]);
|
||||
tCalculators[t].start();
|
||||
}
|
||||
|
||||
// Here, we should wait for the termination of the all threads
|
||||
// and collect their results
|
||||
for (int t = 0; t < numThreads; t++) {
|
||||
if (tCalculators[t] != null) { // TODO Ipek: can you comment on why we're checking for null here?
|
||||
tCalculators[t].join();
|
||||
}
|
||||
// Now we add in the data from this completed thread:
|
||||
if (returnLocals) {
|
||||
// We're computing local measure; copy these local values
|
||||
// into the full array of locals
|
||||
System.arraycopy(runners[t].getReturnValues(), 0,
|
||||
returnValues, runners[t].myStartTimePoint, runners[t].numberOfTimePoints);
|
||||
} else {
|
||||
// We're computing the average measure, keep the running sums of digammas and counts
|
||||
MatrixUtils.addInPlace(returnValues, runners[t].getReturnValues());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return returnValues;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Protected method to be used internally for threaded implementations. This
|
||||
* method implements the guts of each Kraskov algorithm, computing the number
|
||||
* of nearest neighbours in each dimension for a sub-set of the data points.
|
||||
* It is intended to be called by one thread to work on that specific sub-set
|
||||
* of the data.
|
||||
*
|
||||
* <p>Child classes should implement the computation of any specific measure
|
||||
* using this method.</p>
|
||||
*
|
||||
* <p>The method returns:<ol>
|
||||
* <li>for average measures (returnLocals == false), the relevant sums of
|
||||
* digamma(n_x+1) in each marginal
|
||||
* for a partial set of the observations</li>
|
||||
* <li>for local measures (returnLocals == true), the array of local values</li>
|
||||
* </ol>
|
||||
*
|
||||
* @param startTimePoint start time for the partial set we examine
|
||||
* @param numTimePoints number of time points (including startTimePoint to examine)
|
||||
* @param returnLocals whether to return an array or local values, or else
|
||||
* sums of these values
|
||||
* @return an array of sum of digamma(n_x+1) for each marginal x, then
|
||||
* sum of n_x for each marginal x (these latter ones are for debugging purposes).
|
||||
* @throws Exception
|
||||
*/
|
||||
protected abstract double[] partialComputeFromObservations(
|
||||
int startTimePoint, int numTimePoints, boolean returnLocals) throws Exception;
|
||||
|
||||
/**
|
||||
* Private class to handle multi-threading of the Kraskov algorithms.
|
||||
* Each instance calls partialComputeFromObservations()
|
||||
* to compute nearest neighbours for a part of the data.
|
||||
*
|
||||
*
|
||||
* @author Joseph Lizier (<a href="joseph.lizier at gmail.com">email</a>,
|
||||
* <a href="http://lizier.me/joseph/">www</a>)
|
||||
* @author Ipek Özdemir
|
||||
*/
|
||||
private class KraskovThreadRunner implements Runnable {
|
||||
protected MultiVariateInfoMeasureCalculatorKraskov calc;
|
||||
protected int myStartTimePoint;
|
||||
protected int numberOfTimePoints;
|
||||
protected boolean computeLocals;
|
||||
|
||||
protected double[] returnValues = null;
|
||||
protected Exception problem = null;
|
||||
|
||||
public static final int INDEX_SUM_DIGAMMAS = 0;
|
||||
|
||||
public KraskovThreadRunner(
|
||||
MultiVariateInfoMeasureCalculatorKraskov calc,
|
||||
int myStartTimePoint, int numberOfTimePoints,
|
||||
boolean computeLocals) {
|
||||
this.calc = calc;
|
||||
this.myStartTimePoint = myStartTimePoint;
|
||||
this.numberOfTimePoints = numberOfTimePoints;
|
||||
this.computeLocals = computeLocals;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the values from this part of the data,
|
||||
* or throw any exception that was encountered by the
|
||||
* thread.
|
||||
*
|
||||
* @return an exception previously encountered by this thread.
|
||||
* @throws Exception
|
||||
*/
|
||||
public double[] getReturnValues() throws Exception {
|
||||
if (problem != null) {
|
||||
throw problem;
|
||||
}
|
||||
return returnValues;
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the thread for the given parameters
|
||||
*/
|
||||
public void run() {
|
||||
try {
|
||||
returnValues = calc.partialComputeFromObservations(
|
||||
myStartTimePoint, numberOfTimePoints, computeLocals);
|
||||
} catch (Exception e) {
|
||||
// Store the exception for later retrieval
|
||||
problem = e;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
// end class KraskovThreadRunner
|
||||
|
||||
}
|
157
java/source/infodynamics/measures/continuous/kraskov/OInfoCalculatorKraskov.java
Executable file
157
java/source/infodynamics/measures/continuous/kraskov/OInfoCalculatorKraskov.java
Executable file
|
@ -0,0 +1,157 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous.kraskov;
|
||||
|
||||
import infodynamics.utils.EuclideanUtils;
|
||||
import infodynamics.utils.NeighbourNodeData;
|
||||
import infodynamics.utils.KdTree;
|
||||
import infodynamics.utils.UnivariateNearestNeighbourSearcher;
|
||||
import infodynamics.utils.MathsUtils;
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.Calendar;
|
||||
import java.util.Random;
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* <p>Computes the differential O-information of a given multivariate
|
||||
* set of observations, using Kraskov-Stoegbauer-Grassberger (KSG) estimation
|
||||
* (see Kraskov et al., below).</p>
|
||||
*
|
||||
* <p>Usage is as per the paradigm outlined for
|
||||
* {@link MultiVariateInfoMeasureCalculatorCommon}.</p>
|
||||
*
|
||||
* <p>Finally, note that {@link Cloneable} is implemented allowing clone()
|
||||
* to produce only an automatic shallow copy, which is fine
|
||||
* for the statistical significance calculation it is intended for
|
||||
* (none of the array
|
||||
* data will be changed there).
|
||||
* </p>
|
||||
*
|
||||
* <p><b>References:</b><br/>
|
||||
* <ul>
|
||||
* <li>Rosas, F., Mediano, P., Gastpar, M., Jensen, H.,
|
||||
* "Quantifying high-order effects via multivariate extensions of the
|
||||
* mutual information".</li>
|
||||
*
|
||||
* <li>Kraskov, A., Stoegbauer, H., Grassberger, P.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.69.066138">"Estimating mutual information"</a>,
|
||||
* Physical Review E 69, (2004) 066138.</li>
|
||||
* </ul>
|
||||
*
|
||||
* @author Pedro A.M. Mediano (<a href="pmediano at pm.me">email</a>,
|
||||
* <a href="http://www.doc.ic.ac.uk/~pam213">www</a>)
|
||||
*/
|
||||
public class OInfoCalculatorKraskov
|
||||
extends MultiVariateInfoMeasureCalculatorKraskov
|
||||
implements Cloneable { // See comments on clonability above
|
||||
|
||||
|
||||
protected double[] partialComputeFromObservations(
|
||||
int startTimePoint, int numTimePoints, boolean returnLocals) throws Exception {
|
||||
|
||||
// If data is 2D, return 0 before doing any computation
|
||||
if (dimensions == 2) {
|
||||
if (returnLocals) {
|
||||
double[] localMi = new double[numTimePoints];
|
||||
Arrays.fill(localMi, 0);
|
||||
return localMi;
|
||||
} else {
|
||||
return new double[] {0};
|
||||
}
|
||||
}
|
||||
|
||||
double startTime = Calendar.getInstance().getTimeInMillis();
|
||||
|
||||
double[] localMi = null;
|
||||
if (returnLocals) {
|
||||
localMi = new double[numTimePoints];
|
||||
}
|
||||
|
||||
// Constants:
|
||||
double dimensionsMinus1TimesDiGammaN = (double) (dimensions - 1) * digammaN;
|
||||
|
||||
// Count the average number of points within eps_x for each marginal x of each point
|
||||
double totalSumF = 0.0;
|
||||
|
||||
for (int t = startTimePoint; t < startTimePoint + numTimePoints; t++) {
|
||||
// Compute eps for this time step by
|
||||
// finding the kth closest neighbour for point t:
|
||||
PriorityQueue<NeighbourNodeData> nnPQ =
|
||||
kdTreeJoint.findKNearestNeighbours(k, t, dynCorrExclTime);
|
||||
// First element in the PQ is the kth NN,
|
||||
// and epsilon = kthNnData.distance
|
||||
NeighbourNodeData kthNnData = nnPQ.poll();
|
||||
|
||||
// Distance to kth neighbour in joint space
|
||||
double eps = kthNnData.distance;
|
||||
|
||||
double sumF = 0.0;
|
||||
|
||||
sumF += (digammaK - digammaN);
|
||||
|
||||
for (int d = 0; d < dimensions; d++) {
|
||||
int n_small = rangeSearchersInSmallMarginals[d].countPointsStrictlyWithinR(
|
||||
t, eps, dynCorrExclTime);
|
||||
|
||||
int n_big = rangeSearchersInBigMarginals[d].countPointsStrictlyWithinR(
|
||||
t, eps, dynCorrExclTime);
|
||||
|
||||
|
||||
sumF -= (MathsUtils.digamma(n_big + 1) - digammaN)/(dimensions - 2);
|
||||
sumF += (MathsUtils.digamma(n_small + 1) - digammaN)/(dimensions - 2);
|
||||
|
||||
if (debug) {
|
||||
// Only tracking this for debugging purposes:
|
||||
System.out.printf("t=%d, d=%d, n_small=%d, n_big=%d, sumF=%.3f%n",
|
||||
t, d, n_small, n_big, sumF);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
sumF *= (2 - dimensions);
|
||||
|
||||
totalSumF += sumF;
|
||||
|
||||
if (returnLocals) {
|
||||
localMi[t-startTimePoint] = sumF;
|
||||
}
|
||||
}
|
||||
|
||||
if (debug) {
|
||||
Calendar rightNow2 = Calendar.getInstance();
|
||||
long endTime = rightNow2.getTimeInMillis();
|
||||
System.out.println("Subset " + startTimePoint + ":" +
|
||||
(startTimePoint + numTimePoints) + " Calculation time: " +
|
||||
((endTime - startTime)/1000.0) + " sec" );
|
||||
}
|
||||
|
||||
// Select what to return:
|
||||
if (returnLocals) {
|
||||
return localMi;
|
||||
} else {
|
||||
double[] returnArray = new double[] {totalSumF/((double) totalObservations)};
|
||||
return returnArray;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
146
java/source/infodynamics/measures/continuous/kraskov/SInfoCalculatorKraskov.java
Executable file
146
java/source/infodynamics/measures/continuous/kraskov/SInfoCalculatorKraskov.java
Executable file
|
@ -0,0 +1,146 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous.kraskov;
|
||||
|
||||
import infodynamics.utils.EuclideanUtils;
|
||||
import infodynamics.utils.NeighbourNodeData;
|
||||
import infodynamics.utils.KdTree;
|
||||
import infodynamics.utils.UnivariateNearestNeighbourSearcher;
|
||||
import infodynamics.utils.MathsUtils;
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.Calendar;
|
||||
import java.util.Random;
|
||||
import java.util.Arrays;
|
||||
|
||||
/**
|
||||
* <p>Computes the differential S-information of a given multivariate
|
||||
* set of observations, using Kraskov-Stoegbauer-Grassberger (KSG) estimation
|
||||
* (see Kraskov et al., below).</p>
|
||||
*
|
||||
* <p>Usage is as per the paradigm outlined for
|
||||
* {@link MultiVariateInfoMeasureCalculatorCommon}.</p>
|
||||
*
|
||||
* <p>Finally, note that {@link Cloneable} is implemented allowing clone()
|
||||
* to produce only an automatic shallow copy, which is fine
|
||||
* for the statistical significance calculation it is intended for
|
||||
* (none of the array
|
||||
* data will be changed there).
|
||||
* </p>
|
||||
*
|
||||
* <p><b>References:</b><br/>
|
||||
* <ul>
|
||||
* <li>Rosas, F., Mediano, P., Gastpar, M., Jensen, H.,
|
||||
* "Quantifying high-order effects via multivariate extensions of the
|
||||
* mutual information".</li>
|
||||
*
|
||||
* <li>Kraskov, A., Stoegbauer, H., Grassberger, P.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.69.066138">"Estimating mutual information"</a>,
|
||||
* Physical Review E 69, (2004) 066138.</li>
|
||||
* </ul>
|
||||
*
|
||||
* @author Pedro A.M. Mediano (<a href="pmediano at pm.me">email</a>,
|
||||
* <a href="http://www.doc.ic.ac.uk/~pam213">www</a>)
|
||||
*/
|
||||
public class SInfoCalculatorKraskov
|
||||
extends MultiVariateInfoMeasureCalculatorKraskov
|
||||
implements Cloneable { // See comments on clonability above
|
||||
|
||||
|
||||
protected double[] partialComputeFromObservations(
|
||||
int startTimePoint, int numTimePoints, boolean returnLocals) throws Exception {
|
||||
|
||||
double startTime = Calendar.getInstance().getTimeInMillis();
|
||||
|
||||
double[] localMi = null;
|
||||
if (returnLocals) {
|
||||
localMi = new double[numTimePoints];
|
||||
}
|
||||
|
||||
// Constants:
|
||||
double dimensionsMinus1TimesDiGammaN = (double) (dimensions - 1) * digammaN;
|
||||
|
||||
// Count the average number of points within eps_x for each marginal x of each point
|
||||
double totalSumF = 0.0;
|
||||
|
||||
for (int t = startTimePoint; t < startTimePoint + numTimePoints; t++) {
|
||||
// Compute eps for this time step by
|
||||
// finding the kth closest neighbour for point t:
|
||||
PriorityQueue<NeighbourNodeData> nnPQ =
|
||||
kdTreeJoint.findKNearestNeighbours(k, t, dynCorrExclTime);
|
||||
// First element in the PQ is the kth NN,
|
||||
// and epsilon = kthNnData.distance
|
||||
NeighbourNodeData kthNnData = nnPQ.poll();
|
||||
|
||||
// Distance to kth neighbour in joint space
|
||||
double eps = kthNnData.distance;
|
||||
|
||||
double sumF = 0.0;
|
||||
|
||||
sumF += (digammaK - digammaN);
|
||||
|
||||
for (int d = 0; d < dimensions; d++) {
|
||||
int n_small = rangeSearchersInSmallMarginals[d].countPointsStrictlyWithinR(
|
||||
t, eps, dynCorrExclTime);
|
||||
|
||||
int n_big = rangeSearchersInBigMarginals[d].countPointsStrictlyWithinR(
|
||||
t, eps, dynCorrExclTime);
|
||||
|
||||
|
||||
sumF -= (MathsUtils.digamma(n_big + 1) - digammaN)/dimensions;
|
||||
sumF -= (MathsUtils.digamma(n_small + 1) - digammaN)/dimensions;
|
||||
|
||||
if (debug) {
|
||||
// Only tracking this for debugging purposes:
|
||||
System.out.printf("t=%d, d=%d, n_small=%d, n_big=%d, sumF=%.3f%n",
|
||||
t, d, n_small, n_big, sumF);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
sumF *= dimensions;
|
||||
|
||||
totalSumF += sumF;
|
||||
|
||||
if (returnLocals) {
|
||||
localMi[t-startTimePoint] = sumF;
|
||||
}
|
||||
}
|
||||
|
||||
if (debug) {
|
||||
Calendar rightNow2 = Calendar.getInstance();
|
||||
long endTime = rightNow2.getTimeInMillis();
|
||||
System.out.println("Subset " + startTimePoint + ":" +
|
||||
(startTimePoint + numTimePoints) + " Calculation time: " +
|
||||
((endTime - startTime)/1000.0) + " sec" );
|
||||
}
|
||||
|
||||
// Select what to return:
|
||||
if (returnLocals) {
|
||||
return localMi;
|
||||
} else {
|
||||
double[] returnArray = new double[] {totalSumF/((double) totalObservations)};
|
||||
return returnArray;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.discrete;
|
||||
|
||||
import infodynamics.utils.MathsUtils;
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
|
||||
/**
|
||||
* <p>Computes the dual total correlation (DTC) of a given multivariate
|
||||
* <code>int[][]</code> set of
|
||||
* observations (extending {@link MultiVariateInfoMeasureCalculatorDiscrete}).</p>
|
||||
*
|
||||
* <p>Usage is as per the paradigm outlined for {@link MultiVariateInfoMeasureCalculatorDiscrete}.
|
||||
* </p>
|
||||
*
|
||||
* <p><b>References:</b><br/>
|
||||
* <ul>
|
||||
* <li>Rosas, F., Mediano, P., Gastpar, M, Jensen, H.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.100.032305">"Quantifying high-order
|
||||
* interdependencies via multivariate extensions of the mutual information"</a>,
|
||||
* Physical Review E 100, (2019) 032305.</li>
|
||||
* </ul>
|
||||
*
|
||||
* @author Pedro A.M. Mediano (<a href="pmediano at pm.me">email</a>,
|
||||
* <a href="http://www.doc.ic.ac.uk/~pam213">www</a>)
|
||||
*/
|
||||
public class DualTotalCorrelationCalculatorDiscrete
|
||||
extends MultiVariateInfoMeasureCalculatorDiscrete {
|
||||
|
||||
/**
|
||||
* Construct an instance.
|
||||
*
|
||||
* @param base number of symbols for each variable.
|
||||
* E.g. binary variables are in base-2.
|
||||
* @param numVars numbers of joint variables that DTC
|
||||
* will be computed over.
|
||||
*/
|
||||
public DualTotalCorrelationCalculatorDiscrete(int base, int numVars) {
|
||||
super(base, numVars);
|
||||
}
|
||||
|
||||
protected double computeLocalValueForTuple(int[] tuple, int jointValue) {
|
||||
|
||||
if (jointCount[jointValue] == 0) {
|
||||
// This joint state does not occur, so it makes no contribution here
|
||||
return 0;
|
||||
}
|
||||
|
||||
double jointProb = (double) jointCount[jointValue] / (double) observations;
|
||||
double logValue = (numVars - 1) * Math.log(jointProb);
|
||||
|
||||
for (int i = 0; i < numVars; i++) {
|
||||
int marginalState = computeBigMarginalState(jointValue, i, tuple[i]);
|
||||
double marginalProb = (double) bigMarginalCounts[i][marginalState] / (double) observations;
|
||||
logValue -= Math.log(marginalProb);
|
||||
}
|
||||
|
||||
double localValue = logValue / log_2;
|
||||
|
||||
if (jointProb > 0.0) {
|
||||
checkLocals(localValue);
|
||||
}
|
||||
|
||||
return localValue;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,279 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.discrete;
|
||||
|
||||
import infodynamics.utils.MathsUtils;
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
|
||||
/**
|
||||
* Implements a base class with common functionality for child class
|
||||
* implementations of multivariate information measures.
|
||||
*
|
||||
* <p>Multivariate information measures are functionals of probability
|
||||
* distributions over <code>R^n</code>, and typical examples include multi-information
|
||||
* (a.k.a. total correlation), dual total correlation, O-information, and connected
|
||||
* information.</p>
|
||||
*
|
||||
* <p>Usage of child classes is intended to follow this paradigm:</p>
|
||||
* <ol>
|
||||
* <li>Construct the calculator;</li>
|
||||
* <li>Initialise the calculator using {@link #initialise()};</li>
|
||||
* <li>Provide the observations/samples for the calculator
|
||||
* to set up the PDFs, using one or more calls to
|
||||
* sets of {@link #addObservations(int[][], int[])} methods, then</li>
|
||||
* <li>Compute the required quantities, being one or more of:
|
||||
* <ul>
|
||||
* <li>the average measure: {@link #computeAverageLocalOfObservations()};</li>
|
||||
* </ul>
|
||||
* </li>
|
||||
* <li>
|
||||
* Return to step 2 to re-use the calculator on a new data set.
|
||||
* </li>
|
||||
* </ol>
|
||||
*
|
||||
* <p><b>References:</b><br/>
|
||||
* <ul>
|
||||
* <li>Rosas, F., Mediano, P., Gastpar, M, Jensen, H.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.100.032305">"Quantifying high-order
|
||||
* interdependencies via multivariate extensions of the mutual information"</a>,
|
||||
* Physical Review E 100, (2019) 032305.</li>
|
||||
* </ul>
|
||||
*
|
||||
* @author Pedro A.M. Mediano (<a href="pmediano at pm.me">email</a>,
|
||||
* <a href="http://www.doc.ic.ac.uk/~pam213">www</a>)
|
||||
*/
|
||||
public abstract class MultiVariateInfoMeasureCalculatorDiscrete
|
||||
extends InfoMeasureCalculatorDiscrete {
|
||||
|
||||
/**
|
||||
* Count of occurrences of each joint state in the provided observations.
|
||||
*/
|
||||
protected int[] jointCount = null;
|
||||
|
||||
/**
|
||||
* Count of occurrences of each state of each variable in the provided
|
||||
* observations.
|
||||
*
|
||||
* For a given variable <code>v</code> and state <code>i</code>,
|
||||
* <code>smallMarginalCounts[v][i]</code> counts how many times variable
|
||||
* <code>v</code> was observed in state <code>i</code>,
|
||||
*/
|
||||
protected int[][] smallMarginalCounts = null; // marginalCounts[marginalIndex][state]
|
||||
|
||||
/**
|
||||
* Count of occurrences of each state of each (D-1)-dimensional marginal in
|
||||
* the provided observations.
|
||||
*
|
||||
* For a given variable <code>v</code> and state <code>i</code>,
|
||||
* <code>bigMarginalCounts[v][i]</code> counts how many times the _rest_ of
|
||||
* the system, excluding variable <code>v</code>, was observed in state <code>i</code>,
|
||||
*/
|
||||
protected int[][] bigMarginalCounts = null;
|
||||
|
||||
/**
|
||||
* Number of variables in the system.
|
||||
*/
|
||||
protected int numVars;
|
||||
|
||||
/**
|
||||
* Number of possible states of the whole system.
|
||||
*/
|
||||
protected int jointStates;
|
||||
|
||||
/**
|
||||
* Whether the first local value has been checked. (Used to initialise some variables
|
||||
* related to computation of local values.
|
||||
*/
|
||||
protected boolean checkedFirst = false;
|
||||
|
||||
/**
|
||||
* Abstract constructor (to be called by child classes).
|
||||
*
|
||||
* @param base number of symbols for each variable.
|
||||
* E.g. binary variables are in base-2.
|
||||
* @param numVars numbers of joint variables that the measure
|
||||
* will be computed over.
|
||||
*/
|
||||
protected MultiVariateInfoMeasureCalculatorDiscrete(int base, int numVars) {
|
||||
super(base);
|
||||
this.numVars = numVars;
|
||||
jointStates = MathsUtils.power(base, numVars);
|
||||
try {
|
||||
jointCount = new int[jointStates];
|
||||
smallMarginalCounts = new int[numVars][base];
|
||||
bigMarginalCounts = new int[numVars][jointStates];
|
||||
} catch (OutOfMemoryError e) {
|
||||
// Allow any Exceptions to be thrown, but catch and wrap
|
||||
// Error as a RuntimeException
|
||||
throw new RuntimeException("Requested memory for the base " +
|
||||
base + " with " + numVars +
|
||||
" variables is too large for the JVM at this time", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initialise(){
|
||||
super.initialise();
|
||||
MatrixUtils.fill(jointCount, 0);
|
||||
MatrixUtils.fill(smallMarginalCounts, 0);
|
||||
MatrixUtils.fill(bigMarginalCounts, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Given multiple time samples of a homogeneous array of variables (states),
|
||||
* add the observations of all sets of numVars of these
|
||||
* Do this for every time point
|
||||
*
|
||||
* @param states 2D array of values of an array of variables
|
||||
* at many observations (first index is time, second is variable index)
|
||||
*/
|
||||
public void addObservations(int[][] states) throws Exception {
|
||||
int[] jointStates = MatrixUtils.computeCombinedValues(states, base);
|
||||
for (int t = 0; t < states.length; t++) {
|
||||
for (int i = 0; i < numVars; i++) {
|
||||
// Extract values of the 1D and the (N-1)D marginals
|
||||
int thisValue = states[t][i];
|
||||
int bigMarginalState = computeBigMarginalState(jointStates[t], i, thisValue);
|
||||
|
||||
// Update counts
|
||||
bigMarginalCounts[i][bigMarginalState]++;
|
||||
smallMarginalCounts[i][thisValue]++;
|
||||
}
|
||||
jointCount[jointStates[t]]++;
|
||||
observations++;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public double computeAverageLocalOfObservations() {
|
||||
|
||||
int[] jointTuple = new int[numVars];
|
||||
checkedFirst = false;
|
||||
try {
|
||||
average = computeForGivenTupleFromVarIndex(jointTuple, 0);
|
||||
} catch (Exception e) {
|
||||
System.out.println("Something went wrong during the calculation.");
|
||||
average = -1;
|
||||
}
|
||||
|
||||
return average;
|
||||
}
|
||||
|
||||
/**
|
||||
* Private utility to compute the contribution to the measure for all tuples
|
||||
* starting with tuple[0..(fromIndex-1)].
|
||||
*
|
||||
* @param tuple
|
||||
* @param fromIndex
|
||||
* @return
|
||||
*/
|
||||
public double computeForGivenTupleFromVarIndex(int[] tuple, int fromIndex) throws Exception {
|
||||
double miCont = 0;
|
||||
if (fromIndex == numVars) {
|
||||
// The whole tuple is filled in, so compute the contribution to the MI from this tuple
|
||||
int jointValue = MatrixUtils.computeCombinedValues(new int[][] {tuple}, base)[0];
|
||||
|
||||
if (jointCount[jointValue] == 0) {
|
||||
// This joint state does not occur, so it makes no contribution here
|
||||
return 0;
|
||||
}
|
||||
|
||||
double jointProb = (double) jointCount[jointValue] / (double) observations;
|
||||
double localValue = computeLocalValueForTuple(tuple, jointValue);
|
||||
miCont = jointProb * localValue;
|
||||
|
||||
} else {
|
||||
// Fill out the next part of the tuple and make the recursive calls
|
||||
for (int v = 0; v < base; v++) {
|
||||
tuple[fromIndex] = v;
|
||||
miCont += computeForGivenTupleFromVarIndex(tuple, fromIndex + 1);
|
||||
}
|
||||
}
|
||||
return miCont;
|
||||
}
|
||||
|
||||
/**
|
||||
* Shortcut method to initialise the calculator, add observations and compute
|
||||
* the average measure in one line.
|
||||
*
|
||||
* @param state series of multivariate observations
|
||||
* (first index is time or observation index, second is variable number)
|
||||
*/
|
||||
public double compute(int[][] states) throws Exception {
|
||||
initialise();
|
||||
addObservations(states);
|
||||
return computeAverageLocalOfObservations();
|
||||
}
|
||||
|
||||
/**
|
||||
* Internal method to update maximum and minimum values of local information
|
||||
* measures.
|
||||
*
|
||||
* @param localValue instance of computed local information measure
|
||||
*/
|
||||
protected void checkLocals(double localValue) {
|
||||
if (!checkedFirst) {
|
||||
max = localValue;
|
||||
min = localValue;
|
||||
checkedFirst = true;
|
||||
} else {
|
||||
if (localValue > max) {
|
||||
max = localValue;
|
||||
}
|
||||
if (localValue < min) {
|
||||
min = localValue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Method to be implemented by all child classes to compute the local value
|
||||
* of the measure for a given tuple.
|
||||
*
|
||||
* @param tuple state of the system at a given time (index is variable number)
|
||||
* @param jointValue <code>int</code> representing the state of the system
|
||||
*/
|
||||
protected abstract double computeLocalValueForTuple(int[] tuple, int jointValue)
|
||||
throws Exception;
|
||||
|
||||
/**
|
||||
* Method to be implemented by all child classes to compute the local value
|
||||
* of the measure for a given tuple.
|
||||
*
|
||||
* @param tuple state of the system at a given time (index is variable number)
|
||||
*/
|
||||
protected double computeLocalValueForTuple(int[] tuple) throws Exception {
|
||||
int jointValue = MatrixUtils.computeCombinedValues(new int[][] {tuple}, base)[0];
|
||||
return computeLocalValueForTuple(tuple, jointValue);
|
||||
}
|
||||
|
||||
/**
|
||||
* Small utility function to compute the state of the system excluding one variable.
|
||||
*
|
||||
* @param jointState state of the full system
|
||||
* @param varIdx index of the variable to be excluded
|
||||
* @param varValue value of the variable in question in the system state
|
||||
*/
|
||||
protected int computeBigMarginalState(int jointState, int varIdx, int varValue) {
|
||||
int bigMarginalState = jointState - varValue*MathsUtils.power(base, numVars - varIdx - 1);
|
||||
return bigMarginalState;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,85 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.discrete;
|
||||
|
||||
import infodynamics.utils.MathsUtils;
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
|
||||
/**
|
||||
* <p>Computes the O-information of a given multivariate
|
||||
* <code>int[][]</code> set of
|
||||
* observations (extending {@link MultiVariateInfoMeasureCalculatorDiscrete}).</p>
|
||||
*
|
||||
* <p>Usage is as per the paradigm outlined for {@link MultiVariateInfoMeasureCalculatorDiscrete}.
|
||||
* </p>
|
||||
*
|
||||
* <p><b>References:</b><br/>
|
||||
* <ul>
|
||||
* <li>Rosas, F., Mediano, P., Gastpar, M, Jensen, H.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.100.032305">"Quantifying high-order
|
||||
* interdependencies via multivariate extensions of the mutual information"</a>,
|
||||
* Physical Review E 100, (2019) 032305.</li>
|
||||
* </ul>
|
||||
*
|
||||
* @author Pedro A.M. Mediano (<a href="pmediano at pm.me">email</a>,
|
||||
* <a href="http://www.doc.ic.ac.uk/~pam213">www</a>)
|
||||
*/
|
||||
public class OInfoCalculatorDiscrete
|
||||
extends MultiVariateInfoMeasureCalculatorDiscrete {
|
||||
|
||||
/**
|
||||
* Construct an instance.
|
||||
*
|
||||
* @param base number of symbols for each variable.
|
||||
* E.g. binary variables are in base-2.
|
||||
* @param numVars numbers of joint variables that DTC
|
||||
* will be computed over.
|
||||
*/
|
||||
public OInfoCalculatorDiscrete(int base, int numVars) {
|
||||
super(base, numVars);
|
||||
}
|
||||
|
||||
protected double computeLocalValueForTuple(int[] tuple, int jointValue) {
|
||||
|
||||
if (jointCount[jointValue] == 0) {
|
||||
// This joint state does not occur, so it makes no contribution here
|
||||
return 0;
|
||||
}
|
||||
|
||||
double jointProb = (double) jointCount[jointValue] / (double) observations;
|
||||
double logValue = (2 - numVars) * Math.log(jointProb);
|
||||
|
||||
for (int i = 0; i < numVars; i++) {
|
||||
int bigMarginalState = computeBigMarginalState(jointValue, i, tuple[i]);
|
||||
double bigMarginalProb = (double) bigMarginalCounts[i][bigMarginalState] / (double) observations;
|
||||
double smallMarginalProb = (double) smallMarginalCounts[i][tuple[i]] / (double) observations;
|
||||
logValue += Math.log(bigMarginalProb) - Math.log(smallMarginalProb);
|
||||
}
|
||||
|
||||
double localValue = logValue / log_2;
|
||||
|
||||
if (jointProb > 0.0) {
|
||||
checkLocals(localValue);
|
||||
}
|
||||
|
||||
return localValue;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.discrete;
|
||||
|
||||
import infodynamics.utils.MathsUtils;
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
|
||||
/**
|
||||
* <p>Computes the S-information of a given multivariate
|
||||
* <code>int[][]</code> set of
|
||||
* observations (extending {@link MultiVariateInfoMeasureCalculatorDiscrete}).</p>
|
||||
*
|
||||
* <p>Usage is as per the paradigm outlined for {@link MultiVariateInfoMeasureCalculatorDiscrete}.
|
||||
* </p>
|
||||
*
|
||||
* <p><b>References:</b><br/>
|
||||
* <ul>
|
||||
* <li>Rosas, F., Mediano, P., Gastpar, M, Jensen, H.,
|
||||
* <a href="http://dx.doi.org/10.1103/PhysRevE.100.032305">"Quantifying high-order
|
||||
* interdependencies via multivariate extensions of the mutual information"</a>,
|
||||
* Physical Review E 100, (2019) 032305.</li>
|
||||
* </ul>
|
||||
*
|
||||
* @author Pedro A.M. Mediano (<a href="pmediano at pm.me">email</a>,
|
||||
* <a href="http://www.doc.ic.ac.uk/~pam213">www</a>)
|
||||
*/
|
||||
public class SInfoCalculatorDiscrete
|
||||
extends MultiVariateInfoMeasureCalculatorDiscrete {
|
||||
|
||||
/**
|
||||
* Construct an instance.
|
||||
*
|
||||
* @param base number of symbols for each variable.
|
||||
* E.g. binary variables are in base-2.
|
||||
* @param numVars numbers of joint variables that DTC
|
||||
* will be computed over.
|
||||
*/
|
||||
public SInfoCalculatorDiscrete(int base, int numVars) {
|
||||
super(base, numVars);
|
||||
}
|
||||
|
||||
protected double computeLocalValueForTuple(int[] tuple, int jointValue) {
|
||||
|
||||
if (jointCount[jointValue] == 0) {
|
||||
// This joint state does not occur, so it makes no contribution here
|
||||
return 0;
|
||||
}
|
||||
|
||||
double jointProb = (double) jointCount[jointValue] / (double) observations;
|
||||
|
||||
// Local TC value
|
||||
double localTC = Math.log(jointProb);
|
||||
for (int i = 0; i < numVars; i++) {
|
||||
int marginalState = tuple[i];
|
||||
double marginalProb = (double) smallMarginalCounts[i][marginalState] / (double) observations;
|
||||
localTC -= Math.log(marginalProb);
|
||||
}
|
||||
|
||||
// Local DTC value
|
||||
double localDTC = (numVars - 1) * Math.log(jointProb);
|
||||
for (int i = 0; i < numVars; i++) {
|
||||
int marginalState = computeBigMarginalState(jointValue, i, tuple[i]);
|
||||
double marginalProb = (double) bigMarginalCounts[i][marginalState] / (double) observations;
|
||||
localDTC -= Math.log(marginalProb);
|
||||
}
|
||||
|
||||
// Combine local TC and DTC into S-info
|
||||
double logValue = localTC + localDTC;
|
||||
double localValue = logValue / log_2;
|
||||
|
||||
if (jointProb > 0.0) {
|
||||
checkLocals(localValue);
|
||||
}
|
||||
|
||||
return localValue;
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -819,6 +819,51 @@ public class MatrixUtils {
|
|||
return returnValues;
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiplies all items in an array times a constant value
|
||||
*
|
||||
* @param array
|
||||
* @param value
|
||||
* @return array * constant value
|
||||
*/
|
||||
public static int[] multiply(int[] array, int value) throws Exception {
|
||||
int[] returnValues = new int[array.length];
|
||||
for (int i = 0; i < returnValues.length; i++) {
|
||||
returnValues[i] = array[i] * value;
|
||||
}
|
||||
return returnValues;
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiplies all items in an array times a constant value
|
||||
*
|
||||
* @param array
|
||||
* @param value
|
||||
* @return array * constant value
|
||||
*/
|
||||
public static double[] multiply(int[] array, double value) throws Exception {
|
||||
double[] returnValues = new double[array.length];
|
||||
for (int i = 0; i < returnValues.length; i++) {
|
||||
returnValues[i] = array[i] * value;
|
||||
}
|
||||
return returnValues;
|
||||
}
|
||||
|
||||
/**
|
||||
* Multiplies all items in an array times a constant value
|
||||
*
|
||||
* @param array
|
||||
* @param value
|
||||
* @return array * constant value
|
||||
*/
|
||||
public static double[] multiply(double[] array, double value) throws Exception {
|
||||
double[] returnValues = new double[array.length];
|
||||
for (int i = 0; i < returnValues.length; i++) {
|
||||
returnValues[i] = array[i] * value;
|
||||
}
|
||||
return returnValues;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the matrix product A x B
|
||||
*
|
||||
|
@ -1535,6 +1580,84 @@ public class MatrixUtils {
|
|||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the required columns from the matrix
|
||||
*
|
||||
* @param matrix
|
||||
* @param fromCol
|
||||
* @param cols
|
||||
* @return
|
||||
*/
|
||||
public static int[][] selectColumns(int matrix[][],
|
||||
int fromCol, int cols) {
|
||||
int[][] data = new int[matrix.length][cols];
|
||||
for (int r = 0; r < matrix.length; r++) {
|
||||
for (int cIndex = 0; cIndex < cols; cIndex++) {
|
||||
data[r][cIndex] = matrix[r][cIndex + fromCol];
|
||||
}
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the required columns from the matrix
|
||||
*
|
||||
* @param matrix
|
||||
* @param columns
|
||||
* @return
|
||||
*/
|
||||
public static int[][] selectColumns(int matrix[][], int columns[]) {
|
||||
int[][] data = new int[matrix.length][columns.length];
|
||||
for (int r = 0; r < matrix.length; r++) {
|
||||
for (int cIndex = 0; cIndex < columns.length; cIndex++) {
|
||||
data[r][cIndex] = matrix[r][columns[cIndex]];
|
||||
}
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the required columns from the matrix
|
||||
*
|
||||
* @param matrix
|
||||
* @param includeColumnFlags
|
||||
* @return
|
||||
*/
|
||||
public static int[][] selectColumns(int matrix[][], boolean includeColumnFlags[]) {
|
||||
Vector<Integer> v = new Vector<Integer>();
|
||||
|
||||
for (int i = 0; i < includeColumnFlags.length; i++) {
|
||||
if (includeColumnFlags[i]) {
|
||||
v.add(i);
|
||||
}
|
||||
}
|
||||
int[][] data = new int[matrix.length][v.size()];
|
||||
for (int r = 0; r < matrix.length; r++) {
|
||||
for (int outputColumnIndex = 0; outputColumnIndex < v.size(); outputColumnIndex++) {
|
||||
int outputColumn = v.get(outputColumnIndex);
|
||||
data[r][outputColumnIndex] = matrix[r][outputColumn];
|
||||
}
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the required columns from the matrix
|
||||
*
|
||||
* @param matrix
|
||||
* @param columns
|
||||
* @return
|
||||
*/
|
||||
public static int[][] selectColumns(int matrix[][], List<Integer> columns) {
|
||||
int[][] data = new int[matrix.length][columns.size()];
|
||||
for (int r = 0; r < matrix.length; r++) {
|
||||
for (int cIndex = 0; cIndex < columns.size(); cIndex++) {
|
||||
data[r][cIndex] = matrix[r][columns.get(cIndex)];
|
||||
}
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the required rows from the matrix
|
||||
*
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous.gaussian;
|
||||
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
import infodynamics.utils.RandomGenerator;
|
||||
import junit.framework.TestCase;
|
||||
|
||||
public class DualTotalCorrelationCalculatorGaussianTester extends TestCase {
|
||||
|
||||
/**
|
||||
* For two variables, DTC is equal to mutual information.
|
||||
*/
|
||||
public void testTwoVariables() throws Exception {
|
||||
double[][] cov = new double[][] {{1, 0.5}, {0.5, 1}};
|
||||
|
||||
DualTotalCorrelationCalculatorGaussian dtcCalc = new DualTotalCorrelationCalculatorGaussian();
|
||||
dtcCalc.initialise(2);
|
||||
dtcCalc.setCovariance(cov);
|
||||
double dtc = dtcCalc.computeAverageLocalOfObservations();
|
||||
|
||||
MutualInfoCalculatorMultiVariateGaussian miCalc = new MutualInfoCalculatorMultiVariateGaussian();
|
||||
miCalc.initialise(1,1);
|
||||
miCalc.setCovariance(cov, 1);
|
||||
double mi = miCalc.computeAverageLocalOfObservations();
|
||||
|
||||
assertEquals(dtc, mi, 1e-6);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compare against the direct calculation of DTC as a sum of entropies using the
|
||||
* entropy calculator.
|
||||
*/
|
||||
public void testCompareWithEntropy() throws Exception {
|
||||
double[][] cov = new double[][] {{1, 0.4, 0.3}, {0.4, 1, 0.2}, {0.3, 0.2, 1}};
|
||||
|
||||
DualTotalCorrelationCalculatorGaussian dtcCalc = new DualTotalCorrelationCalculatorGaussian();
|
||||
dtcCalc.initialise(3);
|
||||
dtcCalc.setCovariance(cov);
|
||||
double dtc = dtcCalc.computeAverageLocalOfObservations();
|
||||
|
||||
// Calculate using an entropy calculator and picking submatrices manually
|
||||
EntropyCalculatorMultiVariateGaussian hCalc = new EntropyCalculatorMultiVariateGaussian();
|
||||
hCalc.initialise(3);
|
||||
hCalc.setCovariance(cov);
|
||||
double dtc_hCalc = -2 * hCalc.computeAverageLocalOfObservations();
|
||||
|
||||
hCalc.initialise(2);
|
||||
hCalc.setCovariance(MatrixUtils.selectRowsAndColumns(cov, new int[] {0,1}, new int[] {0,1}));
|
||||
dtc_hCalc += hCalc.computeAverageLocalOfObservations();
|
||||
hCalc.initialise(2);
|
||||
hCalc.setCovariance(MatrixUtils.selectRowsAndColumns(cov, new int[] {0,2}, new int[] {0,2}));
|
||||
dtc_hCalc += hCalc.computeAverageLocalOfObservations();
|
||||
hCalc.initialise(2);
|
||||
hCalc.setCovariance(MatrixUtils.selectRowsAndColumns(cov, new int[] {1,2}, new int[] {1,2}));
|
||||
dtc_hCalc += hCalc.computeAverageLocalOfObservations();
|
||||
|
||||
assertEquals(dtc, dtc_hCalc, 1e-6);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Confirm that the local values average correctly back to the average value
|
||||
*/
|
||||
public void testLocalsAverageCorrectly() throws Exception {
|
||||
|
||||
int dimensions = 4;
|
||||
int timeSteps = 1000;
|
||||
DualTotalCorrelationCalculatorGaussian dtcCalc = new DualTotalCorrelationCalculatorGaussian();
|
||||
dtcCalc.initialise(dimensions);
|
||||
|
||||
// generate some random data
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
double[][] data = rg.generateNormalData(timeSteps, dimensions,
|
||||
0, 1);
|
||||
|
||||
dtcCalc.setObservations(data);
|
||||
|
||||
double dtc = dtcCalc.computeAverageLocalOfObservations();
|
||||
double[] dtcLocal = dtcCalc.computeLocalOfPreviousObservations();
|
||||
|
||||
System.out.printf("Average was %.5f%n", dtc);
|
||||
|
||||
assertEquals(dtc, MatrixUtils.mean(dtcLocal), 0.00001);
|
||||
}
|
||||
|
||||
/**
|
||||
* Confirm that for 2D the local values equal the local MI values
|
||||
*
|
||||
*/
|
||||
public void testLocalsEqualMI() throws Exception {
|
||||
|
||||
int dimensions = 2;
|
||||
int timeSteps = 100;
|
||||
DualTotalCorrelationCalculatorGaussian dtcCalc = new DualTotalCorrelationCalculatorGaussian();
|
||||
dtcCalc.initialise(dimensions);
|
||||
|
||||
// generate some random data
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
double[][] data = rg.generateNormalData(timeSteps, dimensions,
|
||||
0, 1);
|
||||
|
||||
dtcCalc.setObservations(data);
|
||||
double[] dtcLocal = dtcCalc.computeLocalOfPreviousObservations();
|
||||
|
||||
MutualInfoCalculatorMultiVariateGaussian miCalc = new MutualInfoCalculatorMultiVariateGaussian();
|
||||
miCalc.initialise(1, 1);
|
||||
miCalc.setObservations(MatrixUtils.selectColumn(data, 0), MatrixUtils.selectColumn(data, 1));
|
||||
double[] miLocal = miCalc.computeLocalOfPreviousObservations();
|
||||
|
||||
for (int t = 0; t < timeSteps; t++) {
|
||||
assertEquals(dtcLocal[t], miLocal[t], 0.00001);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,150 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous.gaussian;
|
||||
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
import infodynamics.utils.RandomGenerator;
|
||||
import junit.framework.TestCase;
|
||||
|
||||
public class OInfoCalculatorGaussianTester extends TestCase {
|
||||
|
||||
/**
|
||||
* For two variables, O-info is zero
|
||||
*/
|
||||
public void testTwoVariables() throws Exception {
|
||||
double[][] cov = new double[][] {{1, 0.5}, {0.5, 1}};
|
||||
|
||||
OInfoCalculatorGaussian oCalc = new OInfoCalculatorGaussian();
|
||||
oCalc.initialise(2);
|
||||
oCalc.setCovariance(cov);
|
||||
double oinfo = oCalc.computeAverageLocalOfObservations();
|
||||
|
||||
assertEquals(oinfo, 0, 1e-6);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* For factorisable pairwise interactions, O-info is zero
|
||||
*/
|
||||
public void testPairwise() throws Exception {
|
||||
double[][] cov = new double[][] {{ 1, 0.5, 0, 0},
|
||||
{0.5, 1, 0, 0},
|
||||
{ 0, 0, 1, 0.5},
|
||||
{ 0, 0, 0.5, 1}};
|
||||
|
||||
OInfoCalculatorGaussian oCalc = new OInfoCalculatorGaussian();
|
||||
oCalc.initialise(4);
|
||||
oCalc.setCovariance(cov);
|
||||
double oinfo = oCalc.computeAverageLocalOfObservations();
|
||||
|
||||
assertEquals(oinfo, 0, 1e-6);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compare against the direct calculation of O-info as a sum of entropies using the
|
||||
* entropy calculator.
|
||||
*/
|
||||
public void testCompareWithEntropy() throws Exception {
|
||||
double[][] cov = new double[][] {{1, 0.4, 0.3}, {0.4, 1, 0.2}, {0.3, 0.2, 1}};
|
||||
|
||||
OInfoCalculatorGaussian oCalc = new OInfoCalculatorGaussian();
|
||||
oCalc.initialise(3);
|
||||
oCalc.setCovariance(cov);
|
||||
double oinfo = oCalc.computeAverageLocalOfObservations();
|
||||
|
||||
// Calculate using an entropy calculator and picking submatrices manually
|
||||
EntropyCalculatorMultiVariateGaussian hCalc = new EntropyCalculatorMultiVariateGaussian();
|
||||
hCalc.initialise(3);
|
||||
hCalc.setCovariance(cov);
|
||||
double oinfo_hCalc = hCalc.computeAverageLocalOfObservations();
|
||||
|
||||
hCalc.initialise(2);
|
||||
hCalc.setCovariance(MatrixUtils.selectRowsAndColumns(cov, new int[] {0,1}, new int[] {0,1}));
|
||||
oinfo_hCalc -= hCalc.computeAverageLocalOfObservations();
|
||||
hCalc.initialise(2);
|
||||
hCalc.setCovariance(MatrixUtils.selectRowsAndColumns(cov, new int[] {0,2}, new int[] {0,2}));
|
||||
oinfo_hCalc -= hCalc.computeAverageLocalOfObservations();
|
||||
hCalc.initialise(2);
|
||||
hCalc.setCovariance(MatrixUtils.selectRowsAndColumns(cov, new int[] {1,2}, new int[] {1,2}));
|
||||
oinfo_hCalc -= hCalc.computeAverageLocalOfObservations();
|
||||
|
||||
hCalc.initialise(1);
|
||||
hCalc.setCovariance(new double[][] {{cov[0][0]}});
|
||||
oinfo_hCalc += hCalc.computeAverageLocalOfObservations();
|
||||
hCalc.initialise(1);
|
||||
hCalc.setCovariance(new double[][] {{cov[1][1]}});
|
||||
oinfo_hCalc += hCalc.computeAverageLocalOfObservations();
|
||||
hCalc.initialise(1);
|
||||
hCalc.setCovariance(new double[][] {{cov[2][2]}});
|
||||
oinfo_hCalc += hCalc.computeAverageLocalOfObservations();
|
||||
|
||||
assertEquals(oinfo, oinfo_hCalc, 1e-6);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Confirm that the local values average correctly back to the average value
|
||||
*/
|
||||
public void testLocalsAverageCorrectly() throws Exception {
|
||||
|
||||
int dimensions = 4;
|
||||
int timeSteps = 1000;
|
||||
OInfoCalculatorGaussian oCalc = new OInfoCalculatorGaussian();
|
||||
oCalc.initialise(dimensions);
|
||||
|
||||
// generate some random data
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
double[][] data = rg.generateNormalData(timeSteps, dimensions,
|
||||
0, 1);
|
||||
|
||||
oCalc.setObservations(data);
|
||||
|
||||
double oinfo = oCalc.computeAverageLocalOfObservations();
|
||||
double[] oLocal = oCalc.computeLocalOfPreviousObservations();
|
||||
|
||||
System.out.printf("Average was %.5f%n", oinfo);
|
||||
|
||||
assertEquals(oinfo, MatrixUtils.mean(oLocal), 0.00001);
|
||||
}
|
||||
|
||||
/**
|
||||
* Confirm that for 2D all local values equal zero
|
||||
*/
|
||||
public void testLocalsEqualZero() throws Exception {
|
||||
|
||||
int dimensions = 2;
|
||||
int timeSteps = 100;
|
||||
OInfoCalculatorGaussian oCalc = new OInfoCalculatorGaussian();
|
||||
oCalc.initialise(dimensions);
|
||||
|
||||
// generate some random data
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
double[][] data = rg.generateNormalData(timeSteps, dimensions,
|
||||
0, 1);
|
||||
|
||||
oCalc.setObservations(data);
|
||||
double[] oLocal = oCalc.computeLocalOfPreviousObservations();
|
||||
|
||||
for (int t = 0; t < timeSteps; t++) {
|
||||
assertEquals(oLocal[t], 0, 0.00001);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,129 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous.gaussian;
|
||||
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
import infodynamics.utils.RandomGenerator;
|
||||
import junit.framework.TestCase;
|
||||
|
||||
public class SInfoCalculatorGaussianTester extends TestCase {
|
||||
|
||||
/**
|
||||
* For two variables, S-info is twice the MI between them
|
||||
*/
|
||||
public void testTwoVariables() throws Exception {
|
||||
double[][] cov = new double[][] {{1, 0.5}, {0.5, 1}};
|
||||
|
||||
SInfoCalculatorGaussian sCalc = new SInfoCalculatorGaussian();
|
||||
sCalc.initialise(2);
|
||||
sCalc.setCovariance(cov);
|
||||
double sinfo = sCalc.computeAverageLocalOfObservations();
|
||||
|
||||
MutualInfoCalculatorMultiVariateGaussian miCalc = new MutualInfoCalculatorMultiVariateGaussian();
|
||||
miCalc.initialise(1,1);
|
||||
miCalc.setCovariance(cov, false);
|
||||
double mi = miCalc.computeAverageLocalOfObservations();
|
||||
|
||||
assertEquals(sinfo, 2*mi, 1e-6);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compare against the direct calculation of S-info as a sum of mutual informations.
|
||||
*/
|
||||
public void testCompareWithEntropy() throws Exception {
|
||||
double[][] cov = new double[][] {{1, 0.4, 0.3}, {0.4, 1, 0.2}, {0.3, 0.2, 1}};
|
||||
|
||||
SInfoCalculatorGaussian sCalc = new SInfoCalculatorGaussian();
|
||||
sCalc.initialise(3);
|
||||
sCalc.setCovariance(cov);
|
||||
double sinfo = sCalc.computeAverageLocalOfObservations();
|
||||
|
||||
// Calculate using a mutual info calculator and picking submatrices manually
|
||||
MutualInfoCalculatorMultiVariateGaussian miCalc = new MutualInfoCalculatorMultiVariateGaussian();
|
||||
miCalc.initialise(2,1);
|
||||
miCalc.setCovariance(cov, false);
|
||||
double sinfo_miCalc = miCalc.computeAverageLocalOfObservations();
|
||||
|
||||
miCalc.initialise(2,1);
|
||||
miCalc.setCovariance(MatrixUtils.selectRowsAndColumns(cov, new int[] {2,0,1}, new int[] {2,0,1}), false);
|
||||
sinfo_miCalc += miCalc.computeAverageLocalOfObservations();
|
||||
miCalc.initialise(2,1);
|
||||
miCalc.setCovariance(MatrixUtils.selectRowsAndColumns(cov, new int[] {1,2,0}, new int[] {1,2,0}), false);
|
||||
sinfo_miCalc += miCalc.computeAverageLocalOfObservations();
|
||||
|
||||
assertEquals(sinfo, sinfo_miCalc, 1e-6);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Confirm that the local values average correctly back to the average value
|
||||
*/
|
||||
public void testLocalsAverageCorrectly() throws Exception {
|
||||
|
||||
int dimensions = 4;
|
||||
int timeSteps = 1000;
|
||||
SInfoCalculatorGaussian sCalc = new SInfoCalculatorGaussian();
|
||||
sCalc.initialise(dimensions);
|
||||
|
||||
// generate some random data
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
double[][] data = rg.generateNormalData(timeSteps, dimensions,
|
||||
0, 1);
|
||||
|
||||
sCalc.setObservations(data);
|
||||
|
||||
double sinfo = sCalc.computeAverageLocalOfObservations();
|
||||
double[] sLocal = sCalc.computeLocalOfPreviousObservations();
|
||||
|
||||
System.out.printf("Average was %.5f%n", sinfo);
|
||||
|
||||
assertEquals(sinfo, MatrixUtils.mean(sLocal), 0.00001);
|
||||
}
|
||||
|
||||
/**
|
||||
* Confirm that for 2D all local values equal zero
|
||||
*/
|
||||
public void testLocalsEqualMI() throws Exception {
|
||||
|
||||
int dimensions = 2;
|
||||
int timeSteps = 100;
|
||||
SInfoCalculatorGaussian sCalc = new SInfoCalculatorGaussian();
|
||||
sCalc.initialise(dimensions);
|
||||
|
||||
// generate some random data
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
double[][] data = rg.generateNormalData(timeSteps, dimensions,
|
||||
0, 1);
|
||||
|
||||
sCalc.setObservations(data);
|
||||
double[] sLocal = sCalc.computeLocalOfPreviousObservations();
|
||||
|
||||
MutualInfoCalculatorMultiVariateGaussian miCalc = new MutualInfoCalculatorMultiVariateGaussian();
|
||||
miCalc.initialise(1, 1);
|
||||
miCalc.setObservations(MatrixUtils.selectColumn(data, 0), MatrixUtils.selectColumn(data, 1));
|
||||
double[] miLocal = miCalc.computeLocalOfPreviousObservations();
|
||||
|
||||
for (int t = 0; t < timeSteps; t++) {
|
||||
assertEquals(sLocal[t], 2*miLocal[t], 0.00001);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous.kraskov;
|
||||
|
||||
import infodynamics.measures.continuous.gaussian.DualTotalCorrelationCalculatorGaussian;
|
||||
|
||||
import infodynamics.utils.ArrayFileReader;
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
import infodynamics.utils.RandomGenerator;
|
||||
import junit.framework.TestCase;
|
||||
|
||||
public class DualTotalCorrelationCalculatorKraskovTester extends TestCase {
|
||||
|
||||
/**
|
||||
* For two variables, DTC is equal to mutual information.
|
||||
*/
|
||||
public void testTwoVariables() throws Exception {
|
||||
|
||||
double[][] data;
|
||||
ArrayFileReader afr = new ArrayFileReader("demos/data/2randomCols-1.txt");
|
||||
data = afr.getDouble2DMatrix();
|
||||
|
||||
DualTotalCorrelationCalculatorKraskov dtcCalc = new DualTotalCorrelationCalculatorKraskov();
|
||||
dtcCalc.setProperty("NOISE_LEVEL_TO_ADD", "0");
|
||||
dtcCalc.initialise(2);
|
||||
dtcCalc.setObservations(data);
|
||||
double dtc = dtcCalc.computeAverageLocalOfObservations();
|
||||
|
||||
|
||||
MutualInfoCalculatorMultiVariateKraskov1 miCalc = new MutualInfoCalculatorMultiVariateKraskov1();
|
||||
miCalc.setProperty("NOISE_LEVEL_TO_ADD", "0");
|
||||
miCalc.initialise(1,1);
|
||||
miCalc.setObservations(MatrixUtils.selectColumn(data, 0),
|
||||
MatrixUtils.selectColumn(data, 1));
|
||||
double mi = miCalc.computeAverageLocalOfObservations();
|
||||
|
||||
assertEquals(dtc, mi, 1e-6);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compare against the values obtained by the Gaussian DTC calculator when the data
|
||||
* is actually Gaussian.
|
||||
*/
|
||||
public void testCompareWithGaussian() throws Exception {
|
||||
|
||||
int N = 100000, D = 3;
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
rg.setSeed(1);
|
||||
double[][] data = rg.generateNormalData(N, D, 0, 1);
|
||||
for (int i = 0; i < N; i++) {
|
||||
data[i][2] = data[i][2] + 0.1*data[i][0];
|
||||
data[i][1] = data[i][1] + 0.5*data[i][0];
|
||||
}
|
||||
|
||||
DualTotalCorrelationCalculatorKraskov dtcCalc_ksg = new DualTotalCorrelationCalculatorKraskov();
|
||||
dtcCalc_ksg.setProperty("NOISE_LEVEL_TO_ADD", "0");
|
||||
dtcCalc_ksg.initialise(3);
|
||||
dtcCalc_ksg.setObservations(data);
|
||||
double dtc_ksg = dtcCalc_ksg.computeAverageLocalOfObservations();
|
||||
|
||||
DualTotalCorrelationCalculatorGaussian dtcCalc_gau = new DualTotalCorrelationCalculatorGaussian();
|
||||
dtcCalc_gau.initialise(3);
|
||||
dtcCalc_gau.setObservations(data);
|
||||
double dtc_gau = dtcCalc_gau.computeAverageLocalOfObservations();
|
||||
|
||||
assertEquals(dtc_ksg, dtc_gau, 0.001);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Confirm that the local values average correctly back to the average value
|
||||
*
|
||||
*/
|
||||
public void testLocalsAverageCorrectly() throws Exception {
|
||||
|
||||
int dimensions = 4;
|
||||
int timeSteps = 1000;
|
||||
DualTotalCorrelationCalculatorKraskov dtcCalc = new DualTotalCorrelationCalculatorKraskov();
|
||||
dtcCalc.initialise(dimensions);
|
||||
|
||||
// generate some random data
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
double[][] data = rg.generateNormalData(timeSteps, dimensions,
|
||||
0, 1);
|
||||
|
||||
dtcCalc.setObservations(data);
|
||||
|
||||
double dtc = dtcCalc.computeAverageLocalOfObservations();
|
||||
double[] dtcLocal = dtcCalc.computeLocalOfPreviousObservations();
|
||||
|
||||
System.out.printf("Average was %.5f%n", dtc);
|
||||
|
||||
assertEquals(dtc, MatrixUtils.mean(dtcLocal), 0.00001);
|
||||
}
|
||||
|
||||
/**
|
||||
* Confirm that for 2D the local values equal the local MI values
|
||||
*
|
||||
*/
|
||||
public void testLocalsEqualMI() throws Exception {
|
||||
|
||||
int dimensions = 2;
|
||||
int timeSteps = 100;
|
||||
DualTotalCorrelationCalculatorKraskov dtcCalc = new DualTotalCorrelationCalculatorKraskov();
|
||||
dtcCalc.initialise(dimensions);
|
||||
|
||||
// generate some random data
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
double[][] data = rg.generateNormalData(timeSteps, dimensions,
|
||||
0, 1);
|
||||
|
||||
dtcCalc.setObservations(data);
|
||||
double[] dtcLocal = dtcCalc.computeLocalOfPreviousObservations();
|
||||
|
||||
MutualInfoCalculatorMultiVariateKraskov1 miCalc = new MutualInfoCalculatorMultiVariateKraskov1();
|
||||
miCalc.initialise(1, 1);
|
||||
miCalc.setObservations(MatrixUtils.selectColumn(data, 0), MatrixUtils.selectColumn(data, 1));
|
||||
double[] miLocal = miCalc.computeLocalOfPreviousObservations();
|
||||
|
||||
for (int t = 0; t < timeSteps; t++) {
|
||||
assertEquals(dtcLocal[t], miLocal[t], 0.00001);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous.kraskov;
|
||||
|
||||
import infodynamics.measures.continuous.gaussian.OInfoCalculatorGaussian;
|
||||
|
||||
import infodynamics.utils.ArrayFileReader;
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
import infodynamics.utils.RandomGenerator;
|
||||
import junit.framework.TestCase;
|
||||
|
||||
public class OInfoCalculatorKraskovTester extends TestCase {
|
||||
|
||||
/**
|
||||
* Compare against the values obtained by the Gaussian O-info calculator when the data
|
||||
* is actually Gaussian.
|
||||
*/
|
||||
public void testCompareWithGaussian() throws Exception {
|
||||
|
||||
int N = 100000, D = 3;
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
rg.setSeed(3);
|
||||
double[][] data = rg.generateNormalData(N, D, 0, 1);
|
||||
for (int i = 0; i < N; i++) {
|
||||
data[i][2] = data[i][2] + 0.4*data[i][0];
|
||||
data[i][1] = data[i][1] + 0.5*data[i][0];
|
||||
}
|
||||
|
||||
OInfoCalculatorKraskov oCalc_ksg = new OInfoCalculatorKraskov();
|
||||
oCalc_ksg.setProperty("NOISE_LEVEL_TO_ADD", "0");
|
||||
oCalc_ksg.initialise(3);
|
||||
oCalc_ksg.setObservations(data);
|
||||
double o_ksg = oCalc_ksg.computeAverageLocalOfObservations();
|
||||
|
||||
OInfoCalculatorGaussian oCalc_gau = new OInfoCalculatorGaussian();
|
||||
oCalc_gau.initialise(3);
|
||||
oCalc_gau.setObservations(data);
|
||||
double o_gau = oCalc_gau.computeAverageLocalOfObservations();
|
||||
|
||||
assertEquals(o_ksg, o_gau, 0.001);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.continuous.kraskov;
|
||||
|
||||
import infodynamics.measures.continuous.gaussian.SInfoCalculatorGaussian;
|
||||
|
||||
import infodynamics.utils.ArrayFileReader;
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
import infodynamics.utils.RandomGenerator;
|
||||
import junit.framework.TestCase;
|
||||
|
||||
public class SInfoCalculatorKraskovTester extends TestCase {
|
||||
|
||||
/**
|
||||
* Compare against the values obtained by the Gaussian S-info calculator when the data
|
||||
* is actually Gaussian.
|
||||
*/
|
||||
public void testCompareWithGaussian() throws Exception {
|
||||
|
||||
int N = 100000, D = 3;
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
rg.setSeed(3);
|
||||
double[][] data = rg.generateNormalData(N, D, 0, 1);
|
||||
for (int i = 0; i < N; i++) {
|
||||
data[i][2] = data[i][2] + 0.4*data[i][0];
|
||||
data[i][1] = data[i][1] + 0.5*data[i][0];
|
||||
}
|
||||
|
||||
SInfoCalculatorKraskov sCalc_ksg = new SInfoCalculatorKraskov();
|
||||
sCalc_ksg.setProperty("NOISE_LEVEL_TO_ADD", "0");
|
||||
sCalc_ksg.initialise(3);
|
||||
sCalc_ksg.setObservations(data);
|
||||
double s_ksg = sCalc_ksg.computeAverageLocalOfObservations();
|
||||
|
||||
SInfoCalculatorGaussian sCalc_gau = new SInfoCalculatorGaussian();
|
||||
sCalc_gau.initialise(3);
|
||||
sCalc_gau.setObservations(data);
|
||||
double s_gau = sCalc_gau.computeAverageLocalOfObservations();
|
||||
|
||||
assertEquals(s_ksg, s_gau, 0.001);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,108 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.discrete;
|
||||
|
||||
import infodynamics.utils.RandomGenerator;
|
||||
import infodynamics.utils.MatrixUtils;
|
||||
import infodynamics.utils.MathsUtils;
|
||||
|
||||
import java.util.Arrays;
|
||||
import junit.framework.TestCase;
|
||||
|
||||
|
||||
public class DualTotalCorrelationTester extends TestCase {
|
||||
|
||||
public void testIndependent() throws Exception {
|
||||
DualTotalCorrelationCalculatorDiscrete dtcCalc = new DualTotalCorrelationCalculatorDiscrete(2, 3);
|
||||
double dtc = dtcCalc.compute(new int[][] {{0,0,0},{0,0,1},{0,1,0},{0,1,1},{1,0,0},{1,0,1},{1,1,0},{1,1,1}});
|
||||
assertEquals(0.0, dtc, 0.000001);
|
||||
}
|
||||
|
||||
public void testXor() throws Exception {
|
||||
// 3 variables
|
||||
DualTotalCorrelationCalculatorDiscrete dtcCalc = new DualTotalCorrelationCalculatorDiscrete(2, 3);
|
||||
double dtc = dtcCalc.compute(new int[][] {{0,0,1},{0,1,0},{1,0,0},{1,1,1}});
|
||||
assertEquals(2.0, dtc, 0.000001);
|
||||
|
||||
// 4 variables
|
||||
dtcCalc = new DualTotalCorrelationCalculatorDiscrete(2, 4);
|
||||
dtc = dtcCalc.compute(new int[][] {{0,0,0,1},{0,0,1,0},{0,1,0,0},{1,0,0,0},{1,1,1,0},{1,1,0,1},{1,0,1,1},{0,1,1,1}});
|
||||
assertEquals(3.0, dtc, 0.000001);
|
||||
}
|
||||
|
||||
public void testCopy() throws Exception {
|
||||
// 3 variables
|
||||
DualTotalCorrelationCalculatorDiscrete dtcCalc = new DualTotalCorrelationCalculatorDiscrete(2, 3);
|
||||
double dtc = dtcCalc.compute(new int[][] {{0,0,0},{1,1,1}});
|
||||
assertEquals(1.0, dtc, 0.000001);
|
||||
|
||||
// 4 variables
|
||||
dtcCalc = new DualTotalCorrelationCalculatorDiscrete(2, 4);
|
||||
dtc = dtcCalc.compute(new int[][] {{0,0,0,0},{1,1,1,1}});
|
||||
assertEquals(1.0, dtc, 0.000001);
|
||||
}
|
||||
|
||||
public void testCompareEntropy() throws Exception {
|
||||
// Generate random data and check that it matches the explicit computation
|
||||
// using entropy calculators
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
int D = 4;
|
||||
int[][] data = rg.generateRandomInts(10, D, 2);
|
||||
|
||||
// DTC calculator
|
||||
DualTotalCorrelationCalculatorDiscrete dtcCalc = new DualTotalCorrelationCalculatorDiscrete(2, D);
|
||||
double dtc_direct = dtcCalc.compute(data);
|
||||
|
||||
// Entropy calculators
|
||||
EntropyCalculatorDiscrete hCalc = new EntropyCalculatorDiscrete(MathsUtils.power(2, D));
|
||||
hCalc.initialise();
|
||||
hCalc.addObservations(MatrixUtils.computeCombinedValues(data, 2));
|
||||
double dtc_test = (1 - D) * hCalc.computeAverageLocalOfObservations();
|
||||
|
||||
hCalc = new EntropyCalculatorDiscrete(MathsUtils.power(2, D-1));
|
||||
for (int i = 0; i < 4; i++) {
|
||||
hCalc.initialise();
|
||||
hCalc.addObservations(MatrixUtils.computeCombinedValues(MatrixUtils.selectColumns(data, allExcept(i, D)), 2));
|
||||
dtc_test += hCalc.computeAverageLocalOfObservations();
|
||||
}
|
||||
|
||||
assertEquals(dtc_direct, dtc_test, 0.000001);
|
||||
|
||||
}
|
||||
|
||||
protected int[] allExcept(int idx, int N) {
|
||||
boolean[] v = new boolean[N];
|
||||
Arrays.fill(v, true);
|
||||
v[idx] = false;
|
||||
|
||||
int[] v2 = new int[N - 1];
|
||||
int counter = 0;
|
||||
for (int i = 0; i < N; i++) {
|
||||
if (v[i]) {
|
||||
v2[counter] = i;
|
||||
counter++;
|
||||
}
|
||||
}
|
||||
|
||||
return v2;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.discrete;
|
||||
|
||||
import infodynamics.utils.RandomGenerator;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
|
||||
public class OInfoTester extends TestCase {
|
||||
|
||||
public void testIndependent() throws Exception {
|
||||
OInfoCalculatorDiscrete oCalc = new OInfoCalculatorDiscrete(2, 3);
|
||||
double oinfo = oCalc.compute(new int[][] {{0,0,0},{0,0,1},{0,1,0},{0,1,1},{1,0,0},{1,0,1},{1,1,0},{1,1,1}});
|
||||
assertEquals(0.0, oinfo, 0.000001);
|
||||
}
|
||||
|
||||
public void testXor() throws Exception {
|
||||
// 3 variables
|
||||
OInfoCalculatorDiscrete oCalc = new OInfoCalculatorDiscrete(2, 3);
|
||||
double oinfo = oCalc.compute(new int[][] {{0,0,1},{0,1,0},{1,0,0},{1,1,1}});
|
||||
assertEquals(-1.0, oinfo, 0.000001);
|
||||
|
||||
// 4 variables
|
||||
oCalc = new OInfoCalculatorDiscrete(2, 4);
|
||||
oinfo = oCalc.compute(new int[][] {{0,0,0,1},{0,0,1,0},{0,1,0,0},{1,0,0,0},{1,1,1,0},{1,1,0,1},{1,0,1,1},{0,1,1,1}});
|
||||
assertEquals(-2.0, oinfo, 0.000001);
|
||||
}
|
||||
|
||||
public void testCopy() throws Exception {
|
||||
// 3 variables
|
||||
OInfoCalculatorDiscrete oCalc = new OInfoCalculatorDiscrete(2, 3);
|
||||
double oinfo = oCalc.compute(new int[][] {{0,0,0},{1,1,1}});
|
||||
assertEquals(1.0, oinfo, 0.000001);
|
||||
|
||||
// 4 variables
|
||||
oCalc = new OInfoCalculatorDiscrete(2, 4);
|
||||
oinfo = oCalc.compute(new int[][] {{0,0,0,0},{1,1,1,1}});
|
||||
assertEquals(2.0, oinfo, 0.000001);
|
||||
}
|
||||
|
||||
public void testPairwise() throws Exception {
|
||||
// Variables 0 and 1 are correlated and independent from 2 and 3, that are
|
||||
// also correlated
|
||||
OInfoCalculatorDiscrete oCalc = new OInfoCalculatorDiscrete(2, 4);
|
||||
double oinfo = oCalc.compute(new int[][] {{0,0,0,0},{0,0,1,1},{1,1,0,0},{1,1,1,1}});
|
||||
assertEquals(0.0, oinfo, 0.000001);
|
||||
}
|
||||
|
||||
public void testCompareTCAndDTC() throws Exception {
|
||||
// Generate random data and check that it matches the explicit computation
|
||||
// using TC and DTC calculators
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
int[][] data = rg.generateRandomInts(10, 4, 2);
|
||||
|
||||
// O-info calculator
|
||||
OInfoCalculatorDiscrete oCalc = new OInfoCalculatorDiscrete(2, 4);
|
||||
double oinfo_direct = oCalc.compute(data);
|
||||
|
||||
// TC and DTC calculators
|
||||
MultiInformationCalculatorDiscrete tcCalc = new MultiInformationCalculatorDiscrete(2, 4);
|
||||
DualTotalCorrelationCalculatorDiscrete dtcCalc = new DualTotalCorrelationCalculatorDiscrete(2, 4);
|
||||
tcCalc.initialise();
|
||||
tcCalc.addObservations(data);
|
||||
double oinfo_test = tcCalc.computeAverageLocalOfObservations() - dtcCalc.compute(data);
|
||||
|
||||
assertEquals(oinfo_direct, oinfo_test, 0.000001);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
* Java Information Dynamics Toolkit (JIDT)
|
||||
* Copyright (C) 2012, Joseph T. Lizier
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU General Public License as published by
|
||||
* the Free Software Foundation, either version 3 of the License, or
|
||||
* (at your option) any later version.
|
||||
*
|
||||
* This program is distributed in the hope that it will be useful,
|
||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
* GNU General Public License for more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package infodynamics.measures.discrete;
|
||||
|
||||
import infodynamics.utils.RandomGenerator;
|
||||
|
||||
import junit.framework.TestCase;
|
||||
|
||||
public class SInfoTester extends TestCase {
|
||||
|
||||
public void testIndependent() throws Exception {
|
||||
SInfoCalculatorDiscrete sCalc = new SInfoCalculatorDiscrete(2, 3);
|
||||
double sinfo = sCalc.compute(new int[][] {{0,0,0},{0,0,1},{0,1,0},{0,1,1},{1,0,0},{1,0,1},{1,1,0},{1,1,1}});
|
||||
assertEquals(0.0, sinfo, 0.000001);
|
||||
}
|
||||
|
||||
public void testXor() throws Exception {
|
||||
// 3 variables
|
||||
SInfoCalculatorDiscrete sCalc = new SInfoCalculatorDiscrete(2, 3);
|
||||
double sinfo = sCalc.compute(new int[][] {{0,0,1},{0,1,0},{1,0,0},{1,1,1}});
|
||||
assertEquals(3.0, sinfo, 0.000001);
|
||||
|
||||
// 4 variables
|
||||
sCalc = new SInfoCalculatorDiscrete(2, 4);
|
||||
sinfo = sCalc.compute(new int[][] {{0,0,0,1},{0,0,1,0},{0,1,0,0},{1,0,0,0},{1,1,1,0},{1,1,0,1},{1,0,1,1},{0,1,1,1}});
|
||||
assertEquals(4.0, sinfo, 0.000001);
|
||||
}
|
||||
|
||||
public void testCopy() throws Exception {
|
||||
// 3 variables
|
||||
SInfoCalculatorDiscrete sCalc = new SInfoCalculatorDiscrete(2, 3);
|
||||
double sinfo = sCalc.compute(new int[][] {{0,0,0},{1,1,1}});
|
||||
assertEquals(3.0, sinfo, 0.000001);
|
||||
|
||||
// 4 variables
|
||||
sCalc = new SInfoCalculatorDiscrete(2, 4);
|
||||
sinfo = sCalc.compute(new int[][] {{0,0,0,0},{1,1,1,1}});
|
||||
assertEquals(4.0, sinfo, 0.000001);
|
||||
}
|
||||
|
||||
public void testCompareTCAndDTC() throws Exception {
|
||||
// Generate random data and check that it matches the explicit computation
|
||||
// using TC and DTC calculators
|
||||
RandomGenerator rg = new RandomGenerator();
|
||||
int[][] data = rg.generateRandomInts(10, 4, 2);
|
||||
|
||||
// O-info calculator
|
||||
SInfoCalculatorDiscrete sCalc = new SInfoCalculatorDiscrete(2, 4);
|
||||
double sinfo_direct = sCalc.compute(data);
|
||||
|
||||
// TC and DTC calculators
|
||||
MultiInformationCalculatorDiscrete tcCalc = new MultiInformationCalculatorDiscrete(2, 4);
|
||||
DualTotalCorrelationCalculatorDiscrete dtcCalc = new DualTotalCorrelationCalculatorDiscrete(2, 4);
|
||||
tcCalc.initialise();
|
||||
tcCalc.addObservations(data);
|
||||
double sinfo_test = tcCalc.computeAverageLocalOfObservations() + dtcCalc.compute(data);
|
||||
|
||||
assertEquals(sinfo_direct, sinfo_test, 0.000001);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue