From 285c95947a037e8f72b59679f1cc5219f84436fc Mon Sep 17 00:00:00 2001 From: David Shorten Date: Tue, 22 Feb 2022 14:39:15 +1100 Subject: [PATCH] integrating artemis changes --- ...erEntropyCalculatorSpikingIntegration.java | 50 +++++++++++++------ 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegration.java b/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegration.java index 89094d8..ede39a4 100644 --- a/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegration.java +++ b/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegration.java @@ -1,10 +1,6 @@ package infodynamics.measures.spiking.integration; import java.util.Arrays; -import java.util.Iterator; -import java.util.PriorityQueue; -import java.util.Vector; - import java.util.Collections; import java.util.Iterator; import java.util.PriorityQueue; @@ -22,7 +18,6 @@ import infodynamics.utils.MatrixUtils; import infodynamics.utils.NeighbourNodeData; import infodynamics.utils.FirstIndexComparatorDouble; import infodynamics.utils.UnivariateNearestNeighbourSearcher; - import infodynamics.utils.EuclideanUtils; import infodynamics.utils.ParsedProperties; @@ -78,11 +73,9 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr */ protected int Knns = 4; - /** * Storage for source observations supplied via * {@link #addObservations(double[], double[])} etc. - */ protected Vector vectorOfSourceSpikeTimes = null; @@ -163,6 +156,7 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr public TransferEntropyCalculatorSpikingIntegration() { super(); } + /* * (non-Javadoc) @@ -296,6 +290,7 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr } } + public void appendConditionalIntervals(int[] intervals) throws Exception{ for (int interval : intervals) { if (interval < 1) { @@ -445,10 +440,10 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr // Initialise the starting points of all the tracking variables int embeddingPointIndex = indexOfFirstPointToUse; int mostRecentDestIndex = destPastIntervals[destPastIntervals.length - 1]; - int mostRecentSourceIndex = sourcePastIntervals[sourcePastIntervals.length - 1]; + int mostRecentSourceIndex = sourcePastIntervals[sourcePastIntervals.length - 1] - 1; int[] mostRecentConditioningIndices = new int[vectorOfCondPastIntervals.size()]; for (int i = 0; i < vectorOfCondPastIntervals.size(); i++) { - mostRecentConditioningIndices[i] = vectorOfCondPastIntervals.elementAt(i).length; + mostRecentConditioningIndices[i] = vectorOfCondPastIntervals.elementAt(i)[vectorOfCondPastIntervals.elementAt(i).length - 1] - 1; } @@ -542,11 +537,14 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr // Add Gaussian noise, if necessary if (addNoise) { for (int i = 0; i < conditioningPast.length; i++) { - //conditioningPast[i] = Math.exp(-conditioningPast[i]); + conditioningPast[i] = Math.log(conditioningPast[i] + 1.1); conditioningPast[i] += random.nextGaussian() * noiseLevel; } for (int i = 0; i < jointPast.length; i++) { - //jointPast[i] = Math.exp(-jointPast[i]); + if (jointPast[i] < 0) { + System.out.println("NEGATIVE"); + } + jointPast[i] = Math.log(jointPast[i] + 1.1); jointPast[i] += random.nextGaussian() * noiseLevel; } } @@ -568,7 +566,7 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr int firstTargetIndexOfEmbedding = destPastIntervals[destPastIntervals.length - 1]; int furthestInterval = sourcePastIntervals[sourcePastIntervals.length - 1]; - while (destSpikeTimes[firstTargetIndexOfEmbedding] <= sourceSpikeTimes[furthestInterval - 1]) { + while (destSpikeTimes[firstTargetIndexOfEmbedding] < sourceSpikeTimes[furthestInterval - 1]) { firstTargetIndexOfEmbedding++; } if (conditionalSpikeTimes.length != vectorOfCondPastIntervals.size()) { @@ -597,11 +595,27 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr int num_samples = (int) Math.round(actualNumSamplesMultiplier * (destSpikeTimes.length - firstTargetIndexOfEmbedding + 1)); double[] randomSampleTimes = new double[num_samples]; Random rand = new Random(); - for (int i = 0; i < randomSampleTimes.length; i++) { - randomSampleTimes[i] = sampleLowerBound + rand.nextDouble() * (sampleUpperBound - sampleLowerBound); + boolean doCellCulture = true; + if (doCellCulture) { + for (int i = 0; i < randomSampleTimes.length; i++) { + randomSampleTimes[i] = destSpikeTimes[firstTargetIndexOfEmbedding + (i % (destSpikeTimes.length - firstTargetIndexOfEmbedding - 1))] + + 200 * (rand.nextDouble() - 0.5); + //randomSampleTimes[i] = -1.0; + if ((randomSampleTimes[i] > sampleUpperBound) || (randomSampleTimes[i] < sampleLowerBound)) { + randomSampleTimes[i] = sampleLowerBound + rand.nextDouble() * (sampleUpperBound - sampleLowerBound); + } + } + } else { + for (int i = 0; i < randomSampleTimes.length; i++) { + randomSampleTimes[i] = sampleLowerBound + rand.nextDouble() * (sampleUpperBound - sampleLowerBound); + } } Arrays.sort(randomSampleTimes); - + /*System.out.println(sampleLowerBound + " " + sampleUpperBound); + for (int i = 0; i < randomSampleTimes.length; i += 100) { + System.out.print(randomSampleTimes[i] + " "); + } + System.out.println("\n\n\n");*/ return randomSampleTimes; } @@ -774,6 +788,7 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr * squared euclidean distance is not a distance metric. We get around this by just taking the square root here, but it might be better to * fix this in the KdTree class. */ + double tempRadiusJointSamples = radiusJointSamples; if (normType == EuclideanUtils.NORM_EUCLIDEAN) { radiusJointSpikes = Math.sqrt(radiusJointSpikes); radiusJointSamples = Math.sqrt(radiusJointSamples); @@ -786,7 +801,10 @@ public class TransferEntropyCalculatorSpikingIntegration implements TransferEntr MathsUtils.digamma(kConditioningSpikes) + MathsUtils.digamma(kConditioningSamples) + + ((numDestPastIntervals + numCondPastIntervals) * (Math.log(radiusConditioningSpikes) - Math.log(radiusConditioningSamples)))); if (Double.isNaN(currentSum)) { - throw new Exception("NaNs in TE clac"); + for (double[] embed : jointEmbeddingsFromSamples) { + System.out.println(Arrays.toString(embed)); + } + throw new Exception("NaNs in TE clac " + radiusJointSpikes + " " + radiusJointSamples + " " + tempRadiusJointSamples); } } // Normalise by time