mirror of https://github.com/jlizier/jidt
some progress towards PCB spiking estimator
This commit is contained in:
parent
4f889df786
commit
56dacc5232
|
@ -1,6 +1,7 @@
|
|||
package infodynamics.measures.spiking.integration;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.Iterator;
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.Random;
|
||||
|
@ -70,6 +71,13 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
* Cache of the timing data for each new observed spiking event for the
|
||||
* destination only
|
||||
*/
|
||||
|
||||
Vector<double[][]> targetEmbeddingsFromSpikes = null;
|
||||
Vector<double[][]> sourceEmbeddingsFromSpikes = null;
|
||||
Vector<double[][]> targetEmbeddingsFromSamples = null;
|
||||
Vector<double[][]> sourceEmbeddingsFromSamples = null;
|
||||
|
||||
|
||||
Vector<double[][]> destPastAndNextTimings = null;
|
||||
/**
|
||||
* Cache of the type of event for each new observed spiking event in both the source
|
||||
|
@ -286,6 +294,12 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
eventIndexLocator = new Vector<Integer>();
|
||||
numEventsPerObservationSet = new Vector<Integer>();
|
||||
|
||||
// New
|
||||
targetEmbeddingsFromSpikes = new Vector<double[][]>();
|
||||
sourceEmbeddingsFromSpikes = new Vector<double[][]>();
|
||||
targetEmbeddingsFromSamples = new Vector<double[][]>();
|
||||
sourceEmbeddingsFromSamples = new Vector<double[][]>();
|
||||
|
||||
// Send all of the observations through:
|
||||
Iterator<double[]> sourceIterator = vectorOfSourceSpikeTimes.iterator();
|
||||
int timeSeriesIndex = 0;
|
||||
|
@ -295,7 +309,9 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
|
||||
processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes,
|
||||
timeSeriesIndex, eventTimings, destPastAndNextTimings,
|
||||
eventTypeLocator, eventIndexLocator, numEventsPerObservationSet);
|
||||
eventTypeLocator, eventIndexLocator, numEventsPerObservationSet,
|
||||
targetEmbeddingsFromSpikes, sourceEmbeddingsFromSpikes,
|
||||
targetEmbeddingsFromSamples, sourceEmbeddingsFromSamples);
|
||||
}
|
||||
|
||||
// Now we have collected all the events.
|
||||
|
@ -353,9 +369,14 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
protected void processEventsFromSpikingTimeSeries(double[] sourceSpikeTimes, double[] destSpikeTimes,
|
||||
int timeSeriesIndex, Vector<double[][]>[] eventTimings,
|
||||
Vector<double[][]> destPastAndNextTimings, Vector<Integer> eventTypeLocator,
|
||||
Vector<Integer> eventIndexLocator, Vector<Integer> numEventsPerObservationSet) throws Exception {
|
||||
Vector<Integer> eventIndexLocator, Vector<Integer> numEventsPerObservationSet,
|
||||
Vector<double[][]> targetEmbeddingsFromSpikes, Vector<double[][]> sourceEmbeddingsFromSpikes,
|
||||
Vector<double[][]> targetEmbeddingsFromSamples, Vector<double[][]> sourceEmbeddingsFromSamples)
|
||||
throws Exception {
|
||||
// addObservationsAfterParamsDetermined(sourceSpikeTimes, destSpikeTimes);
|
||||
|
||||
System.out.println("foo");
|
||||
|
||||
// First sort the spike times in case they were not properly in ascending order:
|
||||
Arrays.sort(sourceSpikeTimes);
|
||||
Arrays.sort(destSpikeTimes);
|
||||
|
@ -501,6 +522,158 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
if (debug) {
|
||||
System.out.printf("Finished processing %d source-target events for observation set %d\n", numEvents, timeSeriesIndex);
|
||||
}
|
||||
|
||||
// New
|
||||
int NUM_SAMPLES = 1000;
|
||||
double sample_lower_bound = Arrays.stream(sourceSpikeTimes).min().getAsDouble();
|
||||
double sample_upper_bound = Arrays.stream(sourceSpikeTimes).max().getAsDouble();
|
||||
double[] randomSampleTimes = new double[NUM_SAMPLES];
|
||||
Random rand = new Random();
|
||||
for (int i = 0; i < randomSampleTimes.length; i++) {
|
||||
randomSampleTimes[i] = sample_lower_bound + rand.nextDouble() * (sample_upper_bound - sample_lower_bound);
|
||||
}
|
||||
|
||||
// // Scan to find the indices by which we have k and l spikes for dest and source
|
||||
// // respectively
|
||||
// int dest_index = k - 1;
|
||||
// int source_index = l - 1;
|
||||
// if (sourceSpikeTimes[source_index] > destSpikeTimes[dest_index]) {
|
||||
// // Minimum required Source spikes are later than the dest.
|
||||
// // Need to advance dest_index until it's the most recent before source_index
|
||||
// for(;dest_index < destSpikeTimes.length; dest_index++) {
|
||||
// if (destSpikeTimes[dest_index] > sourceSpikeTimes[source_index]) {
|
||||
// // We've gone past the set of source spikes we have, we
|
||||
// // can move back one in the dest series
|
||||
// dest_index--;
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// if (dest_index == destSpikeTimes.length) {
|
||||
// // We didn't have enough spikes in this series to generate any observations
|
||||
// // TODO work out how to handle this later -- I think this is ok
|
||||
// numEventsPerObservationSet.add(0);
|
||||
// return;
|
||||
// // throw new Exception("Dest spikes stop before enough source spikes in time-series " + timeSeriesIndex);
|
||||
// }
|
||||
// } else {
|
||||
// // Minimum required Dest spikes are later than the source.
|
||||
// // Need to advance source_index until it's the most recent before dest_index
|
||||
// for(;source_index < sourceSpikeTimes.length; source_index++) {
|
||||
// if (sourceSpikeTimes[source_index] > destSpikeTimes[dest_index]) {
|
||||
// // We've gone past the set of dest spikes we have, we
|
||||
// // can move back one in the source series
|
||||
// source_index--;
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// if (source_index == sourceSpikeTimes.length) {
|
||||
// // We didn't have enough spikes in this series to generate any observations
|
||||
// // TODO work out how to handle this later -- I think this is ok
|
||||
// numEventsPerObservationSet.add(0);
|
||||
// return;
|
||||
// // throw new Exception("Source spikes stop before enough dest spikes in time-series " + timeSeriesIndex);
|
||||
// }
|
||||
// }
|
||||
// // Post-condition: dest_index and source_index are set correctly for the first set of pasts
|
||||
|
||||
// double timeToNextSpike;
|
||||
// boolean nextIsDest = false;
|
||||
// double[] spikeTimesForNextSpiker;
|
||||
// double timeOfPrevDestSpike = destSpikeTimes[dest_index];
|
||||
// int numEvents = 0;
|
||||
// Random random = null;
|
||||
// if (addNoise) {
|
||||
// random = new Random();
|
||||
// }
|
||||
// while(true) {
|
||||
// // 0. Check whether we're finished
|
||||
// if ((source_index == sourceSpikeTimes.length - 1) &&
|
||||
// (dest_index == destSpikeTimes.length - 1)) {
|
||||
// // We have no next spike so we can't take an observation here
|
||||
// // and we're done
|
||||
// break;
|
||||
// }
|
||||
// // Otherwise:
|
||||
// // 1. Determine which of source / dest fires next
|
||||
// if (source_index == sourceSpikeTimes.length - 1) {
|
||||
// nextIsDest = true;
|
||||
// } else if (dest_index == destSpikeTimes.length - 1) {
|
||||
// nextIsDest = false;
|
||||
// } else if (sourceSpikeTimes[source_index+1] < destSpikeTimes[dest_index+1]) {
|
||||
// nextIsDest = false;
|
||||
// } else {
|
||||
// nextIsDest = true;
|
||||
// }
|
||||
// spikeTimesForNextSpiker = nextIsDest ? destSpikeTimes : sourceSpikeTimes;
|
||||
// int indexForNextSpiker = nextIsDest ? dest_index : source_index;
|
||||
// timeToNextSpike = spikeTimesForNextSpiker[indexForNextSpiker+1] - timeOfPrevDestSpike;
|
||||
// if (addNoise) {
|
||||
// timeToNextSpike += random.nextGaussian()*noiseLevel;
|
||||
// }
|
||||
// // 2. Embed the past spikes
|
||||
// double[] sourcePast = new double[l];
|
||||
// double[] destPast = new double[k - 1];
|
||||
// /* if (debug) {
|
||||
// System.out.println("previousIsDest = " + previousIsDest + " and nextIsDest = " + nextIsDest);
|
||||
// }*/
|
||||
// sourcePast[0] = timeOfPrevDestSpike -
|
||||
// sourceSpikeTimes[source_index];
|
||||
// if (addNoise) {
|
||||
// sourcePast[0] += random.nextGaussian()*noiseLevel;
|
||||
// }
|
||||
// for (int i = 1; i < k; i++) {
|
||||
// destPast[i - 1] = destSpikeTimes[dest_index - i + 1] -
|
||||
// destSpikeTimes[dest_index - i];
|
||||
// if (addNoise) {
|
||||
// destPast[i - 1] += random.nextGaussian()*noiseLevel;
|
||||
// }
|
||||
// }
|
||||
// for (int i = 1; i < l; i++) {
|
||||
// sourcePast[i] = sourceSpikeTimes[source_index - i + 1] -
|
||||
// sourceSpikeTimes[source_index - i];
|
||||
// if (addNoise) {
|
||||
// sourcePast[i] += random.nextGaussian()*noiseLevel;
|
||||
// }
|
||||
// }
|
||||
// // 3. Store these embedded observations
|
||||
// double[][] observations = new double[][]{sourcePast, destPast,
|
||||
// new double[] {timeToNextSpike}};
|
||||
// if (debug) {
|
||||
// System.out.printf("Adding event %d with: timeToNextSpike=%.4f, sourceSpikeTimes=", numEvents, timeToNextSpike);
|
||||
// MatrixUtils.printArray(System.out, sourcePast, 3);
|
||||
// System.out.printf(", destSpikeTimes=");
|
||||
// MatrixUtils.printArray(System.out, destPast, 3);
|
||||
// System.out.println();
|
||||
// }
|
||||
// // Add the index locator first so it gets the index correct before
|
||||
// // we add the new event in:
|
||||
// eventIndexLocator.add(eventTimings[nextIsDest ? NEXT_DEST : NEXT_SOURCE].size());
|
||||
// eventTimings[nextIsDest ? NEXT_DEST : NEXT_SOURCE].add(observations);
|
||||
// // TODO Switch eventTypeLocation to be of type Integer rather than int[]
|
||||
// eventTypeLocator.add(nextIsDest ? NEXT_DEST : NEXT_SOURCE);
|
||||
// // And finally store the observations for the dest only
|
||||
// // search structure if required:
|
||||
// if (nextIsDest) {
|
||||
// double[][] destOnlyObservations;
|
||||
// destOnlyObservations = new double[][] {
|
||||
// destPast,
|
||||
// new double[] {timeToNextSpike}
|
||||
// };
|
||||
// destPastAndNextTimings.add(destOnlyObservations);
|
||||
// }
|
||||
// // 4. Reset variables
|
||||
// if (nextIsDest) {
|
||||
// dest_index++;
|
||||
// } else {
|
||||
// source_index++;
|
||||
// }
|
||||
// timeOfPrevDestSpike = destSpikeTimes[dest_index];
|
||||
// numEvents++;
|
||||
// }
|
||||
// numEventsPerObservationSet.add(numEvents);
|
||||
// if (debug) {
|
||||
// System.out.printf("Finished processing %d source-target events for observation set %d\n", numEvents, timeSeriesIndex);
|
||||
// }
|
||||
}
|
||||
|
||||
/* (non-Javadoc)
|
||||
|
|
|
@ -58,7 +58,7 @@ teCalcClass = JPackage("infodynamics.measures.spiking.integration").TransferEntr
|
|||
teCalc = teCalcClass()
|
||||
teCalc.setProperty("NORMALISE", "true") # Normalise the individual variables
|
||||
teCalc.initialise(1) # Use history length 1 (Schreiber k=1)
|
||||
teCalc.setProperty("k", "4") # Use Kraskov parameter K=4 for 4 nearest points
|
||||
teCalc.setProperty("knns", "4") # Use Kraskov parameter K=4 for 4 nearest points
|
||||
# # Perform calculation with correlated source:
|
||||
teCalc.setObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray))
|
||||
result = teCalc.computeAverageLocalOfObservations()
|
||||
|
|
Loading…
Reference in New Issue