some progress towards PCB spiking estimator

This commit is contained in:
David Shorten 2021-07-22 14:19:36 +10:00
parent 4f889df786
commit 56dacc5232
2 changed files with 180 additions and 7 deletions

View File

@ -1,6 +1,7 @@
package infodynamics.measures.spiking.integration;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.PriorityQueue;
import java.util.Random;
@ -70,6 +71,13 @@ public class TransferEntropyCalculatorSpikingIntegration implements
* Cache of the timing data for each new observed spiking event for the
* destination only
*/
Vector<double[][]> targetEmbeddingsFromSpikes = null;
Vector<double[][]> sourceEmbeddingsFromSpikes = null;
Vector<double[][]> targetEmbeddingsFromSamples = null;
Vector<double[][]> sourceEmbeddingsFromSamples = null;
Vector<double[][]> destPastAndNextTimings = null;
/**
* Cache of the type of event for each new observed spiking event in both the source
@ -286,6 +294,12 @@ public class TransferEntropyCalculatorSpikingIntegration implements
eventIndexLocator = new Vector<Integer>();
numEventsPerObservationSet = new Vector<Integer>();
// New
targetEmbeddingsFromSpikes = new Vector<double[][]>();
sourceEmbeddingsFromSpikes = new Vector<double[][]>();
targetEmbeddingsFromSamples = new Vector<double[][]>();
sourceEmbeddingsFromSamples = new Vector<double[][]>();
// Send all of the observations through:
Iterator<double[]> sourceIterator = vectorOfSourceSpikeTimes.iterator();
int timeSeriesIndex = 0;
@ -295,7 +309,9 @@ public class TransferEntropyCalculatorSpikingIntegration implements
processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes,
timeSeriesIndex, eventTimings, destPastAndNextTimings,
eventTypeLocator, eventIndexLocator, numEventsPerObservationSet);
eventTypeLocator, eventIndexLocator, numEventsPerObservationSet,
targetEmbeddingsFromSpikes, sourceEmbeddingsFromSpikes,
targetEmbeddingsFromSamples, sourceEmbeddingsFromSamples);
}
// Now we have collected all the events.
@ -353,9 +369,14 @@ public class TransferEntropyCalculatorSpikingIntegration implements
protected void processEventsFromSpikingTimeSeries(double[] sourceSpikeTimes, double[] destSpikeTimes,
int timeSeriesIndex, Vector<double[][]>[] eventTimings,
Vector<double[][]> destPastAndNextTimings, Vector<Integer> eventTypeLocator,
Vector<Integer> eventIndexLocator, Vector<Integer> numEventsPerObservationSet) throws Exception {
Vector<Integer> eventIndexLocator, Vector<Integer> numEventsPerObservationSet,
Vector<double[][]> targetEmbeddingsFromSpikes, Vector<double[][]> sourceEmbeddingsFromSpikes,
Vector<double[][]> targetEmbeddingsFromSamples, Vector<double[][]> sourceEmbeddingsFromSamples)
throws Exception {
// addObservationsAfterParamsDetermined(sourceSpikeTimes, destSpikeTimes);
System.out.println("foo");
// First sort the spike times in case they were not properly in ascending order:
Arrays.sort(sourceSpikeTimes);
Arrays.sort(destSpikeTimes);
@ -501,6 +522,158 @@ public class TransferEntropyCalculatorSpikingIntegration implements
if (debug) {
System.out.printf("Finished processing %d source-target events for observation set %d\n", numEvents, timeSeriesIndex);
}
// New
int NUM_SAMPLES = 1000;
double sample_lower_bound = Arrays.stream(sourceSpikeTimes).min().getAsDouble();
double sample_upper_bound = Arrays.stream(sourceSpikeTimes).max().getAsDouble();
double[] randomSampleTimes = new double[NUM_SAMPLES];
Random rand = new Random();
for (int i = 0; i < randomSampleTimes.length; i++) {
randomSampleTimes[i] = sample_lower_bound + rand.nextDouble() * (sample_upper_bound - sample_lower_bound);
}
// // Scan to find the indices by which we have k and l spikes for dest and source
// // respectively
// int dest_index = k - 1;
// int source_index = l - 1;
// if (sourceSpikeTimes[source_index] > destSpikeTimes[dest_index]) {
// // Minimum required Source spikes are later than the dest.
// // Need to advance dest_index until it's the most recent before source_index
// for(;dest_index < destSpikeTimes.length; dest_index++) {
// if (destSpikeTimes[dest_index] > sourceSpikeTimes[source_index]) {
// // We've gone past the set of source spikes we have, we
// // can move back one in the dest series
// dest_index--;
// break;
// }
// }
// if (dest_index == destSpikeTimes.length) {
// // We didn't have enough spikes in this series to generate any observations
// // TODO work out how to handle this later -- I think this is ok
// numEventsPerObservationSet.add(0);
// return;
// // throw new Exception("Dest spikes stop before enough source spikes in time-series " + timeSeriesIndex);
// }
// } else {
// // Minimum required Dest spikes are later than the source.
// // Need to advance source_index until it's the most recent before dest_index
// for(;source_index < sourceSpikeTimes.length; source_index++) {
// if (sourceSpikeTimes[source_index] > destSpikeTimes[dest_index]) {
// // We've gone past the set of dest spikes we have, we
// // can move back one in the source series
// source_index--;
// break;
// }
// }
// if (source_index == sourceSpikeTimes.length) {
// // We didn't have enough spikes in this series to generate any observations
// // TODO work out how to handle this later -- I think this is ok
// numEventsPerObservationSet.add(0);
// return;
// // throw new Exception("Source spikes stop before enough dest spikes in time-series " + timeSeriesIndex);
// }
// }
// // Post-condition: dest_index and source_index are set correctly for the first set of pasts
// double timeToNextSpike;
// boolean nextIsDest = false;
// double[] spikeTimesForNextSpiker;
// double timeOfPrevDestSpike = destSpikeTimes[dest_index];
// int numEvents = 0;
// Random random = null;
// if (addNoise) {
// random = new Random();
// }
// while(true) {
// // 0. Check whether we're finished
// if ((source_index == sourceSpikeTimes.length - 1) &&
// (dest_index == destSpikeTimes.length - 1)) {
// // We have no next spike so we can't take an observation here
// // and we're done
// break;
// }
// // Otherwise:
// // 1. Determine which of source / dest fires next
// if (source_index == sourceSpikeTimes.length - 1) {
// nextIsDest = true;
// } else if (dest_index == destSpikeTimes.length - 1) {
// nextIsDest = false;
// } else if (sourceSpikeTimes[source_index+1] < destSpikeTimes[dest_index+1]) {
// nextIsDest = false;
// } else {
// nextIsDest = true;
// }
// spikeTimesForNextSpiker = nextIsDest ? destSpikeTimes : sourceSpikeTimes;
// int indexForNextSpiker = nextIsDest ? dest_index : source_index;
// timeToNextSpike = spikeTimesForNextSpiker[indexForNextSpiker+1] - timeOfPrevDestSpike;
// if (addNoise) {
// timeToNextSpike += random.nextGaussian()*noiseLevel;
// }
// // 2. Embed the past spikes
// double[] sourcePast = new double[l];
// double[] destPast = new double[k - 1];
// /* if (debug) {
// System.out.println("previousIsDest = " + previousIsDest + " and nextIsDest = " + nextIsDest);
// }*/
// sourcePast[0] = timeOfPrevDestSpike -
// sourceSpikeTimes[source_index];
// if (addNoise) {
// sourcePast[0] += random.nextGaussian()*noiseLevel;
// }
// for (int i = 1; i < k; i++) {
// destPast[i - 1] = destSpikeTimes[dest_index - i + 1] -
// destSpikeTimes[dest_index - i];
// if (addNoise) {
// destPast[i - 1] += random.nextGaussian()*noiseLevel;
// }
// }
// for (int i = 1; i < l; i++) {
// sourcePast[i] = sourceSpikeTimes[source_index - i + 1] -
// sourceSpikeTimes[source_index - i];
// if (addNoise) {
// sourcePast[i] += random.nextGaussian()*noiseLevel;
// }
// }
// // 3. Store these embedded observations
// double[][] observations = new double[][]{sourcePast, destPast,
// new double[] {timeToNextSpike}};
// if (debug) {
// System.out.printf("Adding event %d with: timeToNextSpike=%.4f, sourceSpikeTimes=", numEvents, timeToNextSpike);
// MatrixUtils.printArray(System.out, sourcePast, 3);
// System.out.printf(", destSpikeTimes=");
// MatrixUtils.printArray(System.out, destPast, 3);
// System.out.println();
// }
// // Add the index locator first so it gets the index correct before
// // we add the new event in:
// eventIndexLocator.add(eventTimings[nextIsDest ? NEXT_DEST : NEXT_SOURCE].size());
// eventTimings[nextIsDest ? NEXT_DEST : NEXT_SOURCE].add(observations);
// // TODO Switch eventTypeLocation to be of type Integer rather than int[]
// eventTypeLocator.add(nextIsDest ? NEXT_DEST : NEXT_SOURCE);
// // And finally store the observations for the dest only
// // search structure if required:
// if (nextIsDest) {
// double[][] destOnlyObservations;
// destOnlyObservations = new double[][] {
// destPast,
// new double[] {timeToNextSpike}
// };
// destPastAndNextTimings.add(destOnlyObservations);
// }
// // 4. Reset variables
// if (nextIsDest) {
// dest_index++;
// } else {
// source_index++;
// }
// timeOfPrevDestSpike = destSpikeTimes[dest_index];
// numEvents++;
// }
// numEventsPerObservationSet.add(numEvents);
// if (debug) {
// System.out.printf("Finished processing %d source-target events for observation set %d\n", numEvents, timeSeriesIndex);
// }
}
/* (non-Javadoc)

View File

@ -58,7 +58,7 @@ teCalcClass = JPackage("infodynamics.measures.spiking.integration").TransferEntr
teCalc = teCalcClass()
teCalc.setProperty("NORMALISE", "true") # Normalise the individual variables
teCalc.initialise(1) # Use history length 1 (Schreiber k=1)
teCalc.setProperty("k", "4") # Use Kraskov parameter K=4 for 4 nearest points
teCalc.setProperty("knns", "4") # Use Kraskov parameter K=4 for 4 nearest points
# # Perform calculation with correlated source:
teCalc.setObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray))
result = teCalc.computeAverageLocalOfObservations()