mirror of https://github.com/jlizier/jidt
refactoring + conditional processes + radius sharing + euclidean norm working
This commit is contained in:
parent
9a17361cca
commit
dc85001ff7
|
@ -7,6 +7,7 @@ import java.util.PriorityQueue;
|
|||
import java.util.Random;
|
||||
import java.util.Vector;
|
||||
|
||||
//import infodynamics.measures.continuous.kraskov.EuclideanUtils;
|
||||
import infodynamics.measures.spiking.TransferEntropyCalculatorSpiking;
|
||||
import infodynamics.utils.EmpiricalMeasurementDistribution;
|
||||
import infodynamics.utils.KdTree;
|
||||
|
@ -15,48 +16,71 @@ import infodynamics.utils.MatrixUtils;
|
|||
import infodynamics.utils.NeighbourNodeData;
|
||||
import infodynamics.utils.FirstIndexComparatorDouble;
|
||||
import infodynamics.utils.UnivariateNearestNeighbourSearcher;
|
||||
import infodynamics.utils.EuclideanUtils;
|
||||
import infodynamics.utils.ParsedProperties;
|
||||
|
||||
/**
|
||||
* Computes the transfer entropy between a pair of spike trains,
|
||||
* using an integration-based measure in order to match the theoretical
|
||||
* form of TE between such spike trains.
|
||||
* Computes the transfer entropy between a pair of spike trains, using an
|
||||
* integration-based measure in order to match the theoretical form of TE
|
||||
* between such spike trains.
|
||||
*
|
||||
* <p>Usage paradigm is as per the interface {@link TransferEntropyCalculatorSpiking} </p>
|
||||
* <p>
|
||||
* Usage paradigm is as per the interface
|
||||
* {@link TransferEntropyCalculatorSpiking}
|
||||
* </p>
|
||||
*
|
||||
* @author Joseph Lizier (<a href="joseph.lizier at gmail.com">email</a>,
|
||||
* <a href="http://lizier.me/joseph/">www</a>)
|
||||
*/
|
||||
public class TransferEntropyCalculatorSpikingIntegration implements
|
||||
TransferEntropyCalculatorSpiking {
|
||||
public class TransferEntropyCalculatorSpikingIntegration implements TransferEntropyCalculatorSpiking {
|
||||
|
||||
/**
|
||||
* Number of past destination spikes to consider (akin to embedding length)
|
||||
* Number of past destination interspike intervals to consider (akin to embedding length)
|
||||
*/
|
||||
protected int k = 1;
|
||||
/**
|
||||
* Number of past source spikes to consider (akin to embedding length)
|
||||
* Number of past source interspike intervals to consider (akin to embedding length)
|
||||
*/
|
||||
protected int l = 1;
|
||||
|
||||
/**
|
||||
* Property name for number of interspike intervals for the conditional variables
|
||||
*/
|
||||
public static final String COND_EMBED_LENGTHS_PROP_NAME = "COND_EMBED_LENGTHS";
|
||||
/**
|
||||
* Array of history interspike interval embedding lengths for the conditional variables.
|
||||
* Can be an empty array or null if there are no conditional variables.
|
||||
*/
|
||||
protected int[] condEmbedDims = new int[] {};
|
||||
|
||||
/**
|
||||
* Number of nearest neighbours to search for in the full joint space
|
||||
*/
|
||||
protected int Knns = 4;
|
||||
|
||||
/**
|
||||
* Storage for source observations supplied via {@link #addObservations(double[], double[])} etc.
|
||||
* Storage for source observations supplied via
|
||||
* {@link #addObservations(double[], double[])} etc.
|
||||
*/
|
||||
protected Vector<double[]> vectorOfSourceSpikeTimes = null;
|
||||
|
||||
/**
|
||||
* Storage for destination observations supplied via {@link #addObservations(double[], double[])} etc.
|
||||
* Storage for destination observations supplied via
|
||||
* {@link #addObservations(double[], double[])} etc.
|
||||
*/
|
||||
protected Vector<double[]> vectorOfDestinationSpikeTimes = null;
|
||||
|
||||
Vector<double[]> targetEmbeddingsFromSpikes = null;
|
||||
/**
|
||||
* Storage for conditional observations supplied via
|
||||
* {@link #addObservations(double[], double[])} etc.
|
||||
*/
|
||||
protected Vector<double[][]> vectorOfConditionalSpikeTimes = null;
|
||||
|
||||
Vector<double[]> ConditioningEmbeddingsFromSpikes = null;
|
||||
Vector<double[]> jointEmbeddingsFromSpikes = null;
|
||||
Vector<double[]> targetEmbeddingsFromSamples = null;
|
||||
Vector<double[]> ConditioningEmbeddingsFromSamples = null;
|
||||
Vector<double[]> jointEmbeddingsFromSamples = null;
|
||||
Vector<Double> processTimeLengths = null;
|
||||
|
||||
protected KdTree kdTreeJointAtSpikes = null;
|
||||
protected KdTree kdTreeJointAtSamples = null;
|
||||
|
@ -66,10 +90,11 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
public static final String KNNS_PROP_NAME = "Knns";
|
||||
|
||||
/**
|
||||
* Property name for an amount of random Gaussian noise to be
|
||||
* added to the data (default is 1e-8, matching the MILCA toolkit).
|
||||
* Property name for an amount of random Gaussian noise to be added to the data
|
||||
* (default is 1e-8, matching the MILCA toolkit).
|
||||
*/
|
||||
public static final String PROP_ADD_NOISE = "NOISE_LEVEL_TO_ADD";
|
||||
|
||||
/**
|
||||
* Whether to add an amount of random noise to the incoming data
|
||||
*/
|
||||
|
@ -79,35 +104,61 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
*/
|
||||
protected double noiseLevel = (double) 1e-8;
|
||||
|
||||
protected boolean trimToPosNextSpikeTimes = false;
|
||||
|
||||
/**
|
||||
* Stores whether we are in debug mode
|
||||
*/
|
||||
protected boolean debug = false;
|
||||
|
||||
/**
|
||||
* Property name for the number of random sample points to use as a multiple
|
||||
* of the number of target spikes.
|
||||
*/
|
||||
public static final String PROP_SAMPLE_MULTIPLIER = "NUM_SAMPLES_MULTIPLIER";
|
||||
protected double num_samples_multiplier = 1.0;
|
||||
/**
|
||||
* Property name for what type of norm to use between data points
|
||||
* for each marginal variable -- Options are defined by
|
||||
* {@link KdTree#setNormType(String)} and the
|
||||
* default is {@link EuclideanUtils#NORM_EUCLIDEAN}.
|
||||
*/
|
||||
public final static String PROP_NORM_TYPE = "NORM_TYPE";
|
||||
protected int normType = EuclideanUtils.NORM_EUCLIDEAN;
|
||||
|
||||
public TransferEntropyCalculatorSpikingIntegration() {
|
||||
super();
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(int)
|
||||
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see
|
||||
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(
|
||||
* int)
|
||||
*/
|
||||
@Override
|
||||
public void initialise() throws Exception {
|
||||
initialise(k, l);
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(int)
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see
|
||||
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(
|
||||
* int)
|
||||
*/
|
||||
@Override
|
||||
public void initialise(int k) throws Exception {
|
||||
initialise(k, this.l);
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(int, int)
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see
|
||||
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(
|
||||
* int, int)
|
||||
*/
|
||||
@Override
|
||||
public void initialise(int k, int l) throws Exception {
|
||||
|
@ -120,40 +171,74 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
vectorOfDestinationSpikeTimes = null;
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setProperty(java.lang.String, java.lang.String)
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see
|
||||
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setProperty(
|
||||
* java.lang.String, java.lang.String)
|
||||
*/
|
||||
@Override
|
||||
public void setProperty(String propertyName, String propertyValue)
|
||||
throws Exception {
|
||||
public void setProperty(String propertyName, String propertyValue) throws Exception {
|
||||
boolean propertySet = true;
|
||||
if (propertyName.equalsIgnoreCase(K_PROP_NAME)) {
|
||||
k = Integer.parseInt(propertyValue);
|
||||
int k_temp = Integer.parseInt(propertyValue);
|
||||
if (k_temp < 1) {
|
||||
throw new Exception ("Invalid k value less than 1.");
|
||||
} else {
|
||||
k = k_temp;
|
||||
}
|
||||
} else if (propertyName.equalsIgnoreCase(L_PROP_NAME)) {
|
||||
l = Integer.parseInt(propertyValue);
|
||||
int l_temp = Integer.parseInt(propertyValue);
|
||||
if (l_temp < 1) {
|
||||
throw new Exception ("Invalid l value less than 1.");
|
||||
} else {
|
||||
l = l_temp;
|
||||
}
|
||||
} else if (propertyName.equalsIgnoreCase(COND_EMBED_LENGTHS_PROP_NAME)) {
|
||||
int[] condEmbedDims_temp = ParsedProperties.parseStringArrayOfInts(propertyValue);
|
||||
for (int dim : condEmbedDims_temp) {
|
||||
if (dim < 1) {
|
||||
throw new Exception ("Invalid conditional embedding value less than 1.");
|
||||
}
|
||||
}
|
||||
condEmbedDims = condEmbedDims_temp;
|
||||
} else if (propertyName.equalsIgnoreCase(KNNS_PROP_NAME)) {
|
||||
Knns = Integer.parseInt(propertyValue);
|
||||
} else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) {
|
||||
if (propertyValue.equals("0") ||
|
||||
propertyValue.equalsIgnoreCase("false")) {
|
||||
if (propertyValue.equals("0") || propertyValue.equalsIgnoreCase("false")) {
|
||||
addNoise = false;
|
||||
noiseLevel = 0;
|
||||
} else {
|
||||
addNoise = true;
|
||||
noiseLevel = Double.parseDouble(propertyValue);
|
||||
}
|
||||
|
||||
} else if (propertyName.equalsIgnoreCase(PROP_SAMPLE_MULTIPLIER)) {
|
||||
double temp_num_samples_multiplier = Double.parseDouble(propertyValue);
|
||||
if (temp_num_samples_multiplier <= 0) {
|
||||
throw new Exception ("Num samples multiplier must be greater than 0.");
|
||||
} else {
|
||||
num_samples_multiplier = temp_num_samples_multiplier;
|
||||
}
|
||||
} else if (propertyName.equalsIgnoreCase(PROP_NORM_TYPE)) {
|
||||
normType = KdTree.validateNormType(propertyValue);
|
||||
} else {
|
||||
// No property was set on this class
|
||||
propertySet = false;
|
||||
}
|
||||
if (debug && propertySet) {
|
||||
System.out.println(this.getClass().getSimpleName() + ": Set property " + propertyName +
|
||||
" to " + propertyValue);
|
||||
System.out.println(
|
||||
this.getClass().getSimpleName() + ": Set property " + propertyName + " to " + propertyValue);
|
||||
}
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getProperty(java.lang.String)
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see
|
||||
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getProperty(
|
||||
* java.lang.String)
|
||||
*/
|
||||
@Override
|
||||
public String getProperty(String propertyName) throws Exception {
|
||||
|
@ -165,112 +250,143 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
return Integer.toString(Knns);
|
||||
} else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) {
|
||||
return Double.toString(noiseLevel);
|
||||
} else if (propertyName.equalsIgnoreCase(PROP_SAMPLE_MULTIPLIER)) {
|
||||
return Double.toString(num_samples_multiplier);
|
||||
} else {
|
||||
// No property matches for this class
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setObservations(double[], double[])
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||
* setObservations(double[], double[])
|
||||
*/
|
||||
@Override
|
||||
public void setObservations(double[] source, double[] destination)
|
||||
throws Exception {
|
||||
public void setObservations(double[] source, double[] destination) throws Exception {
|
||||
startAddObservations();
|
||||
addObservations(source, destination);
|
||||
finaliseAddObservations();
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#startAddObservations()
|
||||
public void setObservations(double[] source, double[] destination, double[][] conditionals) throws Exception {
|
||||
startAddObservations();
|
||||
addObservations(source, destination, conditionals);
|
||||
finaliseAddObservations();
|
||||
}
|
||||
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||
* startAddObservations()
|
||||
*/
|
||||
@Override
|
||||
public void startAddObservations() {
|
||||
vectorOfSourceSpikeTimes = new Vector<double[]>();
|
||||
vectorOfDestinationSpikeTimes = new Vector<double[]>();
|
||||
vectorOfConditionalSpikeTimes = new Vector<double[][]>();
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#addObservations(double[], double[])
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||
* addObservations(double[], double[])
|
||||
*/
|
||||
@Override
|
||||
public void addObservations(double[] source, double[] destination)
|
||||
throws Exception {
|
||||
public void addObservations(double[] source, double[] destination) throws Exception {
|
||||
// Store these observations in our vector for now
|
||||
vectorOfSourceSpikeTimes.add(source);
|
||||
vectorOfDestinationSpikeTimes.add(destination);
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#finaliseAddObservations()
|
||||
public void addObservations(double[] source, double[] destination, double[][] conditionals) throws Exception {
|
||||
// Store these observations in our vector for now
|
||||
vectorOfSourceSpikeTimes.add(source);
|
||||
vectorOfDestinationSpikeTimes.add(destination);
|
||||
vectorOfConditionalSpikeTimes.add(conditionals);
|
||||
}
|
||||
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||
* finaliseAddObservations()
|
||||
*/
|
||||
@Override
|
||||
public void finaliseAddObservations() throws Exception {
|
||||
|
||||
targetEmbeddingsFromSpikes = new Vector<double[]>();
|
||||
ConditioningEmbeddingsFromSpikes = new Vector<double[]>();
|
||||
jointEmbeddingsFromSpikes = new Vector<double[]>();
|
||||
targetEmbeddingsFromSamples = new Vector<double[]>();
|
||||
ConditioningEmbeddingsFromSamples = new Vector<double[]>();
|
||||
jointEmbeddingsFromSamples = new Vector<double[]>();
|
||||
processTimeLengths = new Vector<Double>();
|
||||
|
||||
// Send all of the observations through:
|
||||
Iterator<double[]> sourceIterator = vectorOfSourceSpikeTimes.iterator();
|
||||
int timeSeriesIndex = 0;
|
||||
if (vectorOfConditionalSpikeTimes.size() > 0) {
|
||||
Iterator<double[][]> conditionalIterator = vectorOfConditionalSpikeTimes.iterator();
|
||||
for (double[] destSpikeTimes : vectorOfDestinationSpikeTimes) {
|
||||
double[] sourceSpikeTimes = sourceIterator.next();
|
||||
processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes,
|
||||
targetEmbeddingsFromSpikes, jointEmbeddingsFromSpikes,
|
||||
targetEmbeddingsFromSamples, jointEmbeddingsFromSamples);
|
||||
double[][] conditionalSpikeTimes = conditionalIterator.next();
|
||||
processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, ConditioningEmbeddingsFromSpikes,
|
||||
jointEmbeddingsFromSpikes, ConditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples,
|
||||
processTimeLengths);
|
||||
}
|
||||
} else {
|
||||
for (double[] destSpikeTimes : vectorOfDestinationSpikeTimes) {
|
||||
double[] sourceSpikeTimes = sourceIterator.next();
|
||||
double[][] conditionalSpikeTimes = new double[][] {};
|
||||
processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, ConditioningEmbeddingsFromSpikes,
|
||||
jointEmbeddingsFromSpikes, ConditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples,
|
||||
processTimeLengths);
|
||||
}
|
||||
}
|
||||
|
||||
// Convert the vectors to arrays so that they can be put in the trees
|
||||
double[][] arrayedTargetEmbeddingsFromSpikes = new double[targetEmbeddingsFromSpikes.size()][k];
|
||||
double[][] arrayedJointEmbeddingsFromSpikes = new double[targetEmbeddingsFromSpikes.size()][k + l];
|
||||
for (int i = 0; i < targetEmbeddingsFromSpikes.size(); i++) {
|
||||
arrayedTargetEmbeddingsFromSpikes[i] = targetEmbeddingsFromSpikes.elementAt(i);
|
||||
double[][] arrayedTargetEmbeddingsFromSpikes = new double[ConditioningEmbeddingsFromSpikes.size()][k];
|
||||
double[][] arrayedJointEmbeddingsFromSpikes = new double[ConditioningEmbeddingsFromSpikes.size()][k + l];
|
||||
for (int i = 0; i < ConditioningEmbeddingsFromSpikes.size(); i++) {
|
||||
arrayedTargetEmbeddingsFromSpikes[i] = ConditioningEmbeddingsFromSpikes.elementAt(i);
|
||||
arrayedJointEmbeddingsFromSpikes[i] = jointEmbeddingsFromSpikes.elementAt(i);
|
||||
}
|
||||
double[][] arrayedTargetEmbeddingsFromSamples = new double[targetEmbeddingsFromSamples.size()][k];
|
||||
double[][] arrayedJointEmbeddingsFromSamples = new double[targetEmbeddingsFromSamples.size()][k + l];
|
||||
for (int i = 0; i < targetEmbeddingsFromSamples.size(); i++) {
|
||||
arrayedTargetEmbeddingsFromSamples[i] = targetEmbeddingsFromSamples.elementAt(i);
|
||||
double[][] arrayedTargetEmbeddingsFromSamples = new double[ConditioningEmbeddingsFromSamples.size()][k];
|
||||
double[][] arrayedJointEmbeddingsFromSamples = new double[ConditioningEmbeddingsFromSamples.size()][k + l];
|
||||
for (int i = 0; i < ConditioningEmbeddingsFromSamples.size(); i++) {
|
||||
arrayedTargetEmbeddingsFromSamples[i] = ConditioningEmbeddingsFromSamples.elementAt(i);
|
||||
arrayedJointEmbeddingsFromSamples[i] = jointEmbeddingsFromSamples.elementAt(i);
|
||||
}
|
||||
|
||||
kdTreeJointAtSpikes = new KdTree(
|
||||
new int[] {k + l},
|
||||
new double[][][] {arrayedJointEmbeddingsFromSpikes});
|
||||
kdTreeJointAtSamples = new KdTree(
|
||||
new int[] {k + l},
|
||||
new double[][][] {arrayedJointEmbeddingsFromSamples});
|
||||
kdTreeConditioningAtSpikes = new KdTree(
|
||||
new int[] {k},
|
||||
new double[][][] {arrayedTargetEmbeddingsFromSpikes});
|
||||
kdTreeConditioningAtSamples = new KdTree(
|
||||
new int[] {k},
|
||||
new double[][][] {arrayedTargetEmbeddingsFromSamples});
|
||||
kdTreeJointAtSpikes = new KdTree(arrayedJointEmbeddingsFromSpikes);
|
||||
kdTreeJointAtSamples = new KdTree(arrayedJointEmbeddingsFromSamples);
|
||||
kdTreeConditioningAtSpikes = new KdTree(arrayedTargetEmbeddingsFromSpikes);
|
||||
kdTreeConditioningAtSamples = new KdTree(arrayedTargetEmbeddingsFromSamples);
|
||||
|
||||
/*kdTreeJointAtSpikes.setNormType("EUCLIDEAN");
|
||||
kdTreeJointAtSamples.setNormType("EUCLIDEAN");
|
||||
kdTreeConditioningAtSpikes.setNormType("EUCLIDEAN");
|
||||
kdTreeConditioningAtSamples.setNormType("EUCLIDEAN");*/
|
||||
kdTreeJointAtSpikes.setNormType(normType);
|
||||
kdTreeJointAtSamples.setNormType(normType);
|
||||
kdTreeConditioningAtSpikes.setNormType(normType);
|
||||
kdTreeConditioningAtSamples.setNormType(normType);
|
||||
}
|
||||
|
||||
protected void makeEmbeddingsAtPoints(double[] pointsAtWhichToMakeEmbeddings, double[] sourceSpikeTimes, double[] destSpikeTimes,
|
||||
Vector<double[]> targetEmbeddings, Vector<double[]> jointEmbeddings) {
|
||||
//System.out.println("foo");
|
||||
protected void makeEmbeddingsAtPoints(double[] pointsAtWhichToMakeEmbeddings, int index_of_first_point_to_use,
|
||||
double[] sourceSpikeTimes, double[] destSpikeTimes,
|
||||
double[][] conditionalSpikeTimes,
|
||||
Vector<double[]> ConditioningEmbeddings,
|
||||
Vector<double[]> jointEmbeddings) {
|
||||
|
||||
Random random = new Random();
|
||||
|
||||
int embedding_point_index = 0;
|
||||
int embedding_point_index = index_of_first_point_to_use;
|
||||
int most_recent_dest_index = k;
|
||||
int most_recent_source_index = l;
|
||||
|
||||
// Make sure that the first point at which an embedding is made has enough preceding spikes in both source and
|
||||
// target for embeddings to be made.
|
||||
while (pointsAtWhichToMakeEmbeddings[embedding_point_index] <= destSpikeTimes[most_recent_dest_index] |
|
||||
pointsAtWhichToMakeEmbeddings[embedding_point_index] <= sourceSpikeTimes[most_recent_source_index]) {
|
||||
embedding_point_index++;
|
||||
int[] most_recent_conditioning_indices = Arrays.copyOf(condEmbedDims, condEmbedDims.length);
|
||||
int total_length_of_conditioning_embeddings = 0;
|
||||
for (int i = 0; i < condEmbedDims.length; i++) {
|
||||
total_length_of_conditioning_embeddings += condEmbedDims[i];
|
||||
}
|
||||
|
||||
// Loop through the points at which embeddings need to be made
|
||||
|
@ -286,162 +402,327 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
}
|
||||
// Do the same for the most recent source index
|
||||
while (most_recent_source_index < (sourceSpikeTimes.length - 1)) {
|
||||
if (sourceSpikeTimes[most_recent_source_index + 1] < pointsAtWhichToMakeEmbeddings[embedding_point_index]) {
|
||||
if (sourceSpikeTimes[most_recent_source_index
|
||||
+ 1] < pointsAtWhichToMakeEmbeddings[embedding_point_index]) {
|
||||
most_recent_source_index++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Now advance the trackers for the most recent conditioning indices
|
||||
for (int j = 0; j < most_recent_conditioning_indices.length; j++) {
|
||||
while (most_recent_conditioning_indices[j] < (conditionalSpikeTimes[j].length - 1)) {
|
||||
if (conditionalSpikeTimes[j][most_recent_conditioning_indices[j] + 1] < pointsAtWhichToMakeEmbeddings[embedding_point_index]) {
|
||||
most_recent_conditioning_indices[j]++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
double[] destPast = new double[k];
|
||||
double[] jointPast = new double[k + l];
|
||||
destPast[0] = pointsAtWhichToMakeEmbeddings[embedding_point_index] -
|
||||
destSpikeTimes[most_recent_dest_index];
|
||||
jointPast[0] = pointsAtWhichToMakeEmbeddings[embedding_point_index] -
|
||||
destSpikeTimes[most_recent_dest_index];
|
||||
jointPast[k] = pointsAtWhichToMakeEmbeddings[embedding_point_index] -
|
||||
sourceSpikeTimes[most_recent_source_index];
|
||||
|
||||
double[] conditioningPast = new double[k + total_length_of_conditioning_embeddings];
|
||||
double[] jointPast = new double[k + total_length_of_conditioning_embeddings + l];
|
||||
|
||||
// Add the embedding intervals from the target process
|
||||
conditioningPast[0] = pointsAtWhichToMakeEmbeddings[embedding_point_index] - destSpikeTimes[most_recent_dest_index];
|
||||
jointPast[0] = pointsAtWhichToMakeEmbeddings[embedding_point_index]
|
||||
- destSpikeTimes[most_recent_dest_index];
|
||||
for (int i = 1; i < k; i++) {
|
||||
destPast[i] = destSpikeTimes[most_recent_dest_index - i + 1] -
|
||||
destSpikeTimes[most_recent_dest_index - i];
|
||||
jointPast[i] = destSpikeTimes[most_recent_dest_index - i + 1] -
|
||||
destSpikeTimes[most_recent_dest_index - i];
|
||||
}
|
||||
for (int i = 1; i < l; i++) {
|
||||
jointPast[k + i] = sourceSpikeTimes[most_recent_source_index - i + 1] -
|
||||
sourceSpikeTimes[most_recent_source_index - i];
|
||||
conditioningPast[i] = destSpikeTimes[most_recent_dest_index - i + 1]
|
||||
- destSpikeTimes[most_recent_dest_index - i];
|
||||
jointPast[i] = destSpikeTimes[most_recent_dest_index - i + 1]
|
||||
- destSpikeTimes[most_recent_dest_index - i];
|
||||
}
|
||||
|
||||
if (addNoise) {
|
||||
for (int i = 0; i < k; i++) {
|
||||
destPast[i] += random.nextGaussian()*noiseLevel;
|
||||
// Add the embeding intervals from the conditional processes
|
||||
int index_of_next_embedding_interval = k;
|
||||
for (int i = 0; i < condEmbedDims.length; i++) {
|
||||
conditioningPast[index_of_next_embedding_interval] =
|
||||
pointsAtWhichToMakeEmbeddings[embedding_point_index] - conditionalSpikeTimes[i][most_recent_conditioning_indices[i]];
|
||||
jointPast[index_of_next_embedding_interval] =
|
||||
pointsAtWhichToMakeEmbeddings[embedding_point_index] - conditionalSpikeTimes[i][most_recent_conditioning_indices[i]];
|
||||
index_of_next_embedding_interval += 1;
|
||||
for (int j = 1; j < condEmbedDims[i]; j++) {
|
||||
conditioningPast[index_of_next_embedding_interval] =
|
||||
conditionalSpikeTimes[i][most_recent_conditioning_indices[i] - j + 1] -
|
||||
conditionalSpikeTimes[i][most_recent_conditioning_indices[i] - j];
|
||||
jointPast[index_of_next_embedding_interval] =
|
||||
conditionalSpikeTimes[i][most_recent_conditioning_indices[i] - j + 1] -
|
||||
conditionalSpikeTimes[i][most_recent_conditioning_indices[i] - j];
|
||||
index_of_next_embedding_interval += 1;
|
||||
}
|
||||
for (int i = 0; i < l; i++) {
|
||||
}
|
||||
|
||||
// Add the embedding intervals from the source process (this only gets added to the joint embeddings)
|
||||
jointPast[k + total_length_of_conditioning_embeddings] = pointsAtWhichToMakeEmbeddings[embedding_point_index]
|
||||
- sourceSpikeTimes[most_recent_source_index];
|
||||
for (int i = 1; i < l; i++) {
|
||||
jointPast[k + total_length_of_conditioning_embeddings + i] = sourceSpikeTimes[most_recent_source_index - i + 1]
|
||||
- sourceSpikeTimes[most_recent_source_index - i];
|
||||
}
|
||||
|
||||
// Add Gaussian noise, if necessary
|
||||
if (addNoise) {
|
||||
for (int i = 0; i < conditioningPast.length; i++) {
|
||||
conditioningPast[i] += random.nextGaussian() * noiseLevel;
|
||||
}
|
||||
for (int i = 0; i < jointPast.length; i++) {
|
||||
jointPast[i] += random.nextGaussian() * noiseLevel;
|
||||
}
|
||||
}
|
||||
|
||||
targetEmbeddings.add(destPast);
|
||||
ConditioningEmbeddings.add(conditioningPast);
|
||||
jointEmbeddings.add(jointPast);
|
||||
}
|
||||
}
|
||||
|
||||
protected void processEventsFromSpikingTimeSeries(double[] sourceSpikeTimes, double[] destSpikeTimes,
|
||||
Vector<double[]> targetEmbeddingsFromSpikes, Vector<double[]> jointEmbeddingsFromSpikes,
|
||||
Vector<double[]> targetEmbeddingsFromSamples, Vector<double[]> jointEmbeddingsFromSamples)
|
||||
protected void processEventsFromSpikingTimeSeries(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes,
|
||||
Vector<double[]> ConditioningEmbeddingsFromSpikes, Vector<double[]> jointEmbeddingsFromSpikes,
|
||||
Vector<double[]> ConditioningEmbeddingsFromSamples, Vector<double[]> jointEmbeddingsFromSamples,
|
||||
Vector<Double> processTimeLengths)
|
||||
throws Exception {
|
||||
// addObservationsAfterParamsDetermined(sourceSpikeTimes, destSpikeTimes);
|
||||
|
||||
// First sort the spike times in case they were not properly in ascending order:
|
||||
Arrays.sort(sourceSpikeTimes);
|
||||
Arrays.sort(destSpikeTimes);
|
||||
|
||||
double sample_lower_bound = Arrays.stream(sourceSpikeTimes).min().getAsDouble();
|
||||
double sample_upper_bound = Arrays.stream(sourceSpikeTimes).max().getAsDouble();
|
||||
double[] randomSampleTimes = new double[sourceSpikeTimes.length];
|
||||
int first_target_index_of_embedding = k;
|
||||
while (destSpikeTimes[first_target_index_of_embedding] <= sourceSpikeTimes[l - 1]) {
|
||||
first_target_index_of_embedding++;
|
||||
}
|
||||
if (conditionalSpikeTimes.length != condEmbedDims.length) {
|
||||
throw new Exception("Number of conditional embedding lengths does not match the number of conditional processes");
|
||||
}
|
||||
for (int i = 0; i < conditionalSpikeTimes.length; i++) {
|
||||
while (destSpikeTimes[first_target_index_of_embedding] <= conditionalSpikeTimes[i][condEmbedDims[i]]) {
|
||||
first_target_index_of_embedding++;
|
||||
}
|
||||
}
|
||||
|
||||
//processTimeLengths.add(destSpikeTimes[sourceSpikeTimes.length - 1] - destSpikeTimes[first_target_index_of_embedding]);
|
||||
processTimeLengths.add(destSpikeTimes[destSpikeTimes.length - 1] - destSpikeTimes[first_target_index_of_embedding]);
|
||||
|
||||
double sample_lower_bound = destSpikeTimes[first_target_index_of_embedding];
|
||||
double sample_upper_bound = destSpikeTimes[destSpikeTimes.length - 1];
|
||||
int num_samples = (int) Math.round(num_samples_multiplier * (destSpikeTimes.length - first_target_index_of_embedding + 1));
|
||||
double[] randomSampleTimes = new double[num_samples];
|
||||
Random rand = new Random();
|
||||
for (int i = 0; i < randomSampleTimes.length; i++) {
|
||||
randomSampleTimes[i] = sample_lower_bound + rand.nextDouble() * (sample_upper_bound - sample_lower_bound);
|
||||
}
|
||||
Arrays.sort(randomSampleTimes);
|
||||
|
||||
makeEmbeddingsAtPoints(destSpikeTimes, sourceSpikeTimes, destSpikeTimes, targetEmbeddingsFromSpikes, jointEmbeddingsFromSpikes);
|
||||
makeEmbeddingsAtPoints(randomSampleTimes, sourceSpikeTimes, destSpikeTimes, targetEmbeddingsFromSamples, jointEmbeddingsFromSamples);
|
||||
makeEmbeddingsAtPoints(destSpikeTimes, first_target_index_of_embedding, sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
|
||||
ConditioningEmbeddingsFromSpikes, jointEmbeddingsFromSpikes);
|
||||
makeEmbeddingsAtPoints(randomSampleTimes, 0, sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
|
||||
ConditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples);
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getAddedMoreThanOneObservationSet()
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||
* getAddedMoreThanOneObservationSet()
|
||||
*/
|
||||
@Override
|
||||
public boolean getAddedMoreThanOneObservationSet() {
|
||||
return (vectorOfDestinationSpikeTimes != null) &&
|
||||
(vectorOfDestinationSpikeTimes.size() > 1);
|
||||
return (vectorOfDestinationSpikeTimes != null) && (vectorOfDestinationSpikeTimes.size() > 1);
|
||||
}
|
||||
|
||||
private double max_neighbour_distance(PriorityQueue<NeighbourNodeData> nnPQ) {
|
||||
double max_val = -1e9;
|
||||
while (nnPQ.peek() != null) {
|
||||
NeighbourNodeData nnData = nnPQ.poll();
|
||||
if (nnData.norms[0] > max_val) {
|
||||
max_val = nnData.norms[0];
|
||||
// Class to allow returning two values in the subsequent method
|
||||
private static class distanceAndNumPoints {
|
||||
public double distance;
|
||||
public int numPoints;
|
||||
|
||||
public distanceAndNumPoints(double distance, int numPoints) {
|
||||
this.distance = distance;
|
||||
this.numPoints = numPoints;
|
||||
}
|
||||
}
|
||||
return max_val;
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#computeAverageLocalOfObservations()
|
||||
private distanceAndNumPoints findMaxDistanceAndNumPointsFromIndices(double[] point, int[] indices, Vector<double[]> setOfPoints) {
|
||||
double maxDistance = 0;
|
||||
int i = 0;
|
||||
for (; indices[i] != -1; i++) {
|
||||
double distance = KdTree.norm(point, setOfPoints.elementAt(indices[i]), normType);
|
||||
if (distance > maxDistance) {
|
||||
maxDistance = distance;
|
||||
}
|
||||
}
|
||||
return new distanceAndNumPoints(maxDistance, i);
|
||||
}
|
||||
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||
* computeAverageLocalOfObservations()
|
||||
*/
|
||||
@Override
|
||||
public double computeAverageLocalOfObservations() throws Exception {
|
||||
|
||||
double currentSum = 0;
|
||||
for (int i = 0; i < targetEmbeddingsFromSpikes.size(); i++) {
|
||||
for (int i = 0; i < ConditioningEmbeddingsFromSpikes.size(); i++) {
|
||||
|
||||
PriorityQueue<NeighbourNodeData> nnPQJointSpikes =
|
||||
kdTreeJointAtSpikes.findKNearestNeighbours(Knns + 1, new double[][] {jointEmbeddingsFromSpikes.elementAt(i)});
|
||||
PriorityQueue<NeighbourNodeData> nnPQJointSamples =
|
||||
kdTreeJointAtSamples.findKNearestNeighbours(Knns, new double[][] {jointEmbeddingsFromSpikes.elementAt(i)});
|
||||
PriorityQueue<NeighbourNodeData> nnPQConditioningSpikes =
|
||||
kdTreeConditioningAtSpikes.findKNearestNeighbours(Knns + 1, new double[][] {targetEmbeddingsFromSpikes.elementAt(i)});
|
||||
PriorityQueue<NeighbourNodeData> nnPQConditioningSamples =
|
||||
kdTreeConditioningAtSamples.findKNearestNeighbours(Knns, new double[][] {targetEmbeddingsFromSpikes.elementAt(i)});
|
||||
double radiusJointSpikes = kdTreeJointAtSpikes.findKNearestNeighbours(Knns, i).poll().norms[0];
|
||||
double radiusJointSamples = kdTreeJointAtSamples.findKNearestNeighbours(Knns,
|
||||
new double[][] { jointEmbeddingsFromSpikes.elementAt(i) }).poll().norms[0];
|
||||
|
||||
double radiusJointSpikes = max_neighbour_distance(nnPQJointSpikes);
|
||||
double radiusJointSamples = max_neighbour_distance(nnPQJointSamples);
|
||||
double radiusConditioningSpikes = max_neighbour_distance(nnPQConditioningSpikes);
|
||||
double radiusConditioningSamples = max_neighbour_distance(nnPQConditioningSamples);
|
||||
/*
|
||||
The algorithm specified in box 1 of doi.org/10.1371/journal.pcbi.1008054 specifies finding the maximum of the two radii
|
||||
just calculated and then redoing the searches in both sets at this radius. In this implementation, however, we make use
|
||||
of the fact that one radius is equal to the maximum, and so only one search needs to be redone.
|
||||
*/
|
||||
double eps = 0.01;
|
||||
// Need variables for the number of neighbours as this is now variable within the maximum radius
|
||||
int kJointSpikes = 0;
|
||||
int kJointSamples = 0;
|
||||
if (radiusJointSpikes >= radiusJointSamples) {
|
||||
/*
|
||||
The maximum was the radius in the set of embeddings at spikes, so redo search in the set of embeddings at randomly
|
||||
sampled points, using this larger radius.
|
||||
*/
|
||||
kJointSpikes = Knns;
|
||||
int[] indicesWithinR = new int[jointEmbeddingsFromSamples.size()];
|
||||
boolean[] isWithinR = new boolean[jointEmbeddingsFromSamples.size()];
|
||||
kdTreeJointAtSamples.findPointsWithinR(radiusJointSpikes + eps,
|
||||
new double[][] { jointEmbeddingsFromSpikes.elementAt(i) },
|
||||
true,
|
||||
isWithinR,
|
||||
indicesWithinR);
|
||||
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(jointEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
|
||||
jointEmbeddingsFromSamples);
|
||||
kJointSamples = temp.numPoints;
|
||||
radiusJointSamples = temp.distance;
|
||||
} else {
|
||||
/*
|
||||
The maximum was the radius in the set of embeddings at randomly sampled points, so redo search in the set of embeddings
|
||||
at spikes, using this larger radius.
|
||||
*/
|
||||
kJointSamples = Knns;
|
||||
int[] indicesWithinR = new int[jointEmbeddingsFromSamples.size()];
|
||||
boolean[] isWithinR = new boolean[jointEmbeddingsFromSamples.size()];
|
||||
kdTreeJointAtSpikes.findPointsWithinR(radiusJointSamples + eps,
|
||||
new double[][] { jointEmbeddingsFromSpikes.elementAt(i) },
|
||||
true,
|
||||
isWithinR,
|
||||
indicesWithinR);
|
||||
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(jointEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
|
||||
jointEmbeddingsFromSpikes);
|
||||
// -1 due to the point itself being in the set
|
||||
kJointSpikes = temp.numPoints - 1;
|
||||
radiusJointSpikes = temp.distance;
|
||||
}
|
||||
|
||||
currentSum += ((k + l) * (- Math.log(radiusJointSpikes) + Math.log(radiusJointSamples))
|
||||
+ k * (Math.log(radiusConditioningSpikes) - Math.log(radiusConditioningSamples)));
|
||||
// Repeat the above steps, but in the conditioning (rather than joint) space.
|
||||
double radiusConditioningSpikes = kdTreeConditioningAtSpikes.findKNearestNeighbours(Knns, i).poll().norms[0];
|
||||
double radiusConditioningSamples = kdTreeConditioningAtSamples.findKNearestNeighbours(Knns,
|
||||
new double[][] { ConditioningEmbeddingsFromSpikes.elementAt(i) }).poll().norms[0];
|
||||
int kConditioningSpikes = 0;
|
||||
int kConditioningSamples = 0;
|
||||
if (radiusConditioningSpikes >= radiusConditioningSamples) {
|
||||
kConditioningSpikes = Knns;
|
||||
int[] indicesWithinR = new int[ConditioningEmbeddingsFromSamples.size()];
|
||||
boolean[] isWithinR = new boolean[ConditioningEmbeddingsFromSamples.size()];
|
||||
kdTreeConditioningAtSamples.findPointsWithinR(radiusConditioningSpikes + eps,
|
||||
new double[][] { ConditioningEmbeddingsFromSpikes.elementAt(i) },
|
||||
true,
|
||||
isWithinR,
|
||||
indicesWithinR);
|
||||
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(ConditioningEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
|
||||
ConditioningEmbeddingsFromSamples);
|
||||
kConditioningSamples = temp.numPoints;
|
||||
radiusConditioningSamples = temp.distance;
|
||||
} else {
|
||||
kConditioningSamples = Knns;
|
||||
int[] indicesWithinR = new int[ConditioningEmbeddingsFromSamples.size()];
|
||||
boolean[] isWithinR = new boolean[ConditioningEmbeddingsFromSamples.size()];
|
||||
kdTreeConditioningAtSpikes.findPointsWithinR(radiusConditioningSamples + eps,
|
||||
new double[][] { ConditioningEmbeddingsFromSpikes.elementAt(i) },
|
||||
true,
|
||||
isWithinR,
|
||||
indicesWithinR);
|
||||
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(ConditioningEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
|
||||
ConditioningEmbeddingsFromSpikes);
|
||||
// -1 due to the point itself being in the set
|
||||
kConditioningSpikes = temp.numPoints - 1;
|
||||
radiusConditioningSpikes = temp.distance;
|
||||
}
|
||||
|
||||
|
||||
|
||||
currentSum += (MathsUtils.digamma(kJointSpikes) - MathsUtils.digamma(kJointSamples) +
|
||||
((k + l) * (-Math.log(radiusJointSpikes) + Math.log(radiusJointSamples))) -
|
||||
MathsUtils.digamma(kConditioningSpikes) + MathsUtils.digamma(kConditioningSamples) +
|
||||
+ (k * (Math.log(radiusConditioningSpikes) - Math.log(radiusConditioningSamples))));
|
||||
if (Double.isNaN(currentSum)) {
|
||||
throw new Exception(kJointSpikes + " " + kJointSamples + " " + kConditioningSpikes + " " + kConditioningSamples + "\n" +
|
||||
radiusJointSpikes + " " + radiusJointSamples + " " + radiusConditioningSpikes + " " + radiusConditioningSamples);
|
||||
}
|
||||
}
|
||||
// Normalise by time
|
||||
currentSum /= (vectorOfDestinationSpikeTimes.elementAt(0)[vectorOfDestinationSpikeTimes.elementAt(0).length - 1]
|
||||
- vectorOfDestinationSpikeTimes.elementAt(0)[0]);
|
||||
double time_sum = 0;
|
||||
for (Double time : processTimeLengths) {
|
||||
time_sum += time;
|
||||
}
|
||||
currentSum /= time_sum;
|
||||
return currentSum;
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#computeLocalOfPreviousObservations()
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||
* computeLocalOfPreviousObservations()
|
||||
*/
|
||||
@Override
|
||||
public SpikingLocalInformationValues computeLocalOfPreviousObservations()
|
||||
throws Exception {
|
||||
public SpikingLocalInformationValues computeLocalOfPreviousObservations() throws Exception {
|
||||
// TODO Auto-generated method stub
|
||||
return null;
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#computeSignificance(int)
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||
* computeSignificance(int)
|
||||
*/
|
||||
@Override
|
||||
public EmpiricalMeasurementDistribution computeSignificance(
|
||||
int numPermutationsToCheck) throws Exception {
|
||||
public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck) throws Exception {
|
||||
// TODO Auto-generated method stub
|
||||
return null;
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#computeSignificance(int[][])
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||
* computeSignificance(int[][])
|
||||
*/
|
||||
@Override
|
||||
public EmpiricalMeasurementDistribution computeSignificance(
|
||||
int[][] newOrderings) throws Exception {
|
||||
public EmpiricalMeasurementDistribution computeSignificance(int[][] newOrderings) throws Exception {
|
||||
// TODO Auto-generated method stub
|
||||
return null;
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setDebug(boolean)
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setDebug(
|
||||
* boolean)
|
||||
*/
|
||||
@Override
|
||||
public void setDebug(boolean debug) {
|
||||
this.debug = debug;
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getLastAverage()
|
||||
/*
|
||||
* (non-Javadoc)
|
||||
*
|
||||
* @see
|
||||
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getLastAverage
|
||||
* ()
|
||||
*/
|
||||
@Override
|
||||
public double getLastAverage() {
|
||||
|
|
|
@ -337,7 +337,7 @@ public class KdTree extends NearestNeighbourSearcher {
|
|||
double difference = x1[d] - x2[d];
|
||||
distance += difference * difference;
|
||||
}
|
||||
return distance;
|
||||
return Math.sqrt(distance);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -359,6 +359,7 @@ public class KdTree extends NearestNeighbourSearcher {
|
|||
*/
|
||||
public final static double normWithAbort(double[] x1, double[] x2,
|
||||
double limit, int normToUse) {
|
||||
|
||||
double distance = 0.0;
|
||||
switch (normToUse) {
|
||||
case EuclideanUtils.NORM_MAX_NORM:
|
||||
|
@ -379,6 +380,8 @@ public class KdTree extends NearestNeighbourSearcher {
|
|||
return distance;
|
||||
// case EuclideanUtils.NORM_EUCLIDEAN_SQUARED:
|
||||
default:
|
||||
// Limit is often r, so must square
|
||||
limit = limit * limit;
|
||||
// Inlined from {@link EuclideanUtils}:
|
||||
for (int d = 0; d < x1.length; d++) {
|
||||
double difference = x1[d] - x2[d];
|
||||
|
@ -387,7 +390,7 @@ public class KdTree extends NearestNeighbourSearcher {
|
|||
return Double.POSITIVE_INFINITY;
|
||||
}
|
||||
}
|
||||
return distance;
|
||||
return Math.sqrt(distance);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1305,9 +1308,10 @@ public class KdTree extends NearestNeighbourSearcher {
|
|||
if (normTypeToUse == EuclideanUtils.NORM_MAX_NORM) {
|
||||
absDistOnThisDim = (distOnThisDim > 0) ? distOnThisDim : - distOnThisDim;
|
||||
} else {
|
||||
absDistOnThisDim = (distOnThisDim > 0) ? distOnThisDim : - distOnThisDim;
|
||||
// norm type is EuclideanUtils#NORM_EUCLIDEAN_SQUARED
|
||||
// Track the square distance
|
||||
absDistOnThisDim = distOnThisDim * distOnThisDim;
|
||||
//absDistOnThisDim = distOnThisDim;
|
||||
}
|
||||
|
||||
if ((node.indexOfThisPoint != sampleIndex) &&
|
||||
|
@ -2053,7 +2057,7 @@ public class KdTree extends NearestNeighbourSearcher {
|
|||
KdTreeNode node, int level, double r, boolean allowEqualToR,
|
||||
boolean[] isWithinR, int[] indicesWithinR,
|
||||
int nextIndexInIndicesWithinR) {
|
||||
|
||||
//System.out.println("foo");
|
||||
// Point to the correct array for the data at this level
|
||||
int currentDim = level % totalDimensions;
|
||||
double[][] data = dimensionToArray[currentDim];
|
||||
|
@ -2068,9 +2072,9 @@ public class KdTree extends NearestNeighbourSearcher {
|
|||
if (normTypeToUse == EuclideanUtils.NORM_MAX_NORM) {
|
||||
absDistOnThisDim = (distOnThisDim > 0) ? distOnThisDim : - distOnThisDim;
|
||||
} else {
|
||||
absDistOnThisDim = (distOnThisDim > 0) ? distOnThisDim : - distOnThisDim;
|
||||
// norm type is EuclideanUtils#NORM_EUCLIDEAN_SQUARED
|
||||
// Track the square distance
|
||||
absDistOnThisDim = distOnThisDim * distOnThisDim;
|
||||
}
|
||||
|
||||
if ((absDistOnThisDim < r) ||
|
||||
|
|
84
tester.py
84
tester.py
|
@ -27,15 +27,14 @@ import os
|
|||
import numpy as np
|
||||
|
||||
|
||||
NUM_REPS = 20
|
||||
NUM_SPIKES = int(1e4)
|
||||
NUM_REPS = 5
|
||||
NUM_SPIKES = int(5e3)
|
||||
NUM_OBSERVATIONS = 2
|
||||
|
||||
# Params for canonical example generation
|
||||
RATE_Y = 1.0
|
||||
RATE_X_MAX = 10
|
||||
|
||||
|
||||
|
||||
def generate_canonical_example_processes(num_y_events):
|
||||
event_train_x = []
|
||||
event_train_x.append(0)
|
||||
|
@ -76,30 +75,97 @@ if (not(os.path.isfile(jarLocation))):
|
|||
# Start the JVM (add the "-Xmx" option with say 1024M if you get crashes due to not enough memory space)
|
||||
startJVM(getDefaultJVMPath(), "-ea", "-Djava.class.path=" + jarLocation)
|
||||
teCalcClass = JPackage("infodynamics.measures.spiking.integration").TransferEntropyCalculatorSpikingIntegration
|
||||
|
||||
|
||||
teCalc = teCalcClass()
|
||||
teCalc.setProperty("knns", "4")
|
||||
|
||||
print("Independent Poisson Processes")
|
||||
teCalc.setProperty("k_HISTORY", "1")
|
||||
teCalc.setProperty("l_HISTORY", "1")
|
||||
teCalc.setProperty("COND_EMBED_LENGTHS", "2,2")
|
||||
teCalc.setProperty("k_HISTORY", "2")
|
||||
teCalc.setProperty("l_HISTORY", "2")
|
||||
teCalc.setProperty("NORM_TYPE", "MAX_NORM")
|
||||
|
||||
results_poisson = np.zeros(NUM_REPS)
|
||||
for i in range(NUM_REPS):
|
||||
teCalc.startAddObservations()
|
||||
for j in range(NUM_OBSERVATIONS):
|
||||
sourceArray = NUM_SPIKES*np.random.random(NUM_SPIKES)
|
||||
sourceArray.sort()
|
||||
destArray = NUM_SPIKES*np.random.random(NUM_SPIKES)
|
||||
destArray.sort()
|
||||
|
||||
teCalc.setObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray))
|
||||
condArray = NUM_SPIKES*np.random.random((2, NUM_SPIKES))
|
||||
condArray.sort(axis = 1)
|
||||
teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray), JArray(JDouble, 2)(condArray))
|
||||
teCalc.finaliseAddObservations();
|
||||
result = teCalc.computeAverageLocalOfObservations()
|
||||
print("TE result %.4f nats" % (result,))
|
||||
results_poisson[i] = result
|
||||
print("Summary: mean ", np.mean(results_poisson), " std dev ", np.std(results_poisson))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
teCalc = teCalcClass()
|
||||
teCalc.setProperty("knns", "4")
|
||||
print("Noisy copy zero TE")
|
||||
teCalc.setProperty("COND_EMBED_LENGTHS", "2")
|
||||
teCalc.setProperty("k_HISTORY", "2")
|
||||
teCalc.setProperty("l_HISTORY", "2")
|
||||
|
||||
results_noisy_zero = np.zeros(NUM_REPS)
|
||||
for i in range(NUM_REPS):
|
||||
teCalc.startAddObservations()
|
||||
for j in range(NUM_OBSERVATIONS):
|
||||
condArray = NUM_SPIKES*np.random.random((1, NUM_SPIKES))
|
||||
condArray.sort(axis = 1)
|
||||
sourceArray = condArray[0, :] + 0.25 + 0.1 * np.random.normal(size = condArray.shape[1])
|
||||
sourceArray.sort()
|
||||
destArray = condArray[0, :] + 0.5 + 0.1 * np.random.normal(size = condArray.shape[1])
|
||||
destArray.sort()
|
||||
teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray), JArray(JDouble, 2)(condArray))
|
||||
teCalc.finaliseAddObservations();
|
||||
result = teCalc.computeAverageLocalOfObservations()
|
||||
print("TE result %.4f nats" % (result,))
|
||||
results_poisson[i] = result
|
||||
print("Summary: mean ", np.mean(results_poisson), " std dev ", np.std(results_poisson))
|
||||
|
||||
|
||||
|
||||
teCalc = teCalcClass()
|
||||
teCalc.setProperty("knns", "4")
|
||||
print("Noisy copy non-zero TE")
|
||||
teCalc.setProperty("COND_EMBED_LENGTHS", "2")
|
||||
teCalc.setProperty("k_HISTORY", "2")
|
||||
teCalc.setProperty("l_HISTORY", "2")
|
||||
|
||||
results_noisy_zero = np.zeros(NUM_REPS)
|
||||
for i in range(NUM_REPS):
|
||||
teCalc.startAddObservations()
|
||||
for j in range(NUM_OBSERVATIONS):
|
||||
sourceArray = NUM_SPIKES*np.random.random(NUM_SPIKES)
|
||||
sourceArray.sort()
|
||||
condArray = sourceArray + 0.25 + 0.1 * np.random.normal(size = sourceArray.shape)
|
||||
condArray.sort()
|
||||
condArray = np.expand_dims(condArray, 0)
|
||||
destArray = sourceArray + 0.5 + 0.1 * np.random.normal(size = sourceArray.shape)
|
||||
destArray.sort()
|
||||
teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray), JArray(JDouble, 2)(condArray))
|
||||
teCalc.finaliseAddObservations();
|
||||
result = teCalc.computeAverageLocalOfObservations()
|
||||
print("TE result %.4f nats" % (result,))
|
||||
results_poisson[i] = result
|
||||
print("Summary: mean ", np.mean(results_poisson), " std dev ", np.std(results_poisson))
|
||||
|
||||
print("Canonical example")
|
||||
teCalc = teCalcClass()
|
||||
teCalc.setProperty("knns", "4")
|
||||
teCalc.setProperty("k_HISTORY", "2")
|
||||
teCalc.setProperty("l_HISTORY", "1")
|
||||
#teCalc.setProperty("NUM_SAMPLES_MULTIPLIER", "1")
|
||||
#teCalc.setProperty("NORM_TYPE", "MAX_NORM")
|
||||
|
||||
results_canonical = np.zeros(NUM_REPS)
|
||||
for i in range(NUM_REPS):
|
||||
|
|
Loading…
Reference in New Issue