integrating artemis changes

This commit is contained in:
David Shorten 2022-02-22 14:39:15 +11:00
parent b4e32727c9
commit 285c95947a
1 changed files with 34 additions and 16 deletions

View File

@ -1,10 +1,6 @@
package infodynamics.measures.spiking.integration;
import java.util.Arrays;
import java.util.Iterator;
import java.util.PriorityQueue;
import java.util.Vector;
import java.util.Collections;
import java.util.Iterator;
import java.util.PriorityQueue;
@ -22,7 +18,6 @@ import infodynamics.utils.MatrixUtils;
import infodynamics.utils.NeighbourNodeData;
import infodynamics.utils.FirstIndexComparatorDouble;
import infodynamics.utils.UnivariateNearestNeighbourSearcher;
import infodynamics.utils.EuclideanUtils;
import infodynamics.utils.ParsedProperties;
@ -78,11 +73,9 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr
*/
protected int Knns = 4;
/**
* Storage for source observations supplied via
* {@link #addObservations(double[], double[])} etc.
*/
protected Vector<double[]> vectorOfSourceSpikeTimes = null;
@ -164,6 +157,7 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr
super();
}
/*
* (non-Javadoc)
*
@ -296,6 +290,7 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr
}
}
public void appendConditionalIntervals(int[] intervals) throws Exception{
for (int interval : intervals) {
if (interval < 1) {
@ -445,10 +440,10 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr
// Initialise the starting points of all the tracking variables
int embeddingPointIndex = indexOfFirstPointToUse;
int mostRecentDestIndex = destPastIntervals[destPastIntervals.length - 1];
int mostRecentSourceIndex = sourcePastIntervals[sourcePastIntervals.length - 1];
int mostRecentSourceIndex = sourcePastIntervals[sourcePastIntervals.length - 1] - 1;
int[] mostRecentConditioningIndices = new int[vectorOfCondPastIntervals.size()];
for (int i = 0; i < vectorOfCondPastIntervals.size(); i++) {
mostRecentConditioningIndices[i] = vectorOfCondPastIntervals.elementAt(i).length;
mostRecentConditioningIndices[i] = vectorOfCondPastIntervals.elementAt(i)[vectorOfCondPastIntervals.elementAt(i).length - 1] - 1;
}
@ -542,11 +537,14 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr
// Add Gaussian noise, if necessary
if (addNoise) {
for (int i = 0; i < conditioningPast.length; i++) {
//conditioningPast[i] = Math.exp(-conditioningPast[i]);
conditioningPast[i] = Math.log(conditioningPast[i] + 1.1);
conditioningPast[i] += random.nextGaussian() * noiseLevel;
}
for (int i = 0; i < jointPast.length; i++) {
//jointPast[i] = Math.exp(-jointPast[i]);
if (jointPast[i] < 0) {
System.out.println("NEGATIVE");
}
jointPast[i] = Math.log(jointPast[i] + 1.1);
jointPast[i] += random.nextGaussian() * noiseLevel;
}
}
@ -568,7 +566,7 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr
int firstTargetIndexOfEmbedding = destPastIntervals[destPastIntervals.length - 1];
int furthestInterval = sourcePastIntervals[sourcePastIntervals.length - 1];
while (destSpikeTimes[firstTargetIndexOfEmbedding] <= sourceSpikeTimes[furthestInterval - 1]) {
while (destSpikeTimes[firstTargetIndexOfEmbedding] < sourceSpikeTimes[furthestInterval - 1]) {
firstTargetIndexOfEmbedding++;
}
if (conditionalSpikeTimes.length != vectorOfCondPastIntervals.size()) {
@ -597,11 +595,27 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr
int num_samples = (int) Math.round(actualNumSamplesMultiplier * (destSpikeTimes.length - firstTargetIndexOfEmbedding + 1));
double[] randomSampleTimes = new double[num_samples];
Random rand = new Random();
boolean doCellCulture = true;
if (doCellCulture) {
for (int i = 0; i < randomSampleTimes.length; i++) {
randomSampleTimes[i] = destSpikeTimes[firstTargetIndexOfEmbedding + (i % (destSpikeTimes.length - firstTargetIndexOfEmbedding - 1))]
+ 200 * (rand.nextDouble() - 0.5);
//randomSampleTimes[i] = -1.0;
if ((randomSampleTimes[i] > sampleUpperBound) || (randomSampleTimes[i] < sampleLowerBound)) {
randomSampleTimes[i] = sampleLowerBound + rand.nextDouble() * (sampleUpperBound - sampleLowerBound);
}
}
} else {
for (int i = 0; i < randomSampleTimes.length; i++) {
randomSampleTimes[i] = sampleLowerBound + rand.nextDouble() * (sampleUpperBound - sampleLowerBound);
}
}
Arrays.sort(randomSampleTimes);
/*System.out.println(sampleLowerBound + " " + sampleUpperBound);
for (int i = 0; i < randomSampleTimes.length; i += 100) {
System.out.print(randomSampleTimes[i] + " ");
}
System.out.println("\n\n\n");*/
return randomSampleTimes;
}
@ -774,6 +788,7 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr
* squared euclidean distance is not a distance metric. We get around this by just taking the square root here, but it might be better to
* fix this in the KdTree class.
*/
double tempRadiusJointSamples = radiusJointSamples;
if (normType == EuclideanUtils.NORM_EUCLIDEAN) {
radiusJointSpikes = Math.sqrt(radiusJointSpikes);
radiusJointSamples = Math.sqrt(radiusJointSamples);
@ -786,7 +801,10 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr
MathsUtils.digamma(kConditioningSpikes) + MathsUtils.digamma(kConditioningSamples) +
+ ((numDestPastIntervals + numCondPastIntervals) * (Math.log(radiusConditioningSpikes) - Math.log(radiusConditioningSamples))));
if (Double.isNaN(currentSum)) {
throw new Exception("NaNs in TE clac");
for (double[] embed : jointEmbeddingsFromSamples) {
System.out.println(Arrays.toString(embed));
}
throw new Exception("NaNs in TE clac " + radiusJointSpikes + " " + radiusJointSamples + " " + tempRadiusJointSamples);
}
}
// Normalise by time