refactoring + conditional processes + radius sharing + euclidean norm working

This commit is contained in:
David Shorten 2021-08-02 22:06:07 +10:00
parent 9a17361cca
commit dc85001ff7
3 changed files with 564 additions and 213 deletions

View File

@ -7,6 +7,7 @@ import java.util.PriorityQueue;
import java.util.Random;
import java.util.Vector;
//import infodynamics.measures.continuous.kraskov.EuclideanUtils;
import infodynamics.measures.spiking.TransferEntropyCalculatorSpiking;
import infodynamics.utils.EmpiricalMeasurementDistribution;
import infodynamics.utils.KdTree;
@ -15,48 +16,71 @@ import infodynamics.utils.MatrixUtils;
import infodynamics.utils.NeighbourNodeData;
import infodynamics.utils.FirstIndexComparatorDouble;
import infodynamics.utils.UnivariateNearestNeighbourSearcher;
import infodynamics.utils.EuclideanUtils;
import infodynamics.utils.ParsedProperties;
/**
* Computes the transfer entropy between a pair of spike trains,
* using an integration-based measure in order to match the theoretical
* form of TE between such spike trains.
* Computes the transfer entropy between a pair of spike trains, using an
* integration-based measure in order to match the theoretical form of TE
* between such spike trains.
*
* <p>Usage paradigm is as per the interface {@link TransferEntropyCalculatorSpiking} </p>
* <p>
* Usage paradigm is as per the interface
* {@link TransferEntropyCalculatorSpiking}
* </p>
*
* @author Joseph Lizier (<a href="joseph.lizier at gmail.com">email</a>,
* <a href="http://lizier.me/joseph/">www</a>)
*/
public class TransferEntropyCalculatorSpikingIntegration implements
TransferEntropyCalculatorSpiking {
public class TransferEntropyCalculatorSpikingIntegration implements TransferEntropyCalculatorSpiking {
/**
* Number of past destination spikes to consider (akin to embedding length)
* Number of past destination interspike intervals to consider (akin to embedding length)
*/
protected int k = 1;
/**
* Number of past source spikes to consider (akin to embedding length)
* Number of past source interspike intervals to consider (akin to embedding length)
*/
protected int l = 1;
/**
* Property name for number of interspike intervals for the conditional variables
*/
public static final String COND_EMBED_LENGTHS_PROP_NAME = "COND_EMBED_LENGTHS";
/**
* Array of history interspike interval embedding lengths for the conditional variables.
* Can be an empty array or null if there are no conditional variables.
*/
protected int[] condEmbedDims = new int[] {};
/**
* Number of nearest neighbours to search for in the full joint space
*/
protected int Knns = 4;
/**
* Storage for source observations supplied via {@link #addObservations(double[], double[])} etc.
* Storage for source observations supplied via
* {@link #addObservations(double[], double[])} etc.
*/
protected Vector<double[]> vectorOfSourceSpikeTimes = null;
/**
* Storage for destination observations supplied via {@link #addObservations(double[], double[])} etc.
* Storage for destination observations supplied via
* {@link #addObservations(double[], double[])} etc.
*/
protected Vector<double[]> vectorOfDestinationSpikeTimes = null;
Vector<double[]> targetEmbeddingsFromSpikes = null;
/**
* Storage for conditional observations supplied via
* {@link #addObservations(double[], double[])} etc.
*/
protected Vector<double[][]> vectorOfConditionalSpikeTimes = null;
Vector<double[]> ConditioningEmbeddingsFromSpikes = null;
Vector<double[]> jointEmbeddingsFromSpikes = null;
Vector<double[]> targetEmbeddingsFromSamples = null;
Vector<double[]> ConditioningEmbeddingsFromSamples = null;
Vector<double[]> jointEmbeddingsFromSamples = null;
Vector<Double> processTimeLengths = null;
protected KdTree kdTreeJointAtSpikes = null;
protected KdTree kdTreeJointAtSamples = null;
@ -66,10 +90,11 @@ public class TransferEntropyCalculatorSpikingIntegration implements
public static final String KNNS_PROP_NAME = "Knns";
/**
* Property name for an amount of random Gaussian noise to be
* added to the data (default is 1e-8, matching the MILCA toolkit).
* Property name for an amount of random Gaussian noise to be added to the data
* (default is 1e-8, matching the MILCA toolkit).
*/
public static final String PROP_ADD_NOISE = "NOISE_LEVEL_TO_ADD";
/**
* Whether to add an amount of random noise to the incoming data
*/
@ -79,35 +104,61 @@ public class TransferEntropyCalculatorSpikingIntegration implements
*/
protected double noiseLevel = (double) 1e-8;
protected boolean trimToPosNextSpikeTimes = false;
/**
* Stores whether we are in debug mode
*/
protected boolean debug = false;
/**
* Property name for the number of random sample points to use as a multiple
* of the number of target spikes.
*/
public static final String PROP_SAMPLE_MULTIPLIER = "NUM_SAMPLES_MULTIPLIER";
protected double num_samples_multiplier = 1.0;
/**
* Property name for what type of norm to use between data points
* for each marginal variable -- Options are defined by
* {@link KdTree#setNormType(String)} and the
* default is {@link EuclideanUtils#NORM_EUCLIDEAN}.
*/
public final static String PROP_NORM_TYPE = "NORM_TYPE";
protected int normType = EuclideanUtils.NORM_EUCLIDEAN;
public TransferEntropyCalculatorSpikingIntegration() {
super();
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(int)
/*
* (non-Javadoc)
*
* @see
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(
* int)
*/
@Override
public void initialise() throws Exception {
initialise(k, l);
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(int)
/*
* (non-Javadoc)
*
* @see
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(
* int)
*/
@Override
public void initialise(int k) throws Exception {
initialise(k, this.l);
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(int, int)
/*
* (non-Javadoc)
*
* @see
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(
* int, int)
*/
@Override
public void initialise(int k, int l) throws Exception {
@ -120,40 +171,74 @@ public class TransferEntropyCalculatorSpikingIntegration implements
vectorOfDestinationSpikeTimes = null;
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setProperty(java.lang.String, java.lang.String)
/*
* (non-Javadoc)
*
* @see
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setProperty(
* java.lang.String, java.lang.String)
*/
@Override
public void setProperty(String propertyName, String propertyValue)
throws Exception {
public void setProperty(String propertyName, String propertyValue) throws Exception {
boolean propertySet = true;
if (propertyName.equalsIgnoreCase(K_PROP_NAME)) {
k = Integer.parseInt(propertyValue);
int k_temp = Integer.parseInt(propertyValue);
if (k_temp < 1) {
throw new Exception ("Invalid k value less than 1.");
} else {
k = k_temp;
}
} else if (propertyName.equalsIgnoreCase(L_PROP_NAME)) {
l = Integer.parseInt(propertyValue);
int l_temp = Integer.parseInt(propertyValue);
if (l_temp < 1) {
throw new Exception ("Invalid l value less than 1.");
} else {
l = l_temp;
}
} else if (propertyName.equalsIgnoreCase(COND_EMBED_LENGTHS_PROP_NAME)) {
int[] condEmbedDims_temp = ParsedProperties.parseStringArrayOfInts(propertyValue);
for (int dim : condEmbedDims_temp) {
if (dim < 1) {
throw new Exception ("Invalid conditional embedding value less than 1.");
}
}
condEmbedDims = condEmbedDims_temp;
} else if (propertyName.equalsIgnoreCase(KNNS_PROP_NAME)) {
Knns = Integer.parseInt(propertyValue);
} else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) {
if (propertyValue.equals("0") ||
propertyValue.equalsIgnoreCase("false")) {
if (propertyValue.equals("0") || propertyValue.equalsIgnoreCase("false")) {
addNoise = false;
noiseLevel = 0;
} else {
addNoise = true;
noiseLevel = Double.parseDouble(propertyValue);
}
} else if (propertyName.equalsIgnoreCase(PROP_SAMPLE_MULTIPLIER)) {
double temp_num_samples_multiplier = Double.parseDouble(propertyValue);
if (temp_num_samples_multiplier <= 0) {
throw new Exception ("Num samples multiplier must be greater than 0.");
} else {
num_samples_multiplier = temp_num_samples_multiplier;
}
} else if (propertyName.equalsIgnoreCase(PROP_NORM_TYPE)) {
normType = KdTree.validateNormType(propertyValue);
} else {
// No property was set on this class
propertySet = false;
}
if (debug && propertySet) {
System.out.println(this.getClass().getSimpleName() + ": Set property " + propertyName +
" to " + propertyValue);
System.out.println(
this.getClass().getSimpleName() + ": Set property " + propertyName + " to " + propertyValue);
}
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getProperty(java.lang.String)
/*
* (non-Javadoc)
*
* @see
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getProperty(
* java.lang.String)
*/
@Override
public String getProperty(String propertyName) throws Exception {
@ -165,112 +250,143 @@ public class TransferEntropyCalculatorSpikingIntegration implements
return Integer.toString(Knns);
} else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) {
return Double.toString(noiseLevel);
} else if (propertyName.equalsIgnoreCase(PROP_SAMPLE_MULTIPLIER)) {
return Double.toString(num_samples_multiplier);
} else {
// No property matches for this class
return null;
}
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setObservations(double[], double[])
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* setObservations(double[], double[])
*/
@Override
public void setObservations(double[] source, double[] destination)
throws Exception {
public void setObservations(double[] source, double[] destination) throws Exception {
startAddObservations();
addObservations(source, destination);
finaliseAddObservations();
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#startAddObservations()
public void setObservations(double[] source, double[] destination, double[][] conditionals) throws Exception {
startAddObservations();
addObservations(source, destination, conditionals);
finaliseAddObservations();
}
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* startAddObservations()
*/
@Override
public void startAddObservations() {
vectorOfSourceSpikeTimes = new Vector<double[]>();
vectorOfDestinationSpikeTimes = new Vector<double[]>();
vectorOfConditionalSpikeTimes = new Vector<double[][]>();
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#addObservations(double[], double[])
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* addObservations(double[], double[])
*/
@Override
public void addObservations(double[] source, double[] destination)
throws Exception {
public void addObservations(double[] source, double[] destination) throws Exception {
// Store these observations in our vector for now
vectorOfSourceSpikeTimes.add(source);
vectorOfDestinationSpikeTimes.add(destination);
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#finaliseAddObservations()
public void addObservations(double[] source, double[] destination, double[][] conditionals) throws Exception {
// Store these observations in our vector for now
vectorOfSourceSpikeTimes.add(source);
vectorOfDestinationSpikeTimes.add(destination);
vectorOfConditionalSpikeTimes.add(conditionals);
}
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* finaliseAddObservations()
*/
@Override
public void finaliseAddObservations() throws Exception {
targetEmbeddingsFromSpikes = new Vector<double[]>();
ConditioningEmbeddingsFromSpikes = new Vector<double[]>();
jointEmbeddingsFromSpikes = new Vector<double[]>();
targetEmbeddingsFromSamples = new Vector<double[]>();
ConditioningEmbeddingsFromSamples = new Vector<double[]>();
jointEmbeddingsFromSamples = new Vector<double[]>();
processTimeLengths = new Vector<Double>();
// Send all of the observations through:
Iterator<double[]> sourceIterator = vectorOfSourceSpikeTimes.iterator();
int timeSeriesIndex = 0;
if (vectorOfConditionalSpikeTimes.size() > 0) {
Iterator<double[][]> conditionalIterator = vectorOfConditionalSpikeTimes.iterator();
for (double[] destSpikeTimes : vectorOfDestinationSpikeTimes) {
double[] sourceSpikeTimes = sourceIterator.next();
processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes,
targetEmbeddingsFromSpikes, jointEmbeddingsFromSpikes,
targetEmbeddingsFromSamples, jointEmbeddingsFromSamples);
double[][] conditionalSpikeTimes = conditionalIterator.next();
processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, ConditioningEmbeddingsFromSpikes,
jointEmbeddingsFromSpikes, ConditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples,
processTimeLengths);
}
} else {
for (double[] destSpikeTimes : vectorOfDestinationSpikeTimes) {
double[] sourceSpikeTimes = sourceIterator.next();
double[][] conditionalSpikeTimes = new double[][] {};
processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, ConditioningEmbeddingsFromSpikes,
jointEmbeddingsFromSpikes, ConditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples,
processTimeLengths);
}
}
// Convert the vectors to arrays so that they can be put in the trees
double[][] arrayedTargetEmbeddingsFromSpikes = new double[targetEmbeddingsFromSpikes.size()][k];
double[][] arrayedJointEmbeddingsFromSpikes = new double[targetEmbeddingsFromSpikes.size()][k + l];
for (int i = 0; i < targetEmbeddingsFromSpikes.size(); i++) {
arrayedTargetEmbeddingsFromSpikes[i] = targetEmbeddingsFromSpikes.elementAt(i);
double[][] arrayedTargetEmbeddingsFromSpikes = new double[ConditioningEmbeddingsFromSpikes.size()][k];
double[][] arrayedJointEmbeddingsFromSpikes = new double[ConditioningEmbeddingsFromSpikes.size()][k + l];
for (int i = 0; i < ConditioningEmbeddingsFromSpikes.size(); i++) {
arrayedTargetEmbeddingsFromSpikes[i] = ConditioningEmbeddingsFromSpikes.elementAt(i);
arrayedJointEmbeddingsFromSpikes[i] = jointEmbeddingsFromSpikes.elementAt(i);
}
double[][] arrayedTargetEmbeddingsFromSamples = new double[targetEmbeddingsFromSamples.size()][k];
double[][] arrayedJointEmbeddingsFromSamples = new double[targetEmbeddingsFromSamples.size()][k + l];
for (int i = 0; i < targetEmbeddingsFromSamples.size(); i++) {
arrayedTargetEmbeddingsFromSamples[i] = targetEmbeddingsFromSamples.elementAt(i);
double[][] arrayedTargetEmbeddingsFromSamples = new double[ConditioningEmbeddingsFromSamples.size()][k];
double[][] arrayedJointEmbeddingsFromSamples = new double[ConditioningEmbeddingsFromSamples.size()][k + l];
for (int i = 0; i < ConditioningEmbeddingsFromSamples.size(); i++) {
arrayedTargetEmbeddingsFromSamples[i] = ConditioningEmbeddingsFromSamples.elementAt(i);
arrayedJointEmbeddingsFromSamples[i] = jointEmbeddingsFromSamples.elementAt(i);
}
kdTreeJointAtSpikes = new KdTree(
new int[] {k + l},
new double[][][] {arrayedJointEmbeddingsFromSpikes});
kdTreeJointAtSamples = new KdTree(
new int[] {k + l},
new double[][][] {arrayedJointEmbeddingsFromSamples});
kdTreeConditioningAtSpikes = new KdTree(
new int[] {k},
new double[][][] {arrayedTargetEmbeddingsFromSpikes});
kdTreeConditioningAtSamples = new KdTree(
new int[] {k},
new double[][][] {arrayedTargetEmbeddingsFromSamples});
kdTreeJointAtSpikes = new KdTree(arrayedJointEmbeddingsFromSpikes);
kdTreeJointAtSamples = new KdTree(arrayedJointEmbeddingsFromSamples);
kdTreeConditioningAtSpikes = new KdTree(arrayedTargetEmbeddingsFromSpikes);
kdTreeConditioningAtSamples = new KdTree(arrayedTargetEmbeddingsFromSamples);
/*kdTreeJointAtSpikes.setNormType("EUCLIDEAN");
kdTreeJointAtSamples.setNormType("EUCLIDEAN");
kdTreeConditioningAtSpikes.setNormType("EUCLIDEAN");
kdTreeConditioningAtSamples.setNormType("EUCLIDEAN");*/
kdTreeJointAtSpikes.setNormType(normType);
kdTreeJointAtSamples.setNormType(normType);
kdTreeConditioningAtSpikes.setNormType(normType);
kdTreeConditioningAtSamples.setNormType(normType);
}
protected void makeEmbeddingsAtPoints(double[] pointsAtWhichToMakeEmbeddings, double[] sourceSpikeTimes, double[] destSpikeTimes,
Vector<double[]> targetEmbeddings, Vector<double[]> jointEmbeddings) {
//System.out.println("foo");
protected void makeEmbeddingsAtPoints(double[] pointsAtWhichToMakeEmbeddings, int index_of_first_point_to_use,
double[] sourceSpikeTimes, double[] destSpikeTimes,
double[][] conditionalSpikeTimes,
Vector<double[]> ConditioningEmbeddings,
Vector<double[]> jointEmbeddings) {
Random random = new Random();
int embedding_point_index = 0;
int embedding_point_index = index_of_first_point_to_use;
int most_recent_dest_index = k;
int most_recent_source_index = l;
// Make sure that the first point at which an embedding is made has enough preceding spikes in both source and
// target for embeddings to be made.
while (pointsAtWhichToMakeEmbeddings[embedding_point_index] <= destSpikeTimes[most_recent_dest_index] |
pointsAtWhichToMakeEmbeddings[embedding_point_index] <= sourceSpikeTimes[most_recent_source_index]) {
embedding_point_index++;
int[] most_recent_conditioning_indices = Arrays.copyOf(condEmbedDims, condEmbedDims.length);
int total_length_of_conditioning_embeddings = 0;
for (int i = 0; i < condEmbedDims.length; i++) {
total_length_of_conditioning_embeddings += condEmbedDims[i];
}
// Loop through the points at which embeddings need to be made
@ -286,162 +402,327 @@ public class TransferEntropyCalculatorSpikingIntegration implements
}
// Do the same for the most recent source index
while (most_recent_source_index < (sourceSpikeTimes.length - 1)) {
if (sourceSpikeTimes[most_recent_source_index + 1] < pointsAtWhichToMakeEmbeddings[embedding_point_index]) {
if (sourceSpikeTimes[most_recent_source_index
+ 1] < pointsAtWhichToMakeEmbeddings[embedding_point_index]) {
most_recent_source_index++;
} else {
break;
}
}
// Now advance the trackers for the most recent conditioning indices
for (int j = 0; j < most_recent_conditioning_indices.length; j++) {
while (most_recent_conditioning_indices[j] < (conditionalSpikeTimes[j].length - 1)) {
if (conditionalSpikeTimes[j][most_recent_conditioning_indices[j] + 1] < pointsAtWhichToMakeEmbeddings[embedding_point_index]) {
most_recent_conditioning_indices[j]++;
} else {
break;
}
}
}
double[] destPast = new double[k];
double[] jointPast = new double[k + l];
destPast[0] = pointsAtWhichToMakeEmbeddings[embedding_point_index] -
destSpikeTimes[most_recent_dest_index];
jointPast[0] = pointsAtWhichToMakeEmbeddings[embedding_point_index] -
destSpikeTimes[most_recent_dest_index];
jointPast[k] = pointsAtWhichToMakeEmbeddings[embedding_point_index] -
sourceSpikeTimes[most_recent_source_index];
double[] conditioningPast = new double[k + total_length_of_conditioning_embeddings];
double[] jointPast = new double[k + total_length_of_conditioning_embeddings + l];
// Add the embedding intervals from the target process
conditioningPast[0] = pointsAtWhichToMakeEmbeddings[embedding_point_index] - destSpikeTimes[most_recent_dest_index];
jointPast[0] = pointsAtWhichToMakeEmbeddings[embedding_point_index]
- destSpikeTimes[most_recent_dest_index];
for (int i = 1; i < k; i++) {
destPast[i] = destSpikeTimes[most_recent_dest_index - i + 1] -
destSpikeTimes[most_recent_dest_index - i];
jointPast[i] = destSpikeTimes[most_recent_dest_index - i + 1] -
destSpikeTimes[most_recent_dest_index - i];
}
for (int i = 1; i < l; i++) {
jointPast[k + i] = sourceSpikeTimes[most_recent_source_index - i + 1] -
sourceSpikeTimes[most_recent_source_index - i];
conditioningPast[i] = destSpikeTimes[most_recent_dest_index - i + 1]
- destSpikeTimes[most_recent_dest_index - i];
jointPast[i] = destSpikeTimes[most_recent_dest_index - i + 1]
- destSpikeTimes[most_recent_dest_index - i];
}
if (addNoise) {
for (int i = 0; i < k; i++) {
destPast[i] += random.nextGaussian()*noiseLevel;
// Add the embeding intervals from the conditional processes
int index_of_next_embedding_interval = k;
for (int i = 0; i < condEmbedDims.length; i++) {
conditioningPast[index_of_next_embedding_interval] =
pointsAtWhichToMakeEmbeddings[embedding_point_index] - conditionalSpikeTimes[i][most_recent_conditioning_indices[i]];
jointPast[index_of_next_embedding_interval] =
pointsAtWhichToMakeEmbeddings[embedding_point_index] - conditionalSpikeTimes[i][most_recent_conditioning_indices[i]];
index_of_next_embedding_interval += 1;
for (int j = 1; j < condEmbedDims[i]; j++) {
conditioningPast[index_of_next_embedding_interval] =
conditionalSpikeTimes[i][most_recent_conditioning_indices[i] - j + 1] -
conditionalSpikeTimes[i][most_recent_conditioning_indices[i] - j];
jointPast[index_of_next_embedding_interval] =
conditionalSpikeTimes[i][most_recent_conditioning_indices[i] - j + 1] -
conditionalSpikeTimes[i][most_recent_conditioning_indices[i] - j];
index_of_next_embedding_interval += 1;
}
for (int i = 0; i < l; i++) {
}
// Add the embedding intervals from the source process (this only gets added to the joint embeddings)
jointPast[k + total_length_of_conditioning_embeddings] = pointsAtWhichToMakeEmbeddings[embedding_point_index]
- sourceSpikeTimes[most_recent_source_index];
for (int i = 1; i < l; i++) {
jointPast[k + total_length_of_conditioning_embeddings + i] = sourceSpikeTimes[most_recent_source_index - i + 1]
- sourceSpikeTimes[most_recent_source_index - i];
}
// Add Gaussian noise, if necessary
if (addNoise) {
for (int i = 0; i < conditioningPast.length; i++) {
conditioningPast[i] += random.nextGaussian() * noiseLevel;
}
for (int i = 0; i < jointPast.length; i++) {
jointPast[i] += random.nextGaussian() * noiseLevel;
}
}
targetEmbeddings.add(destPast);
ConditioningEmbeddings.add(conditioningPast);
jointEmbeddings.add(jointPast);
}
}
protected void processEventsFromSpikingTimeSeries(double[] sourceSpikeTimes, double[] destSpikeTimes,
Vector<double[]> targetEmbeddingsFromSpikes, Vector<double[]> jointEmbeddingsFromSpikes,
Vector<double[]> targetEmbeddingsFromSamples, Vector<double[]> jointEmbeddingsFromSamples)
protected void processEventsFromSpikingTimeSeries(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes,
Vector<double[]> ConditioningEmbeddingsFromSpikes, Vector<double[]> jointEmbeddingsFromSpikes,
Vector<double[]> ConditioningEmbeddingsFromSamples, Vector<double[]> jointEmbeddingsFromSamples,
Vector<Double> processTimeLengths)
throws Exception {
// addObservationsAfterParamsDetermined(sourceSpikeTimes, destSpikeTimes);
// First sort the spike times in case they were not properly in ascending order:
Arrays.sort(sourceSpikeTimes);
Arrays.sort(destSpikeTimes);
double sample_lower_bound = Arrays.stream(sourceSpikeTimes).min().getAsDouble();
double sample_upper_bound = Arrays.stream(sourceSpikeTimes).max().getAsDouble();
double[] randomSampleTimes = new double[sourceSpikeTimes.length];
int first_target_index_of_embedding = k;
while (destSpikeTimes[first_target_index_of_embedding] <= sourceSpikeTimes[l - 1]) {
first_target_index_of_embedding++;
}
if (conditionalSpikeTimes.length != condEmbedDims.length) {
throw new Exception("Number of conditional embedding lengths does not match the number of conditional processes");
}
for (int i = 0; i < conditionalSpikeTimes.length; i++) {
while (destSpikeTimes[first_target_index_of_embedding] <= conditionalSpikeTimes[i][condEmbedDims[i]]) {
first_target_index_of_embedding++;
}
}
//processTimeLengths.add(destSpikeTimes[sourceSpikeTimes.length - 1] - destSpikeTimes[first_target_index_of_embedding]);
processTimeLengths.add(destSpikeTimes[destSpikeTimes.length - 1] - destSpikeTimes[first_target_index_of_embedding]);
double sample_lower_bound = destSpikeTimes[first_target_index_of_embedding];
double sample_upper_bound = destSpikeTimes[destSpikeTimes.length - 1];
int num_samples = (int) Math.round(num_samples_multiplier * (destSpikeTimes.length - first_target_index_of_embedding + 1));
double[] randomSampleTimes = new double[num_samples];
Random rand = new Random();
for (int i = 0; i < randomSampleTimes.length; i++) {
randomSampleTimes[i] = sample_lower_bound + rand.nextDouble() * (sample_upper_bound - sample_lower_bound);
}
Arrays.sort(randomSampleTimes);
makeEmbeddingsAtPoints(destSpikeTimes, sourceSpikeTimes, destSpikeTimes, targetEmbeddingsFromSpikes, jointEmbeddingsFromSpikes);
makeEmbeddingsAtPoints(randomSampleTimes, sourceSpikeTimes, destSpikeTimes, targetEmbeddingsFromSamples, jointEmbeddingsFromSamples);
makeEmbeddingsAtPoints(destSpikeTimes, first_target_index_of_embedding, sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
ConditioningEmbeddingsFromSpikes, jointEmbeddingsFromSpikes);
makeEmbeddingsAtPoints(randomSampleTimes, 0, sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
ConditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples);
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getAddedMoreThanOneObservationSet()
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* getAddedMoreThanOneObservationSet()
*/
@Override
public boolean getAddedMoreThanOneObservationSet() {
return (vectorOfDestinationSpikeTimes != null) &&
(vectorOfDestinationSpikeTimes.size() > 1);
return (vectorOfDestinationSpikeTimes != null) && (vectorOfDestinationSpikeTimes.size() > 1);
}
private double max_neighbour_distance(PriorityQueue<NeighbourNodeData> nnPQ) {
double max_val = -1e9;
while (nnPQ.peek() != null) {
NeighbourNodeData nnData = nnPQ.poll();
if (nnData.norms[0] > max_val) {
max_val = nnData.norms[0];
// Class to allow returning two values in the subsequent method
private static class distanceAndNumPoints {
public double distance;
public int numPoints;
public distanceAndNumPoints(double distance, int numPoints) {
this.distance = distance;
this.numPoints = numPoints;
}
}
return max_val;
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#computeAverageLocalOfObservations()
private distanceAndNumPoints findMaxDistanceAndNumPointsFromIndices(double[] point, int[] indices, Vector<double[]> setOfPoints) {
double maxDistance = 0;
int i = 0;
for (; indices[i] != -1; i++) {
double distance = KdTree.norm(point, setOfPoints.elementAt(indices[i]), normType);
if (distance > maxDistance) {
maxDistance = distance;
}
}
return new distanceAndNumPoints(maxDistance, i);
}
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* computeAverageLocalOfObservations()
*/
@Override
public double computeAverageLocalOfObservations() throws Exception {
double currentSum = 0;
for (int i = 0; i < targetEmbeddingsFromSpikes.size(); i++) {
for (int i = 0; i < ConditioningEmbeddingsFromSpikes.size(); i++) {
PriorityQueue<NeighbourNodeData> nnPQJointSpikes =
kdTreeJointAtSpikes.findKNearestNeighbours(Knns + 1, new double[][] {jointEmbeddingsFromSpikes.elementAt(i)});
PriorityQueue<NeighbourNodeData> nnPQJointSamples =
kdTreeJointAtSamples.findKNearestNeighbours(Knns, new double[][] {jointEmbeddingsFromSpikes.elementAt(i)});
PriorityQueue<NeighbourNodeData> nnPQConditioningSpikes =
kdTreeConditioningAtSpikes.findKNearestNeighbours(Knns + 1, new double[][] {targetEmbeddingsFromSpikes.elementAt(i)});
PriorityQueue<NeighbourNodeData> nnPQConditioningSamples =
kdTreeConditioningAtSamples.findKNearestNeighbours(Knns, new double[][] {targetEmbeddingsFromSpikes.elementAt(i)});
double radiusJointSpikes = kdTreeJointAtSpikes.findKNearestNeighbours(Knns, i).poll().norms[0];
double radiusJointSamples = kdTreeJointAtSamples.findKNearestNeighbours(Knns,
new double[][] { jointEmbeddingsFromSpikes.elementAt(i) }).poll().norms[0];
double radiusJointSpikes = max_neighbour_distance(nnPQJointSpikes);
double radiusJointSamples = max_neighbour_distance(nnPQJointSamples);
double radiusConditioningSpikes = max_neighbour_distance(nnPQConditioningSpikes);
double radiusConditioningSamples = max_neighbour_distance(nnPQConditioningSamples);
/*
The algorithm specified in box 1 of doi.org/10.1371/journal.pcbi.1008054 specifies finding the maximum of the two radii
just calculated and then redoing the searches in both sets at this radius. In this implementation, however, we make use
of the fact that one radius is equal to the maximum, and so only one search needs to be redone.
*/
double eps = 0.01;
// Need variables for the number of neighbours as this is now variable within the maximum radius
int kJointSpikes = 0;
int kJointSamples = 0;
if (radiusJointSpikes >= radiusJointSamples) {
/*
The maximum was the radius in the set of embeddings at spikes, so redo search in the set of embeddings at randomly
sampled points, using this larger radius.
*/
kJointSpikes = Knns;
int[] indicesWithinR = new int[jointEmbeddingsFromSamples.size()];
boolean[] isWithinR = new boolean[jointEmbeddingsFromSamples.size()];
kdTreeJointAtSamples.findPointsWithinR(radiusJointSpikes + eps,
new double[][] { jointEmbeddingsFromSpikes.elementAt(i) },
true,
isWithinR,
indicesWithinR);
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(jointEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
jointEmbeddingsFromSamples);
kJointSamples = temp.numPoints;
radiusJointSamples = temp.distance;
} else {
/*
The maximum was the radius in the set of embeddings at randomly sampled points, so redo search in the set of embeddings
at spikes, using this larger radius.
*/
kJointSamples = Knns;
int[] indicesWithinR = new int[jointEmbeddingsFromSamples.size()];
boolean[] isWithinR = new boolean[jointEmbeddingsFromSamples.size()];
kdTreeJointAtSpikes.findPointsWithinR(radiusJointSamples + eps,
new double[][] { jointEmbeddingsFromSpikes.elementAt(i) },
true,
isWithinR,
indicesWithinR);
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(jointEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
jointEmbeddingsFromSpikes);
// -1 due to the point itself being in the set
kJointSpikes = temp.numPoints - 1;
radiusJointSpikes = temp.distance;
}
currentSum += ((k + l) * (- Math.log(radiusJointSpikes) + Math.log(radiusJointSamples))
+ k * (Math.log(radiusConditioningSpikes) - Math.log(radiusConditioningSamples)));
// Repeat the above steps, but in the conditioning (rather than joint) space.
double radiusConditioningSpikes = kdTreeConditioningAtSpikes.findKNearestNeighbours(Knns, i).poll().norms[0];
double radiusConditioningSamples = kdTreeConditioningAtSamples.findKNearestNeighbours(Knns,
new double[][] { ConditioningEmbeddingsFromSpikes.elementAt(i) }).poll().norms[0];
int kConditioningSpikes = 0;
int kConditioningSamples = 0;
if (radiusConditioningSpikes >= radiusConditioningSamples) {
kConditioningSpikes = Knns;
int[] indicesWithinR = new int[ConditioningEmbeddingsFromSamples.size()];
boolean[] isWithinR = new boolean[ConditioningEmbeddingsFromSamples.size()];
kdTreeConditioningAtSamples.findPointsWithinR(radiusConditioningSpikes + eps,
new double[][] { ConditioningEmbeddingsFromSpikes.elementAt(i) },
true,
isWithinR,
indicesWithinR);
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(ConditioningEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
ConditioningEmbeddingsFromSamples);
kConditioningSamples = temp.numPoints;
radiusConditioningSamples = temp.distance;
} else {
kConditioningSamples = Knns;
int[] indicesWithinR = new int[ConditioningEmbeddingsFromSamples.size()];
boolean[] isWithinR = new boolean[ConditioningEmbeddingsFromSamples.size()];
kdTreeConditioningAtSpikes.findPointsWithinR(radiusConditioningSamples + eps,
new double[][] { ConditioningEmbeddingsFromSpikes.elementAt(i) },
true,
isWithinR,
indicesWithinR);
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(ConditioningEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
ConditioningEmbeddingsFromSpikes);
// -1 due to the point itself being in the set
kConditioningSpikes = temp.numPoints - 1;
radiusConditioningSpikes = temp.distance;
}
currentSum += (MathsUtils.digamma(kJointSpikes) - MathsUtils.digamma(kJointSamples) +
((k + l) * (-Math.log(radiusJointSpikes) + Math.log(radiusJointSamples))) -
MathsUtils.digamma(kConditioningSpikes) + MathsUtils.digamma(kConditioningSamples) +
+ (k * (Math.log(radiusConditioningSpikes) - Math.log(radiusConditioningSamples))));
if (Double.isNaN(currentSum)) {
throw new Exception(kJointSpikes + " " + kJointSamples + " " + kConditioningSpikes + " " + kConditioningSamples + "\n" +
radiusJointSpikes + " " + radiusJointSamples + " " + radiusConditioningSpikes + " " + radiusConditioningSamples);
}
}
// Normalise by time
currentSum /= (vectorOfDestinationSpikeTimes.elementAt(0)[vectorOfDestinationSpikeTimes.elementAt(0).length - 1]
- vectorOfDestinationSpikeTimes.elementAt(0)[0]);
double time_sum = 0;
for (Double time : processTimeLengths) {
time_sum += time;
}
currentSum /= time_sum;
return currentSum;
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#computeLocalOfPreviousObservations()
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* computeLocalOfPreviousObservations()
*/
@Override
public SpikingLocalInformationValues computeLocalOfPreviousObservations()
throws Exception {
public SpikingLocalInformationValues computeLocalOfPreviousObservations() throws Exception {
// TODO Auto-generated method stub
return null;
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#computeSignificance(int)
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* computeSignificance(int)
*/
@Override
public EmpiricalMeasurementDistribution computeSignificance(
int numPermutationsToCheck) throws Exception {
public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck) throws Exception {
// TODO Auto-generated method stub
return null;
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#computeSignificance(int[][])
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* computeSignificance(int[][])
*/
@Override
public EmpiricalMeasurementDistribution computeSignificance(
int[][] newOrderings) throws Exception {
public EmpiricalMeasurementDistribution computeSignificance(int[][] newOrderings) throws Exception {
// TODO Auto-generated method stub
return null;
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setDebug(boolean)
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setDebug(
* boolean)
*/
@Override
public void setDebug(boolean debug) {
this.debug = debug;
}
/* (non-Javadoc)
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getLastAverage()
/*
* (non-Javadoc)
*
* @see
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getLastAverage
* ()
*/
@Override
public double getLastAverage() {

View File

@ -337,7 +337,7 @@ public class KdTree extends NearestNeighbourSearcher {
double difference = x1[d] - x2[d];
distance += difference * difference;
}
return distance;
return Math.sqrt(distance);
}
}
@ -359,6 +359,7 @@ public class KdTree extends NearestNeighbourSearcher {
*/
public final static double normWithAbort(double[] x1, double[] x2,
double limit, int normToUse) {
double distance = 0.0;
switch (normToUse) {
case EuclideanUtils.NORM_MAX_NORM:
@ -379,6 +380,8 @@ public class KdTree extends NearestNeighbourSearcher {
return distance;
// case EuclideanUtils.NORM_EUCLIDEAN_SQUARED:
default:
// Limit is often r, so must square
limit = limit * limit;
// Inlined from {@link EuclideanUtils}:
for (int d = 0; d < x1.length; d++) {
double difference = x1[d] - x2[d];
@ -387,7 +390,7 @@ public class KdTree extends NearestNeighbourSearcher {
return Double.POSITIVE_INFINITY;
}
}
return distance;
return Math.sqrt(distance);
}
}
@ -1305,9 +1308,10 @@ public class KdTree extends NearestNeighbourSearcher {
if (normTypeToUse == EuclideanUtils.NORM_MAX_NORM) {
absDistOnThisDim = (distOnThisDim > 0) ? distOnThisDim : - distOnThisDim;
} else {
absDistOnThisDim = (distOnThisDim > 0) ? distOnThisDim : - distOnThisDim;
// norm type is EuclideanUtils#NORM_EUCLIDEAN_SQUARED
// Track the square distance
absDistOnThisDim = distOnThisDim * distOnThisDim;
//absDistOnThisDim = distOnThisDim;
}
if ((node.indexOfThisPoint != sampleIndex) &&
@ -2053,7 +2057,7 @@ public class KdTree extends NearestNeighbourSearcher {
KdTreeNode node, int level, double r, boolean allowEqualToR,
boolean[] isWithinR, int[] indicesWithinR,
int nextIndexInIndicesWithinR) {
//System.out.println("foo");
// Point to the correct array for the data at this level
int currentDim = level % totalDimensions;
double[][] data = dimensionToArray[currentDim];
@ -2068,9 +2072,9 @@ public class KdTree extends NearestNeighbourSearcher {
if (normTypeToUse == EuclideanUtils.NORM_MAX_NORM) {
absDistOnThisDim = (distOnThisDim > 0) ? distOnThisDim : - distOnThisDim;
} else {
absDistOnThisDim = (distOnThisDim > 0) ? distOnThisDim : - distOnThisDim;
// norm type is EuclideanUtils#NORM_EUCLIDEAN_SQUARED
// Track the square distance
absDistOnThisDim = distOnThisDim * distOnThisDim;
}
if ((absDistOnThisDim < r) ||

View File

@ -27,15 +27,14 @@ import os
import numpy as np
NUM_REPS = 20
NUM_SPIKES = int(1e4)
NUM_REPS = 5
NUM_SPIKES = int(5e3)
NUM_OBSERVATIONS = 2
# Params for canonical example generation
RATE_Y = 1.0
RATE_X_MAX = 10
def generate_canonical_example_processes(num_y_events):
event_train_x = []
event_train_x.append(0)
@ -76,30 +75,97 @@ if (not(os.path.isfile(jarLocation))):
# Start the JVM (add the "-Xmx" option with say 1024M if you get crashes due to not enough memory space)
startJVM(getDefaultJVMPath(), "-ea", "-Djava.class.path=" + jarLocation)
teCalcClass = JPackage("infodynamics.measures.spiking.integration").TransferEntropyCalculatorSpikingIntegration
teCalc = teCalcClass()
teCalc.setProperty("knns", "4")
print("Independent Poisson Processes")
teCalc.setProperty("k_HISTORY", "1")
teCalc.setProperty("l_HISTORY", "1")
teCalc.setProperty("COND_EMBED_LENGTHS", "2,2")
teCalc.setProperty("k_HISTORY", "2")
teCalc.setProperty("l_HISTORY", "2")
teCalc.setProperty("NORM_TYPE", "MAX_NORM")
results_poisson = np.zeros(NUM_REPS)
for i in range(NUM_REPS):
teCalc.startAddObservations()
for j in range(NUM_OBSERVATIONS):
sourceArray = NUM_SPIKES*np.random.random(NUM_SPIKES)
sourceArray.sort()
destArray = NUM_SPIKES*np.random.random(NUM_SPIKES)
destArray.sort()
teCalc.setObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray))
condArray = NUM_SPIKES*np.random.random((2, NUM_SPIKES))
condArray.sort(axis = 1)
teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray), JArray(JDouble, 2)(condArray))
teCalc.finaliseAddObservations();
result = teCalc.computeAverageLocalOfObservations()
print("TE result %.4f nats" % (result,))
results_poisson[i] = result
print("Summary: mean ", np.mean(results_poisson), " std dev ", np.std(results_poisson))
teCalc = teCalcClass()
teCalc.setProperty("knns", "4")
print("Noisy copy zero TE")
teCalc.setProperty("COND_EMBED_LENGTHS", "2")
teCalc.setProperty("k_HISTORY", "2")
teCalc.setProperty("l_HISTORY", "2")
results_noisy_zero = np.zeros(NUM_REPS)
for i in range(NUM_REPS):
teCalc.startAddObservations()
for j in range(NUM_OBSERVATIONS):
condArray = NUM_SPIKES*np.random.random((1, NUM_SPIKES))
condArray.sort(axis = 1)
sourceArray = condArray[0, :] + 0.25 + 0.1 * np.random.normal(size = condArray.shape[1])
sourceArray.sort()
destArray = condArray[0, :] + 0.5 + 0.1 * np.random.normal(size = condArray.shape[1])
destArray.sort()
teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray), JArray(JDouble, 2)(condArray))
teCalc.finaliseAddObservations();
result = teCalc.computeAverageLocalOfObservations()
print("TE result %.4f nats" % (result,))
results_poisson[i] = result
print("Summary: mean ", np.mean(results_poisson), " std dev ", np.std(results_poisson))
teCalc = teCalcClass()
teCalc.setProperty("knns", "4")
print("Noisy copy non-zero TE")
teCalc.setProperty("COND_EMBED_LENGTHS", "2")
teCalc.setProperty("k_HISTORY", "2")
teCalc.setProperty("l_HISTORY", "2")
results_noisy_zero = np.zeros(NUM_REPS)
for i in range(NUM_REPS):
teCalc.startAddObservations()
for j in range(NUM_OBSERVATIONS):
sourceArray = NUM_SPIKES*np.random.random(NUM_SPIKES)
sourceArray.sort()
condArray = sourceArray + 0.25 + 0.1 * np.random.normal(size = sourceArray.shape)
condArray.sort()
condArray = np.expand_dims(condArray, 0)
destArray = sourceArray + 0.5 + 0.1 * np.random.normal(size = sourceArray.shape)
destArray.sort()
teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray), JArray(JDouble, 2)(condArray))
teCalc.finaliseAddObservations();
result = teCalc.computeAverageLocalOfObservations()
print("TE result %.4f nats" % (result,))
results_poisson[i] = result
print("Summary: mean ", np.mean(results_poisson), " std dev ", np.std(results_poisson))
print("Canonical example")
teCalc = teCalcClass()
teCalc.setProperty("knns", "4")
teCalc.setProperty("k_HISTORY", "2")
teCalc.setProperty("l_HISTORY", "1")
#teCalc.setProperty("NUM_SAMPLES_MULTIPLIER", "1")
#teCalc.setProperty("NORM_TYPE", "MAX_NORM")
results_canonical = np.zeros(NUM_REPS)
for i in range(NUM_REPS):