diff --git a/java/source/infodynamics/measures/spiking/TransferEntropyCalculatorSpiking.java b/java/source/infodynamics/measures/spiking/TransferEntropyCalculatorSpiking.java index 5aaaf9f..f9eb40c 100755 --- a/java/source/infodynamics/measures/spiking/TransferEntropyCalculatorSpiking.java +++ b/java/source/infodynamics/measures/spiking/TransferEntropyCalculatorSpiking.java @@ -282,68 +282,12 @@ public interface TransferEntropyCalculatorSpiking { * will vary depending on the underlying implementation */ public SpikingLocalInformationValues computeLocalOfPreviousObservations() throws Exception; - - /** - * Generate a bootstrapped distribution of what the TE would look like, - * under a null hypothesis that the source values of our - * samples had no temporal relation to the destination value. - * - *

See Section II.E "Statistical significance testing" of - * the JIDT paper below for a description of how this is done for MI and TE in general. - *

- * - *

Note that if several disjoint time-series have been added - * as observations using {@link #addObservations(double[])} etc., - * then these separate "trials" will be mixed up in the generation - * of surrogates here.

- * - *

This method (in contrast to {@link #computeSignificance(int[][])}) - * creates random shufflings of the source embedding vectors for the surrogate - * calculations.

- * - * @param numPermutationsToCheck number of surrogate samples to bootstrap - * to generate the distribution. - * @return the distribution of TE scores under this null hypothesis. - * @see "J.T. Lizier, 'JIDT: An information-theoretic - * toolkit for studying the dynamics of complex systems', 2014." - */ - public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck) throws Exception; - - /** - * Generate a bootstrapped distribution of what the TE would look like, - * under a null hypothesis that the source values of our - * samples had no relation to the destination value. - * - *

See Section II.E "Statistical significance testing" of - * the JIDT paper below for a description of how this is done for MI and TE. - *

- * - *

Note that if several disjoint time-series have been added - * as observations using {@link #addObservations(double[])} etc., - * then these separate "trials" will be mixed up in the generation - * of surrogates here.

- * - *

This method (in contrast to {@link #computeSignificance(int)}) - * allows the user to specify how to construct the surrogates, - * such that repeatable results may be obtained.

- * - * @param newOrderings a specification of how to shuffle the source embedding vectors - * between all of our samples to create the surrogates to generate the distribution with. The first - * index is the permutation number (i.e. newOrderings.length is the number - * of surrogate samples we use to bootstrap to generate the distribution here.) - * Each array newOrderings[i] should be an array of length N (where - * would be the value returned by {@link #getNumObservations()}), - * containing a permutation of the values in 0..(N-1). - * TODO Need to think this through a little more before implementing. - * @return the distribution of channel measure scores under this null hypothesis. - * @see "J.T. Lizier, 'JIDT: An information-theoretic - * toolkit for studying the dynamics of complex systems', 2014." - * @throws Exception where the length of each permutation in newOrderings - * is not equal to the number N samples that were previously supplied. - */ - public EmpiricalMeasurementDistribution computeSignificance( - int[][] newOrderings) throws Exception; + public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck, double estimatedValue) throws Exception; + + public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck, + double estimatedValue, long randomSeed) throws Exception; + /** * Set or clear debug mode for extra debug printing to stdout * diff --git a/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegration.java b/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegration.java index 103cd27..89b6a38 100644 --- a/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegration.java +++ b/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegration.java @@ -6,6 +6,7 @@ import java.util.Iterator; import java.util.PriorityQueue; import java.util.Random; import java.util.Vector; +import java.util.ArrayList; //import infodynamics.measures.continuous.kraskov.EuclideanUtils; import infodynamics.measures.spiking.TransferEntropyCalculatorSpiking; @@ -32,6 +33,14 @@ import infodynamics.utils.ParsedProperties; * @author Joseph Lizier (email, * www) */ + +/* + * TODO + * This implementation of the estimator does not implement dynamic exclusion windows. Such windows make sure + * that history embeddings that overlap are not considered in nearest-neighbour searches (as this breaks the + * independece assumption). Getting dynamic exclusion windows working will probably require modifications to the + * KdTree class. + */ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntropyCalculatorSpiking { /** @@ -76,9 +85,9 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr */ protected Vector vectorOfConditionalSpikeTimes = null; - Vector ConditioningEmbeddingsFromSpikes = null; + Vector conditioningEmbeddingsFromSpikes = null; Vector jointEmbeddingsFromSpikes = null; - Vector ConditioningEmbeddingsFromSamples = null; + Vector conditioningEmbeddingsFromSamples = null; Vector jointEmbeddingsFromSamples = null; Vector processTimeLengths = null; @@ -114,7 +123,21 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr * of the number of target spikes. */ public static final String PROP_SAMPLE_MULTIPLIER = "NUM_SAMPLES_MULTIPLIER"; - protected double num_samples_multiplier = 1.0; + protected double numSamplesMultiplier = 1.0; + /** + * Property name for the number of random sample points to use in the construction of the surrogates as a multiple + * of the number of target spikes. + */ + public static final String PROP_SURROGATE_SAMPLE_MULTIPLIER = "SURROGATE_NUM_SAMPLES_MULTIPLIER"; + protected double surrogateNumSamplesMultiplier = 1.0; + + /** + * Property for the number of nearest neighbours to use in the construction of the surrogates + */ + public static final String PROP_K_PERM = "K_PERM"; + protected int kPerm = 10; + + /** * Property name for what type of norm to use between data points * for each marginal variable -- Options are defined by @@ -182,29 +205,31 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr public void setProperty(String propertyName, String propertyValue) throws Exception { boolean propertySet = true; if (propertyName.equalsIgnoreCase(K_PROP_NAME)) { - int k_temp = Integer.parseInt(propertyValue); - if (k_temp < 1) { + int kTemp = Integer.parseInt(propertyValue); + if (kTemp < 1) { throw new Exception ("Invalid k value less than 1."); } else { - k = k_temp; + k = kTemp; } } else if (propertyName.equalsIgnoreCase(L_PROP_NAME)) { - int l_temp = Integer.parseInt(propertyValue); - if (l_temp < 1) { + int lTemp = Integer.parseInt(propertyValue); + if (lTemp < 1) { throw new Exception ("Invalid l value less than 1."); } else { - l = l_temp; + l = lTemp; } } else if (propertyName.equalsIgnoreCase(COND_EMBED_LENGTHS_PROP_NAME)) { - int[] condEmbedDims_temp = ParsedProperties.parseStringArrayOfInts(propertyValue); - for (int dim : condEmbedDims_temp) { + int[] condEmbedDimsTemp = ParsedProperties.parseStringArrayOfInts(propertyValue); + for (int dim : condEmbedDimsTemp) { if (dim < 1) { throw new Exception ("Invalid conditional embedding value less than 1."); } } - condEmbedDims = condEmbedDims_temp; + condEmbedDims = condEmbedDimsTemp; } else if (propertyName.equalsIgnoreCase(KNNS_PROP_NAME)) { Knns = Integer.parseInt(propertyValue); + } else if (propertyName.equalsIgnoreCase(PROP_K_PERM)) { + kPerm = Integer.parseInt(propertyValue); } else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) { if (propertyValue.equals("0") || propertyValue.equalsIgnoreCase("false")) { addNoise = false; @@ -215,14 +240,21 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr } } else if (propertyName.equalsIgnoreCase(PROP_SAMPLE_MULTIPLIER)) { - double temp_num_samples_multiplier = Double.parseDouble(propertyValue); - if (temp_num_samples_multiplier <= 0) { + double tempNumSamplesMultiplier = Double.parseDouble(propertyValue); + if (tempNumSamplesMultiplier <= 0) { throw new Exception ("Num samples multiplier must be greater than 0."); } else { - num_samples_multiplier = temp_num_samples_multiplier; + numSamplesMultiplier = tempNumSamplesMultiplier; } } else if (propertyName.equalsIgnoreCase(PROP_NORM_TYPE)) { normType = KdTree.validateNormType(propertyValue); + } else if (propertyName.equalsIgnoreCase(PROP_SURROGATE_SAMPLE_MULTIPLIER)) { + double tempSurrogateNumSamplesMultiplier = Double.parseDouble(propertyValue); + if (tempSurrogateNumSamplesMultiplier <= 0) { + throw new Exception ("Surrogate Num samples multiplier must be greater than 0."); + } else { + surrogateNumSamplesMultiplier = tempSurrogateNumSamplesMultiplier; + } } else { // No property was set on this class propertySet = false; @@ -251,7 +283,7 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr } else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) { return Double.toString(noiseLevel); } else if (propertyName.equalsIgnoreCase(PROP_SAMPLE_MULTIPLIER)) { - return Double.toString(num_samples_multiplier); + return Double.toString(numSamplesMultiplier); } else { // No property matches for this class return null; @@ -319,45 +351,43 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr @Override public void finaliseAddObservations() throws Exception { - ConditioningEmbeddingsFromSpikes = new Vector(); + conditioningEmbeddingsFromSpikes = new Vector(); jointEmbeddingsFromSpikes = new Vector(); - ConditioningEmbeddingsFromSamples = new Vector(); + conditioningEmbeddingsFromSamples = new Vector(); jointEmbeddingsFromSamples = new Vector(); processTimeLengths = new Vector(); // Send all of the observations through: Iterator sourceIterator = vectorOfSourceSpikeTimes.iterator(); - int timeSeriesIndex = 0; + Iterator conditionalIterator = null; if (vectorOfConditionalSpikeTimes.size() > 0) { - Iterator conditionalIterator = vectorOfConditionalSpikeTimes.iterator(); - for (double[] destSpikeTimes : vectorOfDestinationSpikeTimes) { - double[] sourceSpikeTimes = sourceIterator.next(); - double[][] conditionalSpikeTimes = conditionalIterator.next(); - processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, ConditioningEmbeddingsFromSpikes, - jointEmbeddingsFromSpikes, ConditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples, - processTimeLengths); - } - } else { - for (double[] destSpikeTimes : vectorOfDestinationSpikeTimes) { - double[] sourceSpikeTimes = sourceIterator.next(); - double[][] conditionalSpikeTimes = new double[][] {}; - processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, ConditioningEmbeddingsFromSpikes, - jointEmbeddingsFromSpikes, ConditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples, - processTimeLengths); + conditionalIterator = vectorOfConditionalSpikeTimes.iterator(); + } + int timeSeriesIndex = 0; + for (double[] destSpikeTimes : vectorOfDestinationSpikeTimes) { + double[] sourceSpikeTimes = sourceIterator.next(); + double[][] conditionalSpikeTimes = null; + if (vectorOfConditionalSpikeTimes.size() > 0) { + conditionalSpikeTimes = conditionalIterator.next(); + } else { + conditionalSpikeTimes = new double[][] {}; } + processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, conditioningEmbeddingsFromSpikes, + jointEmbeddingsFromSpikes, conditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples, + numSamplesMultiplier); } // Convert the vectors to arrays so that they can be put in the trees - double[][] arrayedTargetEmbeddingsFromSpikes = new double[ConditioningEmbeddingsFromSpikes.size()][k]; - double[][] arrayedJointEmbeddingsFromSpikes = new double[ConditioningEmbeddingsFromSpikes.size()][k + l]; - for (int i = 0; i < ConditioningEmbeddingsFromSpikes.size(); i++) { - arrayedTargetEmbeddingsFromSpikes[i] = ConditioningEmbeddingsFromSpikes.elementAt(i); + double[][] arrayedTargetEmbeddingsFromSpikes = new double[conditioningEmbeddingsFromSpikes.size()][k]; + double[][] arrayedJointEmbeddingsFromSpikes = new double[conditioningEmbeddingsFromSpikes.size()][k + l]; + for (int i = 0; i < conditioningEmbeddingsFromSpikes.size(); i++) { + arrayedTargetEmbeddingsFromSpikes[i] = conditioningEmbeddingsFromSpikes.elementAt(i); arrayedJointEmbeddingsFromSpikes[i] = jointEmbeddingsFromSpikes.elementAt(i); } - double[][] arrayedTargetEmbeddingsFromSamples = new double[ConditioningEmbeddingsFromSamples.size()][k]; - double[][] arrayedJointEmbeddingsFromSamples = new double[ConditioningEmbeddingsFromSamples.size()][k + l]; - for (int i = 0; i < ConditioningEmbeddingsFromSamples.size(); i++) { - arrayedTargetEmbeddingsFromSamples[i] = ConditioningEmbeddingsFromSamples.elementAt(i); + double[][] arrayedTargetEmbeddingsFromSamples = new double[conditioningEmbeddingsFromSamples.size()][k]; + double[][] arrayedJointEmbeddingsFromSamples = new double[conditioningEmbeddingsFromSamples.size()][k + l]; + for (int i = 0; i < conditioningEmbeddingsFromSamples.size(); i++) { + arrayedTargetEmbeddingsFromSamples[i] = conditioningEmbeddingsFromSamples.elementAt(i); arrayedJointEmbeddingsFromSamples[i] = jointEmbeddingsFromSamples.elementAt(i); } @@ -372,48 +402,48 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr kdTreeConditioningAtSamples.setNormType(normType); } - protected void makeEmbeddingsAtPoints(double[] pointsAtWhichToMakeEmbeddings, int index_of_first_point_to_use, + protected void makeEmbeddingsAtPoints(double[] pointsAtWhichToMakeEmbeddings, int indexOfFirstPointToUse, double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes, - Vector ConditioningEmbeddings, + Vector conditioningEmbeddings, Vector jointEmbeddings) { Random random = new Random(); - int embedding_point_index = index_of_first_point_to_use; - int most_recent_dest_index = k; - int most_recent_source_index = l; - int[] most_recent_conditioning_indices = Arrays.copyOf(condEmbedDims, condEmbedDims.length); - int total_length_of_conditioning_embeddings = 0; + int embeddingPointIndex = indexOfFirstPointToUse; + int mostRecentDestIndex = k; + int mostRecentSourceIndex = l; + int[] mostRecentConditioningIndices = Arrays.copyOf(condEmbedDims, condEmbedDims.length); + int totalLengthOfConditioningEmbeddings = 0; for (int i = 0; i < condEmbedDims.length; i++) { - total_length_of_conditioning_embeddings += condEmbedDims[i]; + totalLengthOfConditioningEmbeddings += condEmbedDims[i]; } // Loop through the points at which embeddings need to be made - for (; embedding_point_index < pointsAtWhichToMakeEmbeddings.length; embedding_point_index++) { + for (; embeddingPointIndex < pointsAtWhichToMakeEmbeddings.length; embeddingPointIndex++) { // Advance the tracker of the most recent dest index - while (most_recent_dest_index < (destSpikeTimes.length - 1)) { - if (destSpikeTimes[most_recent_dest_index + 1] < pointsAtWhichToMakeEmbeddings[embedding_point_index]) { - most_recent_dest_index++; + while (mostRecentDestIndex < (destSpikeTimes.length - 1)) { + if (destSpikeTimes[mostRecentDestIndex + 1] < pointsAtWhichToMakeEmbeddings[embeddingPointIndex]) { + mostRecentDestIndex++; } else { break; } } // Do the same for the most recent source index - while (most_recent_source_index < (sourceSpikeTimes.length - 1)) { - if (sourceSpikeTimes[most_recent_source_index - + 1] < pointsAtWhichToMakeEmbeddings[embedding_point_index]) { - most_recent_source_index++; + while (mostRecentSourceIndex < (sourceSpikeTimes.length - 1)) { + if (sourceSpikeTimes[mostRecentSourceIndex + + 1] < pointsAtWhichToMakeEmbeddings[embeddingPointIndex]) { + mostRecentSourceIndex++; } else { break; } } // Now advance the trackers for the most recent conditioning indices - for (int j = 0; j < most_recent_conditioning_indices.length; j++) { - while (most_recent_conditioning_indices[j] < (conditionalSpikeTimes[j].length - 1)) { - if (conditionalSpikeTimes[j][most_recent_conditioning_indices[j] + 1] < pointsAtWhichToMakeEmbeddings[embedding_point_index]) { - most_recent_conditioning_indices[j]++; + for (int j = 0; j < mostRecentConditioningIndices.length; j++) { + while (mostRecentConditioningIndices[j] < (conditionalSpikeTimes[j].length - 1)) { + if (conditionalSpikeTimes[j][mostRecentConditioningIndices[j] + 1] < pointsAtWhichToMakeEmbeddings[embeddingPointIndex]) { + mostRecentConditioningIndices[j]++; } else { break; } @@ -421,45 +451,45 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr } - double[] conditioningPast = new double[k + total_length_of_conditioning_embeddings]; - double[] jointPast = new double[k + total_length_of_conditioning_embeddings + l]; + double[] conditioningPast = new double[k + totalLengthOfConditioningEmbeddings]; + double[] jointPast = new double[k + totalLengthOfConditioningEmbeddings + l]; // Add the embedding intervals from the target process - conditioningPast[0] = pointsAtWhichToMakeEmbeddings[embedding_point_index] - destSpikeTimes[most_recent_dest_index]; - jointPast[0] = pointsAtWhichToMakeEmbeddings[embedding_point_index] - - destSpikeTimes[most_recent_dest_index]; + conditioningPast[0] = pointsAtWhichToMakeEmbeddings[embeddingPointIndex] - destSpikeTimes[mostRecentDestIndex]; + jointPast[0] = pointsAtWhichToMakeEmbeddings[embeddingPointIndex] + - destSpikeTimes[mostRecentDestIndex]; for (int i = 1; i < k; i++) { - conditioningPast[i] = destSpikeTimes[most_recent_dest_index - i + 1] - - destSpikeTimes[most_recent_dest_index - i]; - jointPast[i] = destSpikeTimes[most_recent_dest_index - i + 1] - - destSpikeTimes[most_recent_dest_index - i]; + conditioningPast[i] = destSpikeTimes[mostRecentDestIndex - i + 1] + - destSpikeTimes[mostRecentDestIndex - i]; + jointPast[i] = destSpikeTimes[mostRecentDestIndex - i + 1] + - destSpikeTimes[mostRecentDestIndex - i]; } // Add the embeding intervals from the conditional processes - int index_of_next_embedding_interval = k; + int indexOfNextEmbeddingInterval = k; for (int i = 0; i < condEmbedDims.length; i++) { - conditioningPast[index_of_next_embedding_interval] = - pointsAtWhichToMakeEmbeddings[embedding_point_index] - conditionalSpikeTimes[i][most_recent_conditioning_indices[i]]; - jointPast[index_of_next_embedding_interval] = - pointsAtWhichToMakeEmbeddings[embedding_point_index] - conditionalSpikeTimes[i][most_recent_conditioning_indices[i]]; - index_of_next_embedding_interval += 1; + conditioningPast[indexOfNextEmbeddingInterval] = + pointsAtWhichToMakeEmbeddings[embeddingPointIndex] - conditionalSpikeTimes[i][mostRecentConditioningIndices[i]]; + jointPast[indexOfNextEmbeddingInterval] = + pointsAtWhichToMakeEmbeddings[embeddingPointIndex] - conditionalSpikeTimes[i][mostRecentConditioningIndices[i]]; + indexOfNextEmbeddingInterval += 1; for (int j = 1; j < condEmbedDims[i]; j++) { - conditioningPast[index_of_next_embedding_interval] = - conditionalSpikeTimes[i][most_recent_conditioning_indices[i] - j + 1] - - conditionalSpikeTimes[i][most_recent_conditioning_indices[i] - j]; - jointPast[index_of_next_embedding_interval] = - conditionalSpikeTimes[i][most_recent_conditioning_indices[i] - j + 1] - - conditionalSpikeTimes[i][most_recent_conditioning_indices[i] - j]; - index_of_next_embedding_interval += 1; + conditioningPast[indexOfNextEmbeddingInterval] = + conditionalSpikeTimes[i][mostRecentConditioningIndices[i] - j + 1] - + conditionalSpikeTimes[i][mostRecentConditioningIndices[i] - j]; + jointPast[indexOfNextEmbeddingInterval] = + conditionalSpikeTimes[i][mostRecentConditioningIndices[i] - j + 1] - + conditionalSpikeTimes[i][mostRecentConditioningIndices[i] - j]; + indexOfNextEmbeddingInterval += 1; } } // Add the embedding intervals from the source process (this only gets added to the joint embeddings) - jointPast[k + total_length_of_conditioning_embeddings] = pointsAtWhichToMakeEmbeddings[embedding_point_index] - - sourceSpikeTimes[most_recent_source_index]; + jointPast[k + totalLengthOfConditioningEmbeddings] = pointsAtWhichToMakeEmbeddings[embeddingPointIndex] + - sourceSpikeTimes[mostRecentSourceIndex]; for (int i = 1; i < l; i++) { - jointPast[k + total_length_of_conditioning_embeddings + i] = sourceSpikeTimes[most_recent_source_index - i + 1] - - sourceSpikeTimes[most_recent_source_index - i]; + jointPast[k + totalLengthOfConditioningEmbeddings + i] = sourceSpikeTimes[mostRecentSourceIndex - i + 1] + - sourceSpikeTimes[mostRecentSourceIndex - i]; } // Add Gaussian noise, if necessary @@ -472,51 +502,85 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr } } - ConditioningEmbeddings.add(conditioningPast); + conditioningEmbeddings.add(conditioningPast); jointEmbeddings.add(jointPast); } } - protected void processEventsFromSpikingTimeSeries(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes, - Vector ConditioningEmbeddingsFromSpikes, Vector jointEmbeddingsFromSpikes, - Vector ConditioningEmbeddingsFromSamples, Vector jointEmbeddingsFromSamples, - Vector processTimeLengths) - throws Exception { + + + protected int getFirstDestIndex(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes, Boolean setProcessTimeLengths) + throws Exception{ // First sort the spike times in case they were not properly in ascending order: Arrays.sort(sourceSpikeTimes); Arrays.sort(destSpikeTimes); - int first_target_index_of_embedding = k; - while (destSpikeTimes[first_target_index_of_embedding] <= sourceSpikeTimes[l - 1]) { - first_target_index_of_embedding++; + int firstTargetIndexOfEMbedding = k; + while (destSpikeTimes[firstTargetIndexOfEMbedding] <= sourceSpikeTimes[l - 1]) { + firstTargetIndexOfEMbedding++; } if (conditionalSpikeTimes.length != condEmbedDims.length) { throw new Exception("Number of conditional embedding lengths does not match the number of conditional processes"); } for (int i = 0; i < conditionalSpikeTimes.length; i++) { - while (destSpikeTimes[first_target_index_of_embedding] <= conditionalSpikeTimes[i][condEmbedDims[i]]) { - first_target_index_of_embedding++; + while (destSpikeTimes[firstTargetIndexOfEMbedding] <= conditionalSpikeTimes[i][condEmbedDims[i]]) { + firstTargetIndexOfEMbedding++; } } - //processTimeLengths.add(destSpikeTimes[sourceSpikeTimes.length - 1] - destSpikeTimes[first_target_index_of_embedding]); - processTimeLengths.add(destSpikeTimes[destSpikeTimes.length - 1] - destSpikeTimes[first_target_index_of_embedding]); + // We don't want to reset these lengths when resampling for surrogates + if (setProcessTimeLengths) { + processTimeLengths.add(destSpikeTimes[destSpikeTimes.length - 1] - destSpikeTimes[firstTargetIndexOfEMbedding]); + } - double sample_lower_bound = destSpikeTimes[first_target_index_of_embedding]; - double sample_upper_bound = destSpikeTimes[destSpikeTimes.length - 1]; - int num_samples = (int) Math.round(num_samples_multiplier * (destSpikeTimes.length - first_target_index_of_embedding + 1)); + return firstTargetIndexOfEMbedding; + } + + protected double[] generateRandomSampleTimes(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes, + double actualNumSamplesMultiplier, int firstTargetIndexOfEMbedding) { + + double sampleLowerBound = destSpikeTimes[firstTargetIndexOfEMbedding]; + double sampleUpperBound = destSpikeTimes[destSpikeTimes.length - 1]; + int num_samples = (int) Math.round(actualNumSamplesMultiplier * (destSpikeTimes.length - firstTargetIndexOfEMbedding + 1)); double[] randomSampleTimes = new double[num_samples]; Random rand = new Random(); for (int i = 0; i < randomSampleTimes.length; i++) { - randomSampleTimes[i] = sample_lower_bound + rand.nextDouble() * (sample_upper_bound - sample_lower_bound); + randomSampleTimes[i] = sampleLowerBound + rand.nextDouble() * (sampleUpperBound - sampleLowerBound); } Arrays.sort(randomSampleTimes); - makeEmbeddingsAtPoints(destSpikeTimes, first_target_index_of_embedding, sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, - ConditioningEmbeddingsFromSpikes, jointEmbeddingsFromSpikes); + return randomSampleTimes; + } + + protected void processEventsFromSpikingTimeSeries(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes, + Vector conditioningEmbeddingsFromSpikes, Vector jointEmbeddingsFromSpikes, + Vector conditioningEmbeddingsFromSamples, Vector jointEmbeddingsFromSamples, + double actualNumSamplesMultiplier) + throws Exception { + + int firstTargetIndexOfEMbedding = getFirstDestIndex(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, true); + double[] randomSampleTimes = generateRandomSampleTimes(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, actualNumSamplesMultiplier, firstTargetIndexOfEMbedding); + + makeEmbeddingsAtPoints(destSpikeTimes, firstTargetIndexOfEMbedding, sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, + conditioningEmbeddingsFromSpikes, jointEmbeddingsFromSpikes); makeEmbeddingsAtPoints(randomSampleTimes, 0, sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, - ConditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples); + conditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples); + } + + /* + * Method to do the + */ + protected void processEventsFromSpikingTimeSeries(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes, + Vector conditioningEmbeddingsFromSamples, Vector jointEmbeddingsFromSamples, + double actualNumSamplesMultiplier) + throws Exception { + + int firstTargetIndexOfEMbedding = getFirstDestIndex(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, false); + double[] randomSampleTimes = generateRandomSampleTimes(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, actualNumSamplesMultiplier, firstTargetIndexOfEMbedding); + + makeEmbeddingsAtPoints(randomSampleTimes, 0, sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, + conditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples); } /* @@ -553,28 +617,31 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr return new distanceAndNumPoints(maxDistance, i); } - /* - * (non-Javadoc) - * - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking# - * computeAverageLocalOfObservations() - */ @Override public double computeAverageLocalOfObservations() throws Exception { + return computeAverageLocalOfObservations(kdTreeJointAtSpikes, jointEmbeddingsFromSpikes); + } + + + /* + * We take the actual joint tree at spikes (along with the associated embeddings) as an argument, as we will need to swap these out when + * computing surrogates. + */ + public double computeAverageLocalOfObservations(KdTree actualKdTreeJointAtSpikes, Vector actualJointEmbeddingsFromSpikes) throws Exception { double currentSum = 0; - for (int i = 0; i < ConditioningEmbeddingsFromSpikes.size(); i++) { + for (int i = 0; i < conditioningEmbeddingsFromSpikes.size(); i++) { - double radiusJointSpikes = kdTreeJointAtSpikes.findKNearestNeighbours(Knns, i).poll().norms[0]; + double radiusJointSpikes = actualKdTreeJointAtSpikes.findKNearestNeighbours(Knns, i).poll().norms[0]; double radiusJointSamples = kdTreeJointAtSamples.findKNearestNeighbours(Knns, - new double[][] { jointEmbeddingsFromSpikes.elementAt(i) }).poll().norms[0]; + new double[][] { actualJointEmbeddingsFromSpikes.elementAt(i) }).poll().norms[0]; /* The algorithm specified in box 1 of doi.org/10.1371/journal.pcbi.1008054 specifies finding the maximum of the two radii just calculated and then redoing the searches in both sets at this radius. In this implementation, however, we make use of the fact that one radius is equal to the maximum, and so only one search needs to be redone. */ - double eps = 0.01; + double eps = 0.0; // Need variables for the number of neighbours as this is now variable within the maximum radius int kJointSpikes = 0; int kJointSamples = 0; @@ -587,11 +654,11 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr int[] indicesWithinR = new int[jointEmbeddingsFromSamples.size()]; boolean[] isWithinR = new boolean[jointEmbeddingsFromSamples.size()]; kdTreeJointAtSamples.findPointsWithinR(radiusJointSpikes + eps, - new double[][] { jointEmbeddingsFromSpikes.elementAt(i) }, + new double[][] { actualJointEmbeddingsFromSpikes.elementAt(i) }, true, isWithinR, indicesWithinR); - distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(jointEmbeddingsFromSpikes.elementAt(i), indicesWithinR, + distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(actualJointEmbeddingsFromSpikes.elementAt(i), indicesWithinR, jointEmbeddingsFromSamples); kJointSamples = temp.numPoints; radiusJointSamples = temp.distance; @@ -603,13 +670,13 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr kJointSamples = Knns; int[] indicesWithinR = new int[jointEmbeddingsFromSamples.size()]; boolean[] isWithinR = new boolean[jointEmbeddingsFromSamples.size()]; - kdTreeJointAtSpikes.findPointsWithinR(radiusJointSamples + eps, - new double[][] { jointEmbeddingsFromSpikes.elementAt(i) }, + actualKdTreeJointAtSpikes.findPointsWithinR(radiusJointSamples + eps, + new double[][] { actualJointEmbeddingsFromSpikes.elementAt(i) }, true, isWithinR, indicesWithinR); - distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(jointEmbeddingsFromSpikes.elementAt(i), indicesWithinR, - jointEmbeddingsFromSpikes); + distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(actualJointEmbeddingsFromSpikes.elementAt(i), indicesWithinR, + actualJointEmbeddingsFromSpikes); // -1 due to the point itself being in the set kJointSpikes = temp.numPoints - 1; radiusJointSpikes = temp.distance; @@ -618,58 +685,160 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr // Repeat the above steps, but in the conditioning (rather than joint) space. double radiusConditioningSpikes = kdTreeConditioningAtSpikes.findKNearestNeighbours(Knns, i).poll().norms[0]; double radiusConditioningSamples = kdTreeConditioningAtSamples.findKNearestNeighbours(Knns, - new double[][] { ConditioningEmbeddingsFromSpikes.elementAt(i) }).poll().norms[0]; + new double[][] { conditioningEmbeddingsFromSpikes.elementAt(i) }).poll().norms[0]; int kConditioningSpikes = 0; int kConditioningSamples = 0; if (radiusConditioningSpikes >= radiusConditioningSamples) { kConditioningSpikes = Knns; - int[] indicesWithinR = new int[ConditioningEmbeddingsFromSamples.size()]; - boolean[] isWithinR = new boolean[ConditioningEmbeddingsFromSamples.size()]; + int[] indicesWithinR = new int[conditioningEmbeddingsFromSamples.size()]; + boolean[] isWithinR = new boolean[conditioningEmbeddingsFromSamples.size()]; kdTreeConditioningAtSamples.findPointsWithinR(radiusConditioningSpikes + eps, - new double[][] { ConditioningEmbeddingsFromSpikes.elementAt(i) }, + new double[][] { conditioningEmbeddingsFromSpikes.elementAt(i) }, true, isWithinR, indicesWithinR); - distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(ConditioningEmbeddingsFromSpikes.elementAt(i), indicesWithinR, - ConditioningEmbeddingsFromSamples); + distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(conditioningEmbeddingsFromSpikes.elementAt(i), indicesWithinR, + conditioningEmbeddingsFromSamples); kConditioningSamples = temp.numPoints; radiusConditioningSamples = temp.distance; } else { kConditioningSamples = Knns; - int[] indicesWithinR = new int[ConditioningEmbeddingsFromSamples.size()]; - boolean[] isWithinR = new boolean[ConditioningEmbeddingsFromSamples.size()]; + int[] indicesWithinR = new int[conditioningEmbeddingsFromSamples.size()]; + boolean[] isWithinR = new boolean[conditioningEmbeddingsFromSamples.size()]; kdTreeConditioningAtSpikes.findPointsWithinR(radiusConditioningSamples + eps, - new double[][] { ConditioningEmbeddingsFromSpikes.elementAt(i) }, + new double[][] { conditioningEmbeddingsFromSpikes.elementAt(i) }, true, isWithinR, indicesWithinR); - distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(ConditioningEmbeddingsFromSpikes.elementAt(i), indicesWithinR, - ConditioningEmbeddingsFromSpikes); + distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(conditioningEmbeddingsFromSpikes.elementAt(i), indicesWithinR, + conditioningEmbeddingsFromSpikes); // -1 due to the point itself being in the set kConditioningSpikes = temp.numPoints - 1; radiusConditioningSpikes = temp.distance; } + /* + * TODO + * The KdTree class defaults to the squared euclidean distance when the euclidean norm is specified. This is fine for Kraskov estimators + * (as the radii are never used, just the numbers of points within radii). It causes problems here though, as we do use the radii and the + * squared euclidean distance is not a distance metric. We get around this by just taking the square root here, but it might be better to + * fix this in the KdTree class. + */ + if (normType == EuclideanUtils.NORM_EUCLIDEAN) { + radiusJointSpikes = Math.sqrt(radiusJointSpikes); + radiusJointSamples = Math.sqrt(radiusJointSamples); + radiusConditioningSpikes = Math.sqrt(radiusConditioningSpikes); + radiusConditioningSamples = Math.sqrt(radiusConditioningSamples); + } - currentSum += (MathsUtils.digamma(kJointSpikes) - MathsUtils.digamma(kJointSamples) + ((k + l) * (-Math.log(radiusJointSpikes) + Math.log(radiusJointSamples))) - MathsUtils.digamma(kConditioningSpikes) + MathsUtils.digamma(kConditioningSamples) + + (k * (Math.log(radiusConditioningSpikes) - Math.log(radiusConditioningSamples)))); if (Double.isNaN(currentSum)) { - throw new Exception(kJointSpikes + " " + kJointSamples + " " + kConditioningSpikes + " " + kConditioningSamples + "\n" + - radiusJointSpikes + " " + radiusJointSamples + " " + radiusConditioningSpikes + " " + radiusConditioningSamples); + throw new Exception("NaNs in TE clac"); } } // Normalise by time - double time_sum = 0; + double timeSum = 0; for (Double time : processTimeLengths) { - time_sum += time; + timeSum += time; } - currentSum /= time_sum; + currentSum /= timeSum; return currentSum; } + + @Override + public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck, double estimatedValue) throws Exception { + return computeSignificance(numPermutationsToCheck, + estimatedValue, + System.currentTimeMillis()); + } + + @Override + public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck, + double estimatedValue, long randomSeed) throws Exception{ + + Random random = new Random(randomSeed); + double[] surrogateTEValues = new double[numPermutationsToCheck]; + + for (int permutationNumber = 0; permutationNumber < numPermutationsToCheck; permutationNumber++) { + Vector resampledConditioningEmbeddingsFromSamples = new Vector(); + Vector resampledJointEmbeddingsFromSamples = new Vector(); + + // Send all of the observations through: + Iterator sourceIterator = vectorOfSourceSpikeTimes.iterator(); + Iterator conditionalIterator = null; + if (vectorOfConditionalSpikeTimes.size() > 0) { + conditionalIterator = vectorOfConditionalSpikeTimes.iterator(); + } + int timeSeriesIndex = 0; + for (double[] destSpikeTimes : vectorOfDestinationSpikeTimes) { + double[] sourceSpikeTimes = sourceIterator.next(); + double[][] conditionalSpikeTimes = null; + if (vectorOfConditionalSpikeTimes.size() > 0) { + conditionalSpikeTimes = conditionalIterator.next(); + } else { + conditionalSpikeTimes = new double[][] {}; + } + processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, + resampledConditioningEmbeddingsFromSamples, resampledJointEmbeddingsFromSamples, + surrogateNumSamplesMultiplier); + } + + // Convert the vectors to arrays so that they can be put in the trees + double[][] arrayedResampledConditioningEmbeddingsFromSamples = new double[resampledConditioningEmbeddingsFromSamples.size()][k]; + for (int i = 0; i < resampledConditioningEmbeddingsFromSamples.size(); i++) { + arrayedResampledConditioningEmbeddingsFromSamples[i] = resampledConditioningEmbeddingsFromSamples.elementAt(i); + } + + KdTree resampledKdTreeConditioningAtSamples = new KdTree(arrayedResampledConditioningEmbeddingsFromSamples); + resampledKdTreeConditioningAtSamples.setNormType(normType); + + Vector conditionallyPermutedJointEmbeddingsFromSpikes = new Vector(jointEmbeddingsFromSpikes); + + Vector usedIndices = new Vector(); + for (int i = 0; i < conditionallyPermutedJointEmbeddingsFromSpikes.size(); i++) { + PriorityQueue neighbours = + resampledKdTreeConditioningAtSamples.findKNearestNeighbours(kPerm, + new double[][] {conditioningEmbeddingsFromSpikes.elementAt(i)}); + ArrayList foundIndices = new ArrayList(); + for (int j = 0; j < kPerm; j++) { + foundIndices.add(neighbours.poll().sampleIndex); + } + + ArrayList prunedIndices = new ArrayList(foundIndices); + prunedIndices.removeAll(usedIndices); + int chosenIndex = 0; + if (prunedIndices.size() > 0) { + chosenIndex = prunedIndices.get(random.nextInt(prunedIndices.size())); + } else { + chosenIndex = foundIndices.get(random.nextInt(foundIndices.size())); + } + usedIndices.add(chosenIndex); + int embeddingLength = conditionallyPermutedJointEmbeddingsFromSpikes.elementAt(i).length; + for(int j = 0; j < l; j++) { + conditionallyPermutedJointEmbeddingsFromSpikes.elementAt(i)[embeddingLength - l + j] = + resampledJointEmbeddingsFromSamples.elementAt(chosenIndex)[embeddingLength - l + j]; + } + + } + + double[][] arrayedConditionallyPermutedJointEmbeddingsFromSpikes = new double[conditionallyPermutedJointEmbeddingsFromSpikes.size()][]; + for (int i = 0; i < conditionallyPermutedJointEmbeddingsFromSpikes.size(); i++) { + arrayedConditionallyPermutedJointEmbeddingsFromSpikes[i] = conditionallyPermutedJointEmbeddingsFromSpikes.elementAt(i); + } + KdTree conditionallyPermutedKdTreeJointFromSpikes = new KdTree(arrayedConditionallyPermutedJointEmbeddingsFromSpikes); + conditionallyPermutedKdTreeJointFromSpikes.setNormType(normType); + + surrogateTEValues[permutationNumber] = computeAverageLocalOfObservations(conditionallyPermutedKdTreeJointFromSpikes, + conditionallyPermutedJointEmbeddingsFromSpikes); + } + return new EmpiricalMeasurementDistribution(surrogateTEValues, estimatedValue); + } + + /* * (non-Javadoc) * @@ -682,29 +851,6 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr return null; } - /* - * (non-Javadoc) - * - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking# - * computeSignificance(int) - */ - @Override - public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck) throws Exception { - // TODO Auto-generated method stub - return null; - } - - /* - * (non-Javadoc) - * - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking# - * computeSignificance(int[][]) - */ - @Override - public EmpiricalMeasurementDistribution computeSignificance(int[][] newOrderings) throws Exception { - // TODO Auto-generated method stub - return null; - } /* * (non-Javadoc) diff --git a/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegrationOldRepresentation.java b/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegrationOldRepresentation.java deleted file mode 100644 index d6dad78..0000000 --- a/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegrationOldRepresentation.java +++ /dev/null @@ -1,1103 +0,0 @@ -package infodynamics.measures.spiking.integration; - -import java.util.Arrays; -import java.util.Iterator; -import java.util.PriorityQueue; -import java.util.Vector; - -import infodynamics.measures.spiking.TransferEntropyCalculatorSpiking; -import infodynamics.utils.EmpiricalMeasurementDistribution; -import infodynamics.utils.KdTree; -import infodynamics.utils.MathsUtils; -import infodynamics.utils.MatrixUtils; -import infodynamics.utils.NeighbourNodeData; -import infodynamics.utils.FirstIndexComparatorDouble; -import infodynamics.utils.UnivariateNearestNeighbourSearcher; - -/** - * Computes the transfer entropy between a pair of spike trains, - * using an integration-based measure in order to match the theoretical - * form of TE between such spike trains. - * - *

Usage paradigm is as per the interface {@link TransferEntropyCalculatorSpiking}

- * - * @author Joseph Lizier (email, - * www) - */ -public class TransferEntropyCalculatorSpikingIntegrationOldRepresentation implements - TransferEntropyCalculatorSpiking { - - /** - * Number of past destination spikes to consider (akin to embedding length) - */ - protected int k = 1; - /** - * Number of past source spikes to consider (akin to embedding length) - */ - protected int l = 1; - - /** - * Number of nearest neighbours to search for in the full joint space - */ - protected int Knns = 4; - - /** - * Storage for source observations supplied via {@link #addObservations(double[], double[])} etc. - */ - protected Vector vectorOfSourceSpikeTimes = null; - - /** - * Storage for destination observations supplied via {@link #addObservations(double[], double[])} etc. - */ - protected Vector vectorOfDestinationSpikeTimes = null; - - // constants for indexing our data storage - protected final static int PREV_DEST = 0; - protected final static int PREV_SOURCE = 1; - protected final static int PREV_POSSIBILITIES = 2; - protected final static int NEXT_DEST = 0; - protected final static int NEXT_SOURCE = 1; - protected final static int NEXT_POSSIBILITIES = 2; - - /** - * Cache of the timing data for each new observed spiking event in both the source - * and destination - */ - Vector[][] eventTimings = null; - /** - * Cache of the timing data for each new observed spiking event for the - * destination only - */ - Vector destPastAndNextTimings = null; - /** - * Cache of the type of event for each new observed spiking event in both the source - * and destination (i.e. which spiked previously, which spiked next - */ - Vector eventTypeLocator = null; - /** - * Cache for each new observed spiking event of which index it has in the vector - * of spiking events of the same type - */ - Vector eventIndexLocator = null; - /** - * Cache for each time-series of observed spiking events of how many - * observations were in that set. - */ - Vector numEventsPerObservationSet = null; - - /** - * KdTrees for searching the joint past spaces and time to next spike, - * for each combination of which spiked previously and next - */ - protected KdTree[][] kdTreesJoint = null; - - /** - * KdTrees for searching the joint past spaces, - * for each combination of which spiked previously and next - */ - protected KdTree[][] kdTreesSourceDestHistories = null; - - /** - * KdTrees for searching the past destination space and time to next spike - */ - protected KdTree kdTreeDestNext = null; - - /** - * KdTrees for searching the past destination space - */ - protected KdTree kdTreeDestHistory = null; - - /** - * NN searcher for the time to next spike space only, if required - */ - protected UnivariateNearestNeighbourSearcher nnSearcherDestTimeToNextSpike = null; - - /** - * Property name for the number of nearest neighbours to search - */ - public static final String KNNS_PROP_NAME = "Knns"; - - - /** - * Stores whether we are in debug mode - */ - protected boolean debug = false; - - public TransferEntropyCalculatorSpikingIntegrationOldRepresentation() { - super(); - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(int) - */ - @Override - public void initialise() throws Exception { - initialise(k,l); - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(int) - */ - @Override - public void initialise(int k) throws Exception { - initialise(k,this.l); - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(int, int) - */ - @Override - public void initialise(int k, int l) throws Exception { - if ((k < 1) || (l < 1)) { - throw new Exception("Zero history length not supported"); - } - this.k = k; - this.l = l; - vectorOfSourceSpikeTimes = null; - vectorOfDestinationSpikeTimes = null; - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setProperty(java.lang.String, java.lang.String) - */ - @Override - public void setProperty(String propertyName, String propertyValue) - throws Exception { - boolean propertySet = true; - if (propertyName.equalsIgnoreCase(K_PROP_NAME)) { - k = Integer.parseInt(propertyValue); - } else if (propertyName.equalsIgnoreCase(L_PROP_NAME)) { - l = Integer.parseInt(propertyValue); - } else if (propertyName.equalsIgnoreCase(KNNS_PROP_NAME)) { - Knns = Integer.parseInt(propertyValue); - } else { - // No property was set on this class - propertySet = false; - } - if (debug && propertySet) { - System.out.println(this.getClass().getSimpleName() + ": Set property " + propertyName + - " to " + propertyValue); - } - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getProperty(java.lang.String) - */ - @Override - public String getProperty(String propertyName) throws Exception { - if (propertyName.equalsIgnoreCase(K_PROP_NAME)) { - return Integer.toString(k); - } else if (propertyName.equalsIgnoreCase(L_PROP_NAME)) { - return Integer.toString(l); - } else if (propertyName.equalsIgnoreCase(KNNS_PROP_NAME)) { - return Integer.toString(Knns); - } else { - // No property matches for this class - return null; - } - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setObservations(double[], double[]) - */ - @Override - public void setObservations(double[] source, double[] destination) - throws Exception { - startAddObservations(); - addObservations(source, destination); - finaliseAddObservations(); - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#startAddObservations() - */ - @Override - public void startAddObservations() { - vectorOfSourceSpikeTimes = new Vector(); - vectorOfDestinationSpikeTimes = new Vector(); - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#addObservations(double[], double[]) - */ - @Override - public void addObservations(double[] source, double[] destination) - throws Exception { - // Store these observations in our vector for now - vectorOfSourceSpikeTimes.add(source); - vectorOfDestinationSpikeTimes.add(destination); - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#finaliseAddObservations() - */ - @Override - public void finaliseAddObservations() throws Exception { - // TODO Auto embed if required - // preFinaliseAddObservations(); - - // Run through each spiking time series set and pull out the observation - // tuples we'll store. - // Initialise our data stores: - eventTimings = new Vector[PREV_POSSIBILITIES][NEXT_POSSIBILITIES]; - for (int prev = 0; prev < PREV_POSSIBILITIES; prev++) { - for (int next = 0; next < NEXT_POSSIBILITIES; next++) { - eventTimings[prev][next] = new Vector(); - } - } - destPastAndNextTimings = new Vector(); - eventTypeLocator = new Vector(); - eventIndexLocator = new Vector(); - numEventsPerObservationSet = new Vector(); - - // Send all of the observations through: - Iterator sourceIterator = vectorOfSourceSpikeTimes.iterator(); - int timeSeriesIndex = 0; - for (double[] destSpikeTimes : vectorOfDestinationSpikeTimes) { - double[] sourceSpikeTimes = sourceIterator.next(); - timeSeriesIndex++; - - processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, - timeSeriesIndex, eventTimings, destPastAndNextTimings, - eventTypeLocator, eventIndexLocator, numEventsPerObservationSet); - } - - // Now we have collected all the events. - // Load up the search structures: - // 1. Full joint space: - // 2. Histories of source and dest only: - kdTreesJoint = new KdTree[PREV_POSSIBILITIES][NEXT_POSSIBILITIES]; - kdTreesSourceDestHistories = new KdTree[PREV_POSSIBILITIES][NEXT_POSSIBILITIES]; - for (int prev = 0; prev < PREV_POSSIBILITIES; prev++) { - for (int next = 0; next < NEXT_POSSIBILITIES; next++) { - // This line does not work: - // double[][][] jointEventTimings = (double[][][]) eventTimings[prev][next].toArray(); - // So we'll do it manually: - double[][] sourcePastTimings = new double[eventTimings[prev][next].size()][]; - double[][] destPastTimings = new double[eventTimings[prev][next].size()][]; - double[][] nextTimings = new double[eventTimings[prev][next].size()][]; - int i = 0; - for (double[][] timing : eventTimings[prev][next]) { - sourcePastTimings[i] = timing[0]; - destPastTimings[i] = timing[1]; - nextTimings[i] = timing[2]; - i++; - } - // TODO Should we normalise before we supply to the KdTree? - // Think about this later. I'm not convinced it's the best - // approach in this particular case. - kdTreesJoint[prev][next] = new KdTree( - new int[] {prev == PREV_DEST ? l : l - 1, - prev == PREV_DEST ? k - 1 : k, - 1}, - new double[][][] {sourcePastTimings, destPastTimings, nextTimings}); - kdTreesSourceDestHistories[prev][next] = new KdTree( - new int[] {prev == PREV_DEST ? l : l - 1, - prev == PREV_DEST ? k - 1 : k}, - new double[][][] {sourcePastTimings, destPastTimings}); - } - } - // 3. For the dest past and time to next spike - // 4. For the dest past only - double[][] destPastOnlyTimings = new double[destPastAndNextTimings.size()][]; - double[][] nextTimingsForDestPastOnly = new double[destPastAndNextTimings.size()][]; - int i = 0; - for (double[][] timing : destPastAndNextTimings) { - destPastOnlyTimings[i] = timing[0]; - nextTimingsForDestPastOnly[i] = timing[1]; - i++; - } - kdTreeDestNext = new KdTree( - new int[] {k - 1, 1}, - new double[][][] {destPastOnlyTimings, nextTimingsForDestPastOnly}); - - if (k == 1) { - // We need an NN searcher for the time to next spike (dest only) - nnSearcherDestTimeToNextSpike = new UnivariateNearestNeighbourSearcher(nextTimingsForDestPastOnly); - } else { - kdTreeDestHistory = new KdTree(destPastOnlyTimings); - } - } - - protected void processEventsFromSpikingTimeSeries(double[] sourceSpikeTimes, double[] destSpikeTimes, - int timeSeriesIndex, Vector[][] eventTimings, - Vector destPastAndNextTimings, Vector eventTypeLocator, - Vector eventIndexLocator, Vector numEventsPerObservationSet) throws Exception { - // addObservationsAfterParamsDetermined(sourceSpikeTimes, destSpikeTimes); - - // First sort the spike times in case they were not properly in ascending order: - Arrays.sort(sourceSpikeTimes); - Arrays.sort(destSpikeTimes); - - // Scan to find the indices by which we have k and l spikes for dest and source - // respectively - int dest_index = k - 1; - int source_index = l - 1; - boolean previousIsDest = false; - double[] spikeTimesForPreviousSpiker = sourceSpikeTimes; - if (sourceSpikeTimes[source_index] > destSpikeTimes[dest_index]) { - // Minimum required Source spikes are later than the dest. - previousIsDest = false; - spikeTimesForPreviousSpiker = sourceSpikeTimes; - // Need to advance dest_index until it's the most recent before source_index - for(;dest_index < destSpikeTimes.length; dest_index++) { - if (destSpikeTimes[dest_index] > sourceSpikeTimes[source_index]) { - // We've gone past the set of source spikes we have, we - // can move back one in the dest series - dest_index--; - break; - } - } - if (dest_index == destSpikeTimes.length) { - // We didn't have enough spikes in this series to generate any observations - // TODO work out how to handle this later -- I think this is ok - numEventsPerObservationSet.add(0); - return; - // throw new Exception("Dest spikes stop before enough source spikes in time-series " + timeSeriesIndex); - } - } else { - // Minimum required Dest spikes are later than the source. - previousIsDest = true; - spikeTimesForPreviousSpiker = destSpikeTimes; - // Need to advance source_index until it's the most recent before dest_index - for(;source_index < sourceSpikeTimes.length; source_index++) { - if (sourceSpikeTimes[source_index] > destSpikeTimes[dest_index]) { - // We've gone past the set of dest spikes we have, we - // can move back one in the source series - source_index--; - break; - } - } - if (source_index == sourceSpikeTimes.length) { - // We didn't have enough spikes in this series to generate any observations - // TODO work out how to handle this later -- I think this is ok - numEventsPerObservationSet.add(0); - return; - // throw new Exception("Source spikes stop before enough dest spikes in time-series " + timeSeriesIndex); - } - } - // Post-condition: dest_index and source_index are set correctly for the first set of pasts - int indexForPreviousSpiker = previousIsDest ? dest_index : source_index; - - double timeToNextSpike; - boolean nextIsDest = false; - double[] spikeTimesForNextSpiker = sourceSpikeTimes; - double timeOfPrevSpike = spikeTimesForPreviousSpiker[indexForPreviousSpiker]; - int numEvents = 0; - while(true) { - // 0. Check whether we're finished - if ((source_index == sourceSpikeTimes.length - 1) && - (dest_index == destSpikeTimes.length - 1)) { - // We have no next spike so we can't take an observation here - // and we're done - break; - } - // Otherwise: - // 1. Determine which of source / dest fires next - if (source_index == sourceSpikeTimes.length - 1) { - nextIsDest = true; - } else if (dest_index == destSpikeTimes.length - 1) { - nextIsDest = false; - } else if (sourceSpikeTimes[source_index+1] < destSpikeTimes[dest_index+1]) { - nextIsDest = false; - } else { - nextIsDest = true; - } - spikeTimesForNextSpiker = nextIsDest ? destSpikeTimes : sourceSpikeTimes; - int indexForNextSpiker = nextIsDest ? dest_index : source_index; - timeToNextSpike = spikeTimesForNextSpiker[indexForNextSpiker+1] - timeOfPrevSpike; - // 2. Embed the past spikes - double[] sourcePast = new double[previousIsDest ? l : l - 1]; - double[] destPast = new double[previousIsDest ? k - 1 : k]; - /* if (debug) { - System.out.println("previousIsDest = " + previousIsDest + " and nextIsDest = " + nextIsDest); - }*/ - if (previousIsDest) { - sourcePast[0] = timeOfPrevSpike - - sourceSpikeTimes[source_index]; - } else { - destPast[0] = timeOfPrevSpike - - destSpikeTimes[dest_index]; - } - for (int i = 1; i < k; i++) { - destPast[previousIsDest ? i - 1 : i] = destSpikeTimes[dest_index - i + 1] - - destSpikeTimes[dest_index - i]; - } - for (int i = 1; i < l; i++) { - sourcePast[previousIsDest ? i : i - 1] = sourceSpikeTimes[source_index - i + 1] - - sourceSpikeTimes[source_index - i]; - } - // 3. Store these embedded observations - double[][] observations = new double[][]{sourcePast, destPast, - new double[] {timeToNextSpike}}; - if (debug) { - System.out.printf("Adding event %d with: timeToNextSpike=%.4f, sourceSpikeTimes=", numEvents, timeToNextSpike); - MatrixUtils.printArray(System.out, sourcePast, 3); - System.out.printf(", destSpikeTimes="); - MatrixUtils.printArray(System.out, destPast, 3); - System.out.println(); - } - // Add the index locator first so it gets the index correct before - // we add the new event in: - eventIndexLocator.add(eventTimings[previousIsDest ? PREV_DEST : PREV_SOURCE][nextIsDest ? NEXT_DEST : NEXT_SOURCE].size()); - eventTimings[previousIsDest ? PREV_DEST : PREV_SOURCE][nextIsDest ? NEXT_DEST : NEXT_SOURCE].add(observations); - eventTypeLocator.add(new int[] {previousIsDest ? PREV_DEST : PREV_SOURCE, - nextIsDest ? NEXT_DEST : NEXT_SOURCE}); - // And finally store the observations for the dest only - // search structure if required: - if (nextIsDest) { - double[][] destOnlyObservations; - if (previousIsDest) { - destOnlyObservations = new double[][] { - destPast, - new double[] {timeToNextSpike} - }; - } else { - // previous is source: - // We can take a copy of destPast, removing the first entry - // (since this only signals time the dest last fired before the source) - // and add that entry to the timeToNextSpike (which was back to the - // source firing). - double[] destPastOnly = Arrays.copyOfRange(destPast, 1, destPast.length); - double timeToNextSpikeSincePreviousDestSpike = - destPast[0] + timeToNextSpike; - destOnlyObservations = new double[][] { - destPastOnly, - new double[] {timeToNextSpikeSincePreviousDestSpike} - }; - } - destPastAndNextTimings.add(destOnlyObservations); - } - // 4. Reset prev as next ... - previousIsDest = nextIsDest; - if (previousIsDest) { - dest_index++; - } else { - source_index++; - } - spikeTimesForPreviousSpiker = previousIsDest ? destSpikeTimes : sourceSpikeTimes; - indexForPreviousSpiker = previousIsDest ? dest_index : source_index; - timeOfPrevSpike = spikeTimesForPreviousSpiker[indexForPreviousSpiker]; - numEvents++; - } - numEventsPerObservationSet.add(numEvents); - if (debug) { - System.out.printf("Finished processing %d source-target events for observation set %d\n", numEvents, timeSeriesIndex); - } - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getAddedMoreThanOneObservationSet() - */ - @Override - public boolean getAddedMoreThanOneObservationSet() { - return (vectorOfDestinationSpikeTimes != null) && - (vectorOfDestinationSpikeTimes.size() > 1); - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#computeAverageLocalOfObservations() - */ - @Override - public double computeAverageLocalOfObservations() throws Exception { - - int numberOfEvents = eventTypeLocator.size(); - - double te = 0; - double contributionFromSpikes = 0; - double totalTimeLength = 0; - - double digammaK = MathsUtils.digamma(Knns); - double inverseKTerm = 2.0 / (double) k; - - // Create temporary storage for arrays used in the neighbour counting: - boolean[] isWithinR = new boolean[numberOfEvents]; // dummy, we don't really use this - int[] indicesWithinR = new int[numberOfEvents]; - - // Iterate over all the spiking events: - Iterator eventIndexIterator = eventIndexLocator.iterator(); - int eventIndex = -1; - int indexForNextIsDest = -1; - for (int[] eventType : eventTypeLocator) { - eventIndex++; - int eventIndexWithinType = eventIndexIterator.next().intValue(); - double[][] thisEventTimings = eventTimings[eventType[0]][eventType[1]].elementAt(eventIndexWithinType); - totalTimeLength += thisEventTimings[2][0]; - // Find the Knns nearest neighbour matches to this event, - // with the same previous spiker and the next. - // TODO Add dynamic exclusion time later - PriorityQueue nnPQ = - kdTreesJoint[eventType[0]][eventType[1]].findKNearestNeighbours( - Knns, eventIndexWithinType); - - // Find eps_{x,y,z} as the maximum x, y and z norms amongst this set: - double radius_sourcePast = 0.0; - double radius_destPast = 0.0; - double radius_destNext = 0.0; - int radius_destNext_sampleIndex = -1; - for (int j = 0; j < k; j++) { - // Take the furthest remaining of the nearest neighbours from the PQ: - NeighbourNodeData nnData = nnPQ.poll(); - if (nnData.norms[0] > radius_sourcePast) { - radius_sourcePast = nnData.norms[0]; - } - if (nnData.norms[1] > radius_destPast) { - radius_destPast = nnData.norms[1]; - } - if (nnData.norms[2] > radius_destNext) { - radius_destNext = nnData.norms[2]; - radius_destNext_sampleIndex = nnData.sampleIndex; - } - } - // TODO Do we need to correct radius_destNext to have a different value below, chopping it - // where it pushes into negative times (i.e. *before* the previous spike)? - - if (debug && (eventIndex < 10000)) { - // Pull out the data for this observation: - System.out.print("index = " + eventIndex + ", " + - eventIndexWithinType + " for " + - (eventType[0] == PREV_DEST ? "dst" : "src") + - "->" + - (eventType[1] == NEXT_DEST ? "dst" : "src") + - ", timings: src: "); - MatrixUtils.printArray(System.out, thisEventTimings[0], 3); - System.out.print(", dest: "); - MatrixUtils.printArray(System.out, thisEventTimings[1], 3); - System.out.print(", time to next: "); - MatrixUtils.printArray(System.out, thisEventTimings[2], 3); - System.out.printf("index=%d: K=%d NNs at next_range %.5f (point %d)", eventIndexWithinType, Knns, radius_destNext, radius_destNext_sampleIndex); - } - - // Select only events where the destination spiked next: - if (eventType[1] == NEXT_DEST) { - indexForNextIsDest++; - - // Now find the matching samples in each sub-space; - // first match dest history and source history, with a next spike in dest: - kdTreesSourceDestHistories[eventType[0]][NEXT_DEST]. - findPointsWithinRs(eventIndexWithinType, - new double[] {radius_sourcePast, radius_destPast}, 0, - true, isWithinR, indicesWithinR); - // And check which of these samples had spike time in dest after ours: - int countOfDestNextAndGreater = 0; - for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { - // Pull out this matching event from the full joint space - double[][] matchedHistoryEventTimings = eventTimings[eventType[0]][NEXT_DEST].elementAt(indicesWithinR[nIndex]); - if (matchedHistoryEventTimings[2][0] > thisEventTimings[2][0] + radius_destNext) { - // This sample had a matched history and next spike was a destination - // spike with a longer interval than the current sample - countOfDestNextAndGreater++; - } - // Reset the isWithinR array while we're here - isWithinR[indicesWithinR[nIndex]] = false; - } - // And count how many samples with the matching history actually had a - // *source* spike next, after ours. - // Note that we now must go to the other kdTree for next source spike - kdTreesSourceDestHistories[eventType[0]][NEXT_SOURCE]. - findPointsWithinRs( - new double[] {radius_sourcePast, radius_destPast}, thisEventTimings, - true, isWithinR, indicesWithinR); - // And check which of these samples had spike time in source at or after ours: - int countOfSourceNextAndGreater = 0; - for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { - // Pull out this matching event from the full joint space - double[][] matchedHistoryEventTimings = eventTimings[eventType[0]][NEXT_SOURCE].elementAt(indicesWithinR[nIndex]); - if (matchedHistoryEventTimings[2][0] >= thisEventTimings[2][0] - radius_destNext) { - // This sample had a matched history and next spike was a source - // spike with an interval longer than or considered equal to the current sample. - // (The "equal to" is why we look for matches within radius_destNext here as well.) - countOfSourceNextAndGreater++; - } - // Reset the isWithinR array while we're here - isWithinR[indicesWithinR[nIndex]] = false; - } - - if (debug && (eventIndex < 10000)) { - System.out.printf(" of %d + %d + %d points with matching S-D history", - Knns, countOfSourceNextAndGreater, countOfDestNextAndGreater); - } - - // Now find the matching samples in the dest history and - // with a next spike timing. - // Construct the appropriate timings to compare to here: - double timeToNextSpikeSincePreviousDestSpike; - if (eventType[0] == PREV_DEST) { - timeToNextSpikeSincePreviousDestSpike = thisEventTimings[2][0]; - } else { - // previous is source: - timeToNextSpikeSincePreviousDestSpike = - thisEventTimings[1][0] + thisEventTimings[2][0]; - } - int countOfDestNextAndGreaterMatchedDest = 0; - int countOfDestNextMatched = 0; - if (k > 1) { - // Search only the space of dest past -- no point - // searching dest past and next, since we need to run through - // all matches of dest past to count those with greater next spike - // times we might as well count those with matching spike times - // while we're at it. - kdTreeDestHistory.findPointsWithinR(indexForNextIsDest, radius_destPast, - true, isWithinR, indicesWithinR); - // And check which of these samples had next spike time after ours: - for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { - // Pull out this matching event from the dest history space - double[][] matchedHistoryEventTimings = destPastAndNextTimings.elementAt(indicesWithinR[nIndex]); - if (matchedHistoryEventTimings[1][0] >= timeToNextSpikeSincePreviousDestSpike - radius_destNext) { - // This sample had a matched history and next spike was a - // spike with an interval longer than or considered equal to the current sample. - // (The "equal to" is why we look for matches within kthNnData.distance here as well.) - countOfDestNextAndGreaterMatchedDest++; - if (matchedHistoryEventTimings[1][0] <= timeToNextSpikeSincePreviousDestSpike + radius_destNext) { - // Then we also have a match on the next spike itself - countOfDestNextMatched++; - } - } - // Reset the isWithinR array while we're here - isWithinR[indicesWithinR[nIndex]] = false; - } - } else { - // We don't take any past dest spike ISIs into account, so we just need to look at the proportion of next - // spike times that match. - countOfDestNextMatched = nnSearcherDestTimeToNextSpike.countPointsWithinOrOnR(indexForNextIsDest, radius_destNext); - countOfDestNextAndGreaterMatchedDest = countOfDestNextMatched + - nnSearcherDestTimeToNextSpike.countPointsWithinROrLarger(indexForNextIsDest, radius_destNext, true); - } - - if (debug && (eventIndex < 10000)) { - System.out.printf(", and %d of %d points for D history only; ", - countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest); - } - - // With these neighbours counted, we're ready to compute the probability of the spike given the past - // of source and dest. - double logPGivenSourceAndDest = digammaK - inverseKTerm - - MathsUtils.digamma(Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater) - + 1.0 / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater); - double logPGivenDest = MathsUtils.digamma(countOfDestNextMatched) - - MathsUtils.digamma(countOfDestNextAndGreaterMatchedDest) - + 1.0 / ((double) countOfDestNextAndGreaterMatchedDest); - if (debug && (eventIndex < 10000)) { - System.out.printf(" te ~~ log (%d/%d)/(%d/%d) = %.4f -> %.4f (inferred rates %.4f vs %.4f)\n", Knns, - Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater, - countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest, - Math.log(((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) / - ((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest))), - logPGivenSourceAndDest - logPGivenDest, - (double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater) / (2.0*radius_destNext), - (double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest) / (2.0*radius_destNext)); - } - contributionFromSpikes += logPGivenSourceAndDest - logPGivenDest; - } else { - if (debug) { - System.out.println(); - } - } - - } - contributionFromSpikes /= totalTimeLength; - te = contributionFromSpikes; - return te; - } - - public double computeAverageLocalOfObservationsAlg1() throws Exception { - - int numberOfEvents = eventTypeLocator.size(); - - double te = 0; - double contributionFromSpikes = 0; - double contributionFromNonSpikes = 0; - double contributionFromNonSpikes_destOnly = 0; - double contributionFromNonSpikes_destAndSource = 0; - double totalTimeLength = 0; - - // Create temporary storage for arrays used in the neighbour counting: - boolean[] isWithinR = new boolean[numberOfEvents]; // dummy, we don't really use this - int[] indicesWithinR = new int[numberOfEvents]; - - // Iterate over all the spiking events: - Iterator eventIndexIterator = eventIndexLocator.iterator(); - int eventIndex = -1; - int indexForNextIsDest = -1; - for (int[] eventType : eventTypeLocator) { - eventIndex++; - int eventIndexWithinType = eventIndexIterator.next().intValue(); - double[][] thisEventTimings = eventTimings[eventType[0]][eventType[1]].elementAt(eventIndexWithinType); - totalTimeLength += thisEventTimings[2][0]; - // Find the Knns nearest neighbour matches to this event, - // with the same previous spiker and the next. - // TODO Add dynamic exclusion time later - PriorityQueue nnPQ = - kdTreesJoint[eventType[0]][eventType[1]].findKNearestNeighbours( - Knns, eventIndexWithinType); - // First element in the PQ is the kth NN, - // and epsilon = kthNnData.distance - NeighbourNodeData kthNnData = nnPQ.poll(); - double radiusToKnn = kthNnData.distance; - if (debug && (eventIndex < 10000)) { - // Pull out the data for this observation: - System.out.print("index = " + eventIndex + ", " + - eventIndexWithinType + " for " + - (eventType[0] == PREV_DEST ? "dst" : "src") + - "->" + - (eventType[1] == NEXT_DEST ? "dst" : "src") + - ", timings: src: "); - MatrixUtils.printArray(System.out, thisEventTimings[0], 3); - System.out.print(", dest: "); - MatrixUtils.printArray(System.out, thisEventTimings[1], 3); - System.out.print(", time to next: "); - MatrixUtils.printArray(System.out, thisEventTimings[2], 3); - System.out.printf("index=%d: K=%d NNs at range %.5f (point %d)", eventIndexWithinType, Knns, radiusToKnn, kthNnData.sampleIndex); - } - - // Select only events where the destination spiked next: - if (eventType[1] == NEXT_DEST) { - indexForNextIsDest++; - - // Now find the matching samples in each sub-space; - // first match dest history and source history, with a next spike in dest: - kdTreesSourceDestHistories[eventType[0]][NEXT_DEST]. - findPointsWithinR(eventIndexWithinType, radiusToKnn, 0, - false, isWithinR, indicesWithinR); - // And check which of these samples had spike time in dest after ours: - int countOfDestNextAndGreater = 0; - for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { - // Pull out this matching event from the full joint space - double[][] matchedHistoryEventTimings = eventTimings[eventType[0]][NEXT_DEST].elementAt(indicesWithinR[nIndex]); - if (matchedHistoryEventTimings[2][0] >= thisEventTimings[2][0] + radiusToKnn) { - // This sample had a matched history and next spike was a destination - // spike with a longer interval than the current sample - countOfDestNextAndGreater++; - } - // Reset the isWithinR array while we're here - isWithinR[indicesWithinR[nIndex]] = false; - } - // And count how many samples with the matching history actually had a - // *source* spike next, after ours. - // Note that we now must go to the other kdTree for next source spike - kdTreesSourceDestHistories[eventType[0]][NEXT_SOURCE]. - findPointsWithinR(radiusToKnn, thisEventTimings, - false, isWithinR, indicesWithinR); - // And check which of these samples had spike time in source after ours: - int countOfSourceNextAndGreater = 0; - for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { - // Pull out this matching event from the full joint space - double[][] matchedHistoryEventTimings = eventTimings[eventType[0]][NEXT_SOURCE].elementAt(indicesWithinR[nIndex]); - if (matchedHistoryEventTimings[2][0] > thisEventTimings[2][0] - radiusToKnn) { - // This sample had a matched history and next spike was a source - // spike with an interval longer than or considered equal to the current sample. - // (The "equal to" is why we look for matches within kthNnData.distance here as well.) - countOfSourceNextAndGreater++; - } - // Reset the isWithinR array while we're here - isWithinR[indicesWithinR[nIndex]] = false; - } - - if (debug && (eventIndex < 10000)) { - System.out.printf(" of %d + %d + %d points with matching S-D history", - Knns, countOfSourceNextAndGreater, countOfDestNextAndGreater); - } - - // Now find the matching samples in the dest history and - // with a next spike timing. - // Construct the appropriate timings to compare to here: - double[][] destOnlyObservations; - double[][] destPastOnlyObservations; - double timeToNextSpikeSincePreviousDestSpike; - if (eventType[0] == PREV_DEST) { - destOnlyObservations = new double[][] { - thisEventTimings[1], // timing of past dest spikes - thisEventTimings[2] // time to next spike - }; - timeToNextSpikeSincePreviousDestSpike = thisEventTimings[2][0]; - destPastOnlyObservations = new double[][] { - thisEventTimings[1] // timing of past dest spikes - }; - } else { - // previous is source: - // We can take a copy of the dest past timings, removing the first entry - // (since this only signals time the dest last fired before the source) - // and add that entry to the timeToNextSpike (which was back to the - // source firing). - double[] destPastOnly = Arrays.copyOfRange( - thisEventTimings[1], 1, thisEventTimings[1].length); - timeToNextSpikeSincePreviousDestSpike = - thisEventTimings[1][0] + thisEventTimings[2][0]; - destOnlyObservations = new double[][] { - destPastOnly, - new double[] {timeToNextSpikeSincePreviousDestSpike} - }; - destPastOnlyObservations = new double[][] { - destPastOnly - }; - } - int countOfDestNextAndGreaterMatchedDest = 0; - int countOfDestNextMatched = 0; - if (k > 1) { - // Search only the space of dest past -- no point - // searching dest past and next, since we need to run through - // all matches of dest past to count those with greater next spike - // times we might as well count those with matching spike times - // while we're at it. - - // OLD WAY: - // NO NO NO -- Can't search for it this way, because it's biased -- - // should search for it by giving the index of this dest past-next - // observation, so that it doesn't match to this observation. - // Should be able to use indexForNextIsDest here - //kdTreeDestHistory.findPointsWithinR(radiusToKnn, destPastOnlyObservations, - // false, isWithinR, indicesWithinR); - // Proper way: - kdTreeDestHistory.findPointsWithinR(indexForNextIsDest, radiusToKnn, - false, isWithinR, indicesWithinR); - // And check which of these samples had next spike time after ours: - for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { - // Pull out this matching event from the dest history space - double[][] matchedHistoryEventTimings = destPastAndNextTimings.elementAt(indicesWithinR[nIndex]); - if (matchedHistoryEventTimings[1][0] > timeToNextSpikeSincePreviousDestSpike - radiusToKnn) { - // This sample had a matched history and next spike was a - // spike with an interval longer than or considered equal to the current sample. - // (The "equal to" is why we look for matches within kthNnData.distance here as well.) - countOfDestNextAndGreaterMatchedDest++; - if (matchedHistoryEventTimings[1][0] < timeToNextSpikeSincePreviousDestSpike + radiusToKnn) { - // Then we also have a match on the next spike itself - countOfDestNextMatched++; - } - } - // Reset the isWithinR array while we're here - isWithinR[indicesWithinR[nIndex]] = false; - } - } else { - // We don't take any past dest spike times into account, so we just need to look at the proportion of next - // spike times that match. - countOfDestNextMatched = nnSearcherDestTimeToNextSpike.countPointsStrictlyWithinR(indexForNextIsDest, radiusToKnn); - countOfDestNextAndGreaterMatchedDest = countOfDestNextMatched + - nnSearcherDestTimeToNextSpike.countPointsWithinROrLarger(indexForNextIsDest, radiusToKnn, false); - } - - if (debug && (eventIndex < 10000)) { - System.out.printf(", and %d of %d points for D history only; ", - countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest); - } - - // With these neighbours counted, we're ready to compute the probability of the spike given the past - // of source and dest. - // Digammas for algorithm 1 include the extra "+1" on all terms except - // for the full joint space - double logPGivenSourceAndDest = MathsUtils.digamma(Knns) - - MathsUtils.digamma(Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater + 1); - double logPGivenDest = MathsUtils.digamma(countOfDestNextMatched + 1) - - MathsUtils.digamma(countOfDestNextAndGreaterMatchedDest + 1); - if (debug && (eventIndex < 10000)) { - System.out.printf(" te ~~ log (%d/%d)/(%d/%d) = %.4f -> %.4f (inferred rates %.4f vs %.4f)\n", Knns, - Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater + 1, - countOfDestNextMatched + 1, countOfDestNextAndGreaterMatchedDest + 1, - Math.log(((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater + 1)) / - ((double) (countOfDestNextMatched + 1) / (double) (countOfDestNextAndGreaterMatchedDest + 1))), - logPGivenSourceAndDest - logPGivenDest, - (double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater + 1) / (2.0*radiusToKnn), - (double) (countOfDestNextMatched + 1) / (double) (countOfDestNextAndGreaterMatchedDest + 1) / (2.0*radiusToKnn)); - } - contributionFromSpikes += logPGivenSourceAndDest - logPGivenDest; - } else { - if (debug) { - System.out.println(); - } - } - - // Regardless of which type of event it was, we need to integrate - // the spiking rates up until the next spiking event - // Our first attempt at a solution uses the search width defined - // using the history and the next spike. - - // Consider first the destination process only. - // Match dest history - double[][] destPastOnlyObservations; - double timeToNextSpikeSincePreviousDestSpike; - if (eventType[0] == PREV_DEST) { - timeToNextSpikeSincePreviousDestSpike = thisEventTimings[2][0]; - destPastOnlyObservations = new double[][] { - thisEventTimings[1] // timing of past dest spikes - }; - } else { - // previous is source: - // We can take a copy of the dest past timings, removing the first entry - // (since this only signals time the dest last fired before the source) - // and add that entry to the timeToNextSpike (which was back to the - // source firing). - double[] destPastOnly = Arrays.copyOfRange( - thisEventTimings[1], 1, thisEventTimings[1].length); - timeToNextSpikeSincePreviousDestSpike = - thisEventTimings[1][0] + thisEventTimings[2][0]; - destPastOnlyObservations = new double[][] { - destPastOnly - }; - } - int countOfDestNextEarlier = 0; - int countOfDestMatches = 0; - if (k > 1) { - kdTreeDestHistory.findPointsWithinR(radiusToKnn, destPastOnlyObservations, - false, isWithinR, indicesWithinR); - // And check which of these samples had next spike time before ours: - for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { - // Pull out this matching event from the dest history space - double[][] matchedHistoryEventTimings = destPastAndNextTimings.elementAt(indicesWithinR[nIndex]); - if (matchedHistoryEventTimings[1][0] < timeToNextSpikeSincePreviousDestSpike) { - // This sample had a matched history and next spike was a - // spike with an interval shorter than the current sample. - countOfDestNextEarlier++; - } - // Reset the isWithinR array while we're here - isWithinR[indicesWithinR[nIndex]] = false; - countOfDestMatches++; - } - } else { - // We're not using the past, so we match on everything up to the last spike - countOfDestMatches = nnSearcherDestTimeToNextSpike.getNumObservations(); - countOfDestNextEarlier = nnSearcherDestTimeToNextSpike.countPointsSmallerAndOutsideR( - // indexForNextIsDest must point to the next event (possibly this one) where the dest spikes next. - (eventType[1] == NEXT_DEST) ? indexForNextIsDest : indexForNextIsDest + 1, - radiusToKnn, false); - - } - // And include the contribution for each of these - double integralForDestHistorySpace = 0; - for (int hi = 0; hi < countOfDestNextEarlier; hi++) { - integralForDestHistorySpace += (double) 1 / - (double) (countOfDestMatches - hi); - } - - // First match dest history and source history, with a next spike in dest: - kdTreesSourceDestHistories[eventType[0]][NEXT_DEST]. - findPointsWithinR(radiusToKnn, thisEventTimings, - false, isWithinR, indicesWithinR); - // And store which of these samples had spike time in dest before ours. - // Store them in a vector of double arrays, with each array holding the - // spike time then -1 for a next dest spike and +1 for a source spike - Vector spikesBeforeOurs = new Vector(); - int countOfSpikesAfterAndIncludingOurs = 0; - for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { - // Pull out this matching event from the full joint space - double[][] matchedHistoryEventTimings = eventTimings[eventType[0]][NEXT_DEST].elementAt(indicesWithinR[nIndex]); - if (matchedHistoryEventTimings[2][0] < thisEventTimings[2][0]) { - // This sample had a matched history and next spike was a destination - // spike with a shorted interval than the current sample - spikesBeforeOurs.add(new double[] { - matchedHistoryEventTimings[2][0], -1}); - } else { - countOfSpikesAfterAndIncludingOurs++; - } - // Reset the isWithinR array while we're here - isWithinR[indicesWithinR[nIndex]] = false; - } - // And store which of these samples had spike time in source before ours. - // Store them in a vector of double arrays, with each array holding the - // spike time then -1 for a next dest spike and +1 for a source spike - // Note that we now must go to the other kdTree for next source spike - kdTreesSourceDestHistories[eventType[0]][NEXT_SOURCE]. - findPointsWithinR(radiusToKnn, thisEventTimings, - false, isWithinR, indicesWithinR); - for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { - // Pull out this matching event from the full joint space - double[][] matchedHistoryEventTimings = eventTimings[eventType[0]][NEXT_SOURCE].elementAt(indicesWithinR[nIndex]); - if (matchedHistoryEventTimings[2][0] < thisEventTimings[2][0]) { - // This sample had a matched history and next spike was a source - // spike with a shorted interval than the current sample - spikesBeforeOurs.add(new double[] { - matchedHistoryEventTimings[2][0], +1}); - } else { - countOfSpikesAfterAndIncludingOurs++; - } - // Reset the isWithinR array while we're here - isWithinR[indicesWithinR[nIndex]] = false; - } - // Now we can sort the spikes which occur before ours and process - // them in order: - // Next line doesn't work, so replaced with clunkier code: - // double[][] nextSpikeTimesAndType = (double[][]) spikesBeforeOurs.toArray(); - double[][] nextSpikeTimesAndType = new double[spikesBeforeOurs.size()][]; - for (int si = 0; si < nextSpikeTimesAndType.length; si++) { - nextSpikeTimesAndType[si] = spikesBeforeOurs.elementAt(si); - } - Arrays.sort(nextSpikeTimesAndType, FirstIndexComparatorDouble.getInstance()); - double integralForJointSpace = 0; - for (int si = 0; si < nextSpikeTimesAndType.length; si++) { - if (nextSpikeTimesAndType[si][1] < 0) { - // We have a next spike from the dest, which is - // earlier than our spike. - // Integrated Prob for getting a spike here is 1 / N, where - // N is the number of properly matched histories (i.e. - // which don't have a next spike before this one) - double intPNext = (double) 1 / (double) - (nextSpikeTimesAndType.length - si + countOfSpikesAfterAndIncludingOurs); - integralForJointSpace += intPNext; - } - // Ignore next spikes on the source, they simply get removed - // from the matched histories count - } - - // We now have the integral of spike rates given the dest and - // joint histories, so subtract this out: - contributionFromNonSpikes += integralForDestHistorySpace - integralForJointSpace; - contributionFromNonSpikes_destAndSource += integralForJointSpace; - contributionFromNonSpikes_destOnly += integralForDestHistorySpace; - } - contributionFromSpikes /= totalTimeLength; - contributionFromNonSpikes /= totalTimeLength; - contributionFromNonSpikes_destAndSource /= totalTimeLength; - contributionFromNonSpikes_destOnly /= totalTimeLength; - te = contributionFromSpikes + contributionFromNonSpikes; - System.out.printf("TE = %.4f (spikes) + %.4f (non-spikes: d:%.4f - s-d:%.4f) = %.4f\n", - contributionFromSpikes, contributionFromNonSpikes, - contributionFromNonSpikes_destOnly, contributionFromNonSpikes_destAndSource, te); - return te; - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#computeLocalOfPreviousObservations() - */ - @Override - public SpikingLocalInformationValues computeLocalOfPreviousObservations() - throws Exception { - // TODO Auto-generated method stub - return null; - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#computeSignificance(int) - */ - @Override - public EmpiricalMeasurementDistribution computeSignificance( - int numPermutationsToCheck) throws Exception { - // TODO Auto-generated method stub - return null; - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#computeSignificance(int[][]) - */ - @Override - public EmpiricalMeasurementDistribution computeSignificance( - int[][] newOrderings) throws Exception { - // TODO Auto-generated method stub - return null; - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setDebug(boolean) - */ - @Override - public void setDebug(boolean debug) { - this.debug = debug; - } - - /* (non-Javadoc) - * @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getLastAverage() - */ - @Override - public double getLastAverage() { - // TODO Auto-generated method stub - return 0; - } - -} diff --git a/java/source/infodynamics/utils/KdTree.java b/java/source/infodynamics/utils/KdTree.java index 0f58755..a8c8f09 100644 --- a/java/source/infodynamics/utils/KdTree.java +++ b/java/source/infodynamics/utils/KdTree.java @@ -275,10 +275,10 @@ public class KdTree extends NearestNeighbourSearcher { // Could remove these since the code is now functional, // but may be better to leave them in just in case the code breaks: if (leftIndex > leftStart + leftNumPoints) { - throw new RuntimeException("Exceeded expected number of points on left"); + throw new RuntimeException("Exceeded expected number of points on left - likely had an NaN in the data"); } if (rightIndex > rightStart + rightNumPoints) { - throw new RuntimeException("Exceeded expected number of points on right"); + throw new RuntimeException("Exceeded expected number of points on right - likely had an NaN in the data"); } // Update the pointer for the sorted indices for this dimension, // and keep the new temporary array @@ -337,7 +337,7 @@ public class KdTree extends NearestNeighbourSearcher { double difference = x1[d] - x2[d]; distance += difference * difference; } - return Math.sqrt(distance); + return distance; } } @@ -359,7 +359,6 @@ public class KdTree extends NearestNeighbourSearcher { */ public final static double normWithAbort(double[] x1, double[] x2, double limit, int normToUse) { - double distance = 0.0; switch (normToUse) { case EuclideanUtils.NORM_MAX_NORM: @@ -380,17 +379,15 @@ public class KdTree extends NearestNeighbourSearcher { return distance; // case EuclideanUtils.NORM_EUCLIDEAN_SQUARED: default: - // Limit is often r, so must square - limit = limit * limit; // Inlined from {@link EuclideanUtils}: for (int d = 0; d < x1.length; d++) { double difference = x1[d] - x2[d]; distance += difference * difference; if (distance > limit) { - return Double.POSITIVE_INFINITY; + return Double.POSITIVE_INFINITY; } } - return Math.sqrt(distance); + return distance; } } @@ -1308,10 +1305,9 @@ public class KdTree extends NearestNeighbourSearcher { if (normTypeToUse == EuclideanUtils.NORM_MAX_NORM) { absDistOnThisDim = (distOnThisDim > 0) ? distOnThisDim : - distOnThisDim; } else { - absDistOnThisDim = (distOnThisDim > 0) ? distOnThisDim : - distOnThisDim; // norm type is EuclideanUtils#NORM_EUCLIDEAN_SQUARED // Track the square distance - //absDistOnThisDim = distOnThisDim; + absDistOnThisDim = distOnThisDim * distOnThisDim; } if ((node.indexOfThisPoint != sampleIndex) && @@ -2057,7 +2053,7 @@ public class KdTree extends NearestNeighbourSearcher { KdTreeNode node, int level, double r, boolean allowEqualToR, boolean[] isWithinR, int[] indicesWithinR, int nextIndexInIndicesWithinR) { - //System.out.println("foo"); + // Point to the correct array for the data at this level int currentDim = level % totalDimensions; double[][] data = dimensionToArray[currentDim]; @@ -2072,9 +2068,9 @@ public class KdTree extends NearestNeighbourSearcher { if (normTypeToUse == EuclideanUtils.NORM_MAX_NORM) { absDistOnThisDim = (distOnThisDim > 0) ? distOnThisDim : - distOnThisDim; } else { - absDistOnThisDim = (distOnThisDim > 0) ? distOnThisDim : - distOnThisDim; // norm type is EuclideanUtils#NORM_EUCLIDEAN_SQUARED // Track the square distance + absDistOnThisDim = distOnThisDim * distOnThisDim; } if ((absDistOnThisDim < r) || diff --git a/tester.py b/tester.py index 09817c4..cae3631 100755 --- a/tester.py +++ b/tester.py @@ -27,8 +27,8 @@ import os import numpy as np -NUM_REPS = 5 -NUM_SPIKES = int(5e3) +NUM_REPS = 2 +NUM_SPIKES = int(2e3) NUM_OBSERVATIONS = 2 # Params for canonical example generation @@ -99,6 +99,8 @@ for i in range(NUM_REPS): teCalc.finaliseAddObservations(); result = teCalc.computeAverageLocalOfObservations() print("TE result %.4f nats" % (result,)) + sig = teCalc.computeSignificance(10, result) + print(sig.pValue) results_poisson[i] = result print("Summary: mean ", np.mean(results_poisson), " std dev ", np.std(results_poisson)) @@ -111,59 +113,67 @@ print("Summary: mean ", np.mean(results_poisson), " std dev ", np.std(results_po teCalc = teCalcClass() teCalc.setProperty("knns", "4") print("Noisy copy zero TE") -teCalc.setProperty("COND_EMBED_LENGTHS", "2") +teCalc.setProperty("COND_EMBED_LENGTHS", "1") teCalc.setProperty("k_HISTORY", "2") -teCalc.setProperty("l_HISTORY", "2") +teCalc.setProperty("l_HISTORY", "1") +#teCalc.setProperty("NORM_TYPE", "MAX_NORM") results_noisy_zero = np.zeros(NUM_REPS) for i in range(NUM_REPS): teCalc.startAddObservations() for j in range(NUM_OBSERVATIONS): - condArray = NUM_SPIKES*np.random.random((1, NUM_SPIKES)) + condArray = np.ones((1, NUM_SPIKES)) + 0.05 * np.random.random((1, NUM_SPIKES)) + condArray = np.cumsum(condArray, axis = 1) condArray.sort(axis = 1) - sourceArray = condArray[0, :] + 0.25 + 0.1 * np.random.normal(size = condArray.shape[1]) + sourceArray = condArray[0, :] + 0.25 + 0.05 * np.random.normal(size = condArray.shape[1]) sourceArray.sort() - destArray = condArray[0, :] + 0.5 + 0.1 * np.random.normal(size = condArray.shape[1]) + destArray = condArray[0, :] + 0.5 + 0.05 * np.random.normal(size = condArray.shape[1]) destArray.sort() teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray), JArray(JDouble, 2)(condArray)) teCalc.finaliseAddObservations(); result = teCalc.computeAverageLocalOfObservations() print("TE result %.4f nats" % (result,)) - results_poisson[i] = result -print("Summary: mean ", np.mean(results_poisson), " std dev ", np.std(results_poisson)) + sig = teCalc.computeSignificance(10, result) + print(sig.pValue) + results_noisy_zero[i] = result +print("Summary: mean ", np.mean(results_noisy_zero), " std dev ", np.std(results_noisy_zero)) teCalc = teCalcClass() teCalc.setProperty("knns", "4") print("Noisy copy non-zero TE") -teCalc.setProperty("COND_EMBED_LENGTHS", "2") +teCalc.setProperty("COND_EMBED_LENGTHS", "1") teCalc.setProperty("k_HISTORY", "2") -teCalc.setProperty("l_HISTORY", "2") +teCalc.setProperty("l_HISTORY", "1") +#teCalc.setProperty("NORM_TYPE", "MAX_NORM") -results_noisy_zero = np.zeros(NUM_REPS) +results_noisy_non_zero = np.zeros(NUM_REPS) for i in range(NUM_REPS): teCalc.startAddObservations() for j in range(NUM_OBSERVATIONS): - sourceArray = NUM_SPIKES*np.random.random(NUM_SPIKES) + sourceArray = np.ones((1, NUM_SPIKES)) + 0.05 * np.random.random((1, NUM_SPIKES)) + sourceArray = np.cumsum(sourceArray) sourceArray.sort() - condArray = sourceArray + 0.25 + 0.1 * np.random.normal(size = sourceArray.shape) + condArray = sourceArray + 0.25 + 0.05 * np.random.normal(size = sourceArray.shape) condArray.sort() condArray = np.expand_dims(condArray, 0) - destArray = sourceArray + 0.5 + 0.1 * np.random.normal(size = sourceArray.shape) + destArray = sourceArray + 0.5 + 0.05 * np.random.normal(size = sourceArray.shape) destArray.sort() teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray), JArray(JDouble, 2)(condArray)) teCalc.finaliseAddObservations(); result = teCalc.computeAverageLocalOfObservations() print("TE result %.4f nats" % (result,)) - results_poisson[i] = result -print("Summary: mean ", np.mean(results_poisson), " std dev ", np.std(results_poisson)) + sig = teCalc.computeSignificance(10, result) + print(sig.pValue) + results_noisy_non_zero[i] = result +print("Summary: mean ", np.mean(results_noisy_non_zero), " std dev ", np.std(results_noisy_zero)) print("Canonical example") teCalc = teCalcClass() teCalc.setProperty("knns", "4") -teCalc.setProperty("k_HISTORY", "2") -teCalc.setProperty("l_HISTORY", "1") +teCalc.setProperty("k_HISTORY", "1") +teCalc.setProperty("l_HISTORY", "2") #teCalc.setProperty("NUM_SAMPLES_MULTIPLIER", "1") #teCalc.setProperty("NORM_TYPE", "MAX_NORM") @@ -174,4 +184,6 @@ for i in range(NUM_REPS): result = teCalc.computeAverageLocalOfObservations() results_canonical[i] = result print("TE result %.4f nats" % (result,)) + sig = teCalc.computeSignificance(10, result) + print(sig.pValue) print("Summary: mean ", np.mean(results_canonical), " std dev ", np.std(results_canonical))