mirror of https://github.com/jlizier/jidt
Incorporated JL code from 21/11/2018
This commit is contained in:
parent
fa9a45a9cf
commit
ec01683592
|
@ -3,6 +3,7 @@ package infodynamics.measures.spiking.integration;
|
|||
import java.util.Arrays;
|
||||
import java.util.Iterator;
|
||||
import java.util.PriorityQueue;
|
||||
import java.util.Random;
|
||||
import java.util.Vector;
|
||||
|
||||
import infodynamics.measures.spiking.TransferEntropyCalculatorSpiking;
|
||||
|
@ -120,7 +121,20 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
* source or destination spike)
|
||||
*/
|
||||
public static final String TRIM_TO_POS_PROP_NAME = "TRIM_RANGE_TO_POS_TIMES";
|
||||
|
||||
/**
|
||||
* Property name for an amount of random Gaussian noise to be
|
||||
* added to the data (default is 1e-8, matching the MILCA toolkit).
|
||||
*/
|
||||
public static final String PROP_ADD_NOISE = "NOISE_LEVEL_TO_ADD";
|
||||
/**
|
||||
* Whether to add an amount of random noise to the incoming data
|
||||
*/
|
||||
protected boolean addNoise = true;
|
||||
/**
|
||||
* Amount of random Gaussian noise to add to the incoming data
|
||||
*/
|
||||
protected double noiseLevel = (double) 1e-8;
|
||||
|
||||
protected boolean trimToPosNextSpikeTimes = false;
|
||||
|
||||
/**
|
||||
|
@ -177,6 +191,15 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
Knns = Integer.parseInt(propertyValue);
|
||||
} else if (propertyName.equalsIgnoreCase(TRIM_TO_POS_PROP_NAME)) {
|
||||
trimToPosNextSpikeTimes = Boolean.parseBoolean(propertyValue);
|
||||
} else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) {
|
||||
if (propertyValue.equals("0") ||
|
||||
propertyValue.equalsIgnoreCase("false")) {
|
||||
addNoise = false;
|
||||
noiseLevel = 0;
|
||||
} else {
|
||||
addNoise = true;
|
||||
noiseLevel = Double.parseDouble(propertyValue);
|
||||
}
|
||||
} else {
|
||||
// No property was set on this class
|
||||
propertySet = false;
|
||||
|
@ -200,6 +223,8 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
return Integer.toString(Knns);
|
||||
} else if (propertyName.equalsIgnoreCase(TRIM_TO_POS_PROP_NAME)) {
|
||||
return Boolean.toString(trimToPosNextSpikeTimes);
|
||||
} else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) {
|
||||
return Double.toString(noiseLevel);
|
||||
} else {
|
||||
// No property matches for this class
|
||||
return null;
|
||||
|
@ -379,6 +404,10 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
double[] spikeTimesForNextSpiker;
|
||||
double timeOfPrevDestSpike = destSpikeTimes[dest_index];
|
||||
int numEvents = 0;
|
||||
Random random = null;
|
||||
if (addNoise) {
|
||||
random = new Random();
|
||||
}
|
||||
while(true) {
|
||||
// 0. Check whether we're finished
|
||||
if ((source_index == sourceSpikeTimes.length - 1) &&
|
||||
|
@ -401,6 +430,9 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
spikeTimesForNextSpiker = nextIsDest ? destSpikeTimes : sourceSpikeTimes;
|
||||
int indexForNextSpiker = nextIsDest ? dest_index : source_index;
|
||||
timeToNextSpike = spikeTimesForNextSpiker[indexForNextSpiker+1] - timeOfPrevDestSpike;
|
||||
if (addNoise) {
|
||||
timeToNextSpike += random.nextGaussian()*noiseLevel;
|
||||
}
|
||||
// 2. Embed the past spikes
|
||||
double[] sourcePast = new double[l];
|
||||
double[] destPast = new double[k - 1];
|
||||
|
@ -409,13 +441,22 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
}*/
|
||||
sourcePast[0] = timeOfPrevDestSpike -
|
||||
sourceSpikeTimes[source_index];
|
||||
if (addNoise) {
|
||||
sourcePast[0] += random.nextGaussian()*noiseLevel;
|
||||
}
|
||||
for (int i = 1; i < k; i++) {
|
||||
destPast[i - 1] = destSpikeTimes[dest_index - i + 1] -
|
||||
destSpikeTimes[dest_index - i];
|
||||
if (addNoise) {
|
||||
destPast[i - 1] += random.nextGaussian()*noiseLevel;
|
||||
}
|
||||
}
|
||||
for (int i = 1; i < l; i++) {
|
||||
sourcePast[i] = sourceSpikeTimes[source_index - i + 1] -
|
||||
sourceSpikeTimes[source_index - i];
|
||||
if (addNoise) {
|
||||
sourcePast[i] += random.nextGaussian()*noiseLevel;
|
||||
}
|
||||
}
|
||||
// 3. Store these embedded observations
|
||||
double[][] observations = new double[][]{sourcePast, destPast,
|
||||
|
@ -478,10 +519,6 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
double contributionFromSpikes = 0;
|
||||
double totalTimeLength = 0;
|
||||
|
||||
double digammaK = MathsUtils.digamma(Knns);
|
||||
double twoInverseKTerm = 2.0 / (double) Knns;
|
||||
double inverseKTerm = 1.0 / (double) Knns;
|
||||
|
||||
// Create temporary storage for arrays used in the neighbour counting:
|
||||
boolean[] isWithinR = new boolean[numberOfEvents]; // dummy, we don't really use this
|
||||
int[] indicesWithinR = new int[numberOfEvents];
|
||||
|
@ -509,314 +546,313 @@ public class TransferEntropyCalculatorSpikingIntegration implements
|
|||
}
|
||||
|
||||
// Select only events where the destination spiked next:
|
||||
if (eventType == NEXT_DEST) {
|
||||
// Find the Knns nearest neighbour matches to this event,
|
||||
// with the same previous spiker and the next.
|
||||
// TODO Add dynamic exclusion time later
|
||||
PriorityQueue<NeighbourNodeData> nnPQ =
|
||||
kdTreesJoint[eventType].findKNearestNeighbours(
|
||||
Knns, eventIndexWithinType);
|
||||
|
||||
// Find eps_{x,y,z} as the maximum x, y and z norms amongst this set:
|
||||
double radius_sourcePast = 0.0;
|
||||
double radius_destPast = 0.0;
|
||||
double radius_destNext = 0.0;
|
||||
int radius_destNext_sampleIndex = -1;
|
||||
for (int j = 0; j < Knns; j++) {
|
||||
// Take the furthest remaining of the nearest neighbours from the PQ:
|
||||
NeighbourNodeData nnData = nnPQ.poll();
|
||||
if (nnData.norms[0] > radius_sourcePast) {
|
||||
radius_sourcePast = nnData.norms[0];
|
||||
}
|
||||
if (nnData.norms[1] > radius_destPast) {
|
||||
radius_destPast = nnData.norms[1];
|
||||
}
|
||||
if (nnData.norms[2] > radius_destNext) {
|
||||
radius_destNext = nnData.norms[2];
|
||||
radius_destNext_sampleIndex = nnData.sampleIndex;
|
||||
}
|
||||
}
|
||||
|
||||
if (debug && (eventIndex < 10000)) {
|
||||
System.out.print(", timings: src: ");
|
||||
MatrixUtils.printArray(System.out, thisEventTimings[0], 5);
|
||||
System.out.print(", dest: ");
|
||||
MatrixUtils.printArray(System.out, thisEventTimings[1], 5);
|
||||
System.out.print(", time to next: ");
|
||||
MatrixUtils.printArray(System.out, thisEventTimings[2], 5);
|
||||
System.out.printf("index=%d: K=%d NNs at next_range %.5f (point %d)", eventIndexWithinType, Knns, radius_destNext, radius_destNext_sampleIndex);
|
||||
}
|
||||
|
||||
indexForNextIsDest++;
|
||||
|
||||
// Now find the matching samples in each sub-space;
|
||||
// first match dest history and source history, with a next spike in dest:
|
||||
kdTreesSourceDestHistories[NEXT_DEST].
|
||||
findPointsWithinRs(eventIndexWithinType,
|
||||
new double[] {radius_sourcePast, radius_destPast}, 0,
|
||||
true, isWithinR, indicesWithinR);
|
||||
// And check which of these samples had spike time in dest after ours:
|
||||
int countOfDestNextAndGreater = 0;
|
||||
for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) {
|
||||
// Pull out this matching event from the full joint space
|
||||
double[][] matchedHistoryEventTimings = eventTimings[NEXT_DEST].elementAt(indicesWithinR[nIndex]);
|
||||
if (matchedHistoryEventTimings[2][0] > thisEventTimings[2][0] + radius_destNext) {
|
||||
// This sample had a matched history and next spike was a destination
|
||||
// spike with a longer interval than the current sample
|
||||
countOfDestNextAndGreater++;
|
||||
}
|
||||
// Reset the isWithinR array while we're here
|
||||
isWithinR[indicesWithinR[nIndex]] = false;
|
||||
}
|
||||
// And count how many samples with the matching history actually had a
|
||||
// *source* spike next, after ours.
|
||||
// Note that we now must go to the other kdTree for next source spike
|
||||
kdTreesSourceDestHistories[NEXT_SOURCE].
|
||||
findPointsWithinRs(
|
||||
new double[] {radius_sourcePast, radius_destPast}, thisEventTimings,
|
||||
true, isWithinR, indicesWithinR);
|
||||
// And check which of these samples had spike time in source at or after ours:
|
||||
int countOfSourceNextAndGreater = 0;
|
||||
for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) {
|
||||
// Pull out this matching event from the full joint space
|
||||
double[][] matchedHistoryEventTimings = eventTimings[NEXT_SOURCE].elementAt(indicesWithinR[nIndex]);
|
||||
if (matchedHistoryEventTimings[2][0] >= thisEventTimings[2][0] - radius_destNext) {
|
||||
// This sample had a matched history and next spike was a source
|
||||
// spike with an interval longer than or considered equal to the current sample.
|
||||
// (The "equal to" is why we look for matches within radius_destNext here as well.)
|
||||
countOfSourceNextAndGreater++;
|
||||
}
|
||||
// Reset the isWithinR array while we're here
|
||||
isWithinR[indicesWithinR[nIndex]] = false;
|
||||
}
|
||||
|
||||
// We need to correct radius_destNext to have a different value below, chopping it
|
||||
// where it pushes into negative times (i.e. *before* the previous spike).
|
||||
// TODO I think this is justified by the approach of Greg Ver Steeg et al. in
|
||||
// http://www.jmlr.org/proceedings/papers/v38/gao15.pdf -- it's not the same
|
||||
// as that approach, but a principled version where we know the space cannot
|
||||
// be explored (so we don't need to only apply the correction where it is large).
|
||||
double searchAreaRatio = 0; // Ratio of actual search area for full joint space compared to if the timings were independent
|
||||
double prevSource_timing_upper_original = 0, prevSource_timing_lower = 0;
|
||||
double destNext_timing_lower_original = timeToNextSpikeSincePreviousDestSpike - radius_destNext;
|
||||
double destNext_timing_upper = timeToNextSpikeSincePreviousDestSpike + radius_destNext;
|
||||
double radius_destNext_lower = radius_destNext;
|
||||
if (trimToPosNextSpikeTimes) {
|
||||
if (destNext_timing_lower_original < 0) {
|
||||
// This is the range that's actually being searched in the lower
|
||||
// dimensional space.
|
||||
// (We don't need to adjust margins in lower dimensional space for
|
||||
// this because it simply won't get any points below this!)
|
||||
destNext_timing_lower_original = 0;
|
||||
}
|
||||
// timePreviousSourceSpikeBeforePreviousDestSpike is negative if source is after dest.
|
||||
prevSource_timing_upper_original = -timePreviousSourceSpikeBeforePreviousDestSpike + radius_sourcePast;
|
||||
prevSource_timing_lower = -timePreviousSourceSpikeBeforePreviousDestSpike - radius_sourcePast;
|
||||
double destNext_timing_lower = Math.max(destNext_timing_lower_original,
|
||||
prevSource_timing_lower);
|
||||
double prevSource_timing_upper = Math.min(destNext_timing_upper, prevSource_timing_upper_original);
|
||||
// Now look at various cases for the corrective ratio:
|
||||
// TODO I think we still need to correct destNext_timing_lower_original to radius_destNext_lower here!
|
||||
double denominator = (destNext_timing_upper - destNext_timing_lower_original) *
|
||||
(prevSource_timing_upper_original - prevSource_timing_lower);
|
||||
if (destNext_timing_upper - destNext_timing_lower >
|
||||
prevSource_timing_upper - prevSource_timing_lower) {
|
||||
// Case 1:
|
||||
if (prevSource_timing_upper < destNext_timing_lower) {
|
||||
// Windows do not overlap, so there will be no correction here
|
||||
searchAreaRatio = (destNext_timing_upper - destNext_timing_lower) *
|
||||
(prevSource_timing_upper - prevSource_timing_lower);
|
||||
} else if (destNext_timing_upper < prevSource_timing_lower) {
|
||||
// Should never happen, because next spike can't be before previous source spike
|
||||
searchAreaRatio = 0;
|
||||
throw new RuntimeException("Encountered a next target spike *before* the previous source spike");
|
||||
} else if (destNext_timing_lower < prevSource_timing_lower) {
|
||||
// Dest next lower bound is below source previous lower bound -- definitely
|
||||
// can't get any points in the part below prevSource_timing_lower
|
||||
// for dest next.
|
||||
searchAreaRatio = (prevSource_timing_upper - prevSource_timing_lower) *
|
||||
(destNext_timing_upper - prevSource_timing_upper) +
|
||||
0.5 * (prevSource_timing_upper - prevSource_timing_lower) *
|
||||
(prevSource_timing_upper - prevSource_timing_lower);
|
||||
} else {
|
||||
// destNext_timing_lower >= prevSource_timing_lower
|
||||
// Dest next lower bound is within the previous source window
|
||||
// but above the source lower bound.
|
||||
searchAreaRatio = (destNext_timing_upper - destNext_timing_lower) *
|
||||
(prevSource_timing_upper - prevSource_timing_lower) -
|
||||
0.5 * (prevSource_timing_upper - destNext_timing_lower) *
|
||||
(prevSource_timing_upper - destNext_timing_lower);
|
||||
}
|
||||
} else {
|
||||
// Case 2:
|
||||
if (prevSource_timing_upper < destNext_timing_lower) {
|
||||
// Windows do not overlap, so there will be no correction here
|
||||
searchAreaRatio = (destNext_timing_upper - destNext_timing_lower) *
|
||||
(prevSource_timing_upper - prevSource_timing_lower);
|
||||
} else if (destNext_timing_upper < prevSource_timing_lower) {
|
||||
// Should never happen, because next spike can't be before previous source spike
|
||||
searchAreaRatio = 0;
|
||||
throw new RuntimeException("Encountered a next target spike *before* the previous source spike");
|
||||
} else if (destNext_timing_lower < prevSource_timing_lower) {
|
||||
searchAreaRatio = 0.5 * (destNext_timing_upper - prevSource_timing_lower) *
|
||||
(destNext_timing_upper - prevSource_timing_lower);
|
||||
} else {
|
||||
// destNext_timing_lower >= prevSource_timing_lower
|
||||
searchAreaRatio = (destNext_timing_lower - prevSource_timing_lower) *
|
||||
(destNext_timing_upper - destNext_timing_lower) +
|
||||
0.5 * (destNext_timing_upper - destNext_timing_lower) *
|
||||
(destNext_timing_upper - destNext_timing_lower);
|
||||
}
|
||||
}
|
||||
if (denominator <= 0) {
|
||||
// We'll get NaN if we divide by it. No correction is necessary in this
|
||||
// case, the margins are too thin to cause a discrepancy anyway.
|
||||
searchAreaRatio = 1;
|
||||
} else {
|
||||
searchAreaRatio /= denominator;
|
||||
}
|
||||
// And finally fix up the lower radius on destNext if required:
|
||||
radius_destNext_lower = timeToNextSpikeSincePreviousDestSpike - destNext_timing_lower;
|
||||
}
|
||||
|
||||
if (debug && (eventIndex < 10000)) {
|
||||
System.out.printf(" of %d + %d + %d points with matching S-D history",
|
||||
Knns, countOfSourceNextAndGreater, countOfDestNextAndGreater);
|
||||
}
|
||||
|
||||
// Now find the matching samples in the dest history and
|
||||
// with a next spike timing.
|
||||
int countOfDestNextAndGreaterMatchedDest = 0;
|
||||
int countOfDestNextMatched = 0;
|
||||
if (k > 1) {
|
||||
// Search only the space of dest past -- no point
|
||||
// searching dest past and next, since we need to run through
|
||||
// all matches of dest past to count those with greater next spike
|
||||
// times we might as well count those with matching spike times
|
||||
// while we're at it.
|
||||
kdTreeDestHistory.findPointsWithinR(indexForNextIsDest, radius_destPast,
|
||||
true, isWithinR, indicesWithinR);
|
||||
// And check which of these samples had next spike time after ours:
|
||||
for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) {
|
||||
// Pull out this matching event from the dest history space
|
||||
double[][] matchedHistoryEventTimings = destPastAndNextTimings.elementAt(indicesWithinR[nIndex]);
|
||||
if (matchedHistoryEventTimings[1][0] >= timeToNextSpikeSincePreviousDestSpike - radius_destNext_lower) {
|
||||
// This sample had a matched history and next spike was a
|
||||
// spike with an interval longer than or considered equal to the current sample.
|
||||
// (The "equal to" is why we look for matches within kthNnData.distance here as well.)
|
||||
countOfDestNextAndGreaterMatchedDest++;
|
||||
if (matchedHistoryEventTimings[1][0] <= timeToNextSpikeSincePreviousDestSpike + radius_destNext) {
|
||||
// Then we also have a match on the next spike itself
|
||||
countOfDestNextMatched++;
|
||||
}
|
||||
}
|
||||
// Reset the isWithinR array while we're here
|
||||
isWithinR[indicesWithinR[nIndex]] = false;
|
||||
}
|
||||
} else {
|
||||
// We don't take any past dest spike ISIs into account, so we just need to look at the proportion of next
|
||||
// spike times that match.
|
||||
countOfDestNextMatched = nnSearcherDestTimeToNextSpike.countPointsWithinRs(indexForNextIsDest,
|
||||
radius_destNext, radius_destNext_lower, true);
|
||||
// To check for points matching or larger, we just need to make this call with the
|
||||
// revised lower radius, because we don't check the upper one.
|
||||
countOfDestNextAndGreaterMatchedDest =
|
||||
nnSearcherDestTimeToNextSpike.countPointsWithinROrLarger(indexForNextIsDest, radius_destNext_lower, true);
|
||||
}
|
||||
double totalSearchTimeWindowCondDestPast = radius_destNext_lower + radius_destNext;
|
||||
|
||||
if (debug && (eventIndex < 10000)) {
|
||||
System.out.printf(", and %d of %d points for D history only; ",
|
||||
countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest);
|
||||
}
|
||||
|
||||
//============================
|
||||
// This code section is if we wish to compute using digamma logs:
|
||||
//============================
|
||||
// With these neighbours counted, we're ready to compute the probability of the spike given the past
|
||||
// of source and dest.
|
||||
double logPGivenSourceAndDest;
|
||||
double logPGivenDest;
|
||||
if (k > 1) {
|
||||
// We're handling three variables:
|
||||
logPGivenSourceAndDest = digammaK - twoInverseKTerm
|
||||
- MathsUtils.digamma(Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)
|
||||
+ 1.0 / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater);
|
||||
logPGivenDest = MathsUtils.digamma(countOfDestNextMatched)
|
||||
- 1.0 / ((double) countOfDestNextMatched)
|
||||
- MathsUtils.digamma(countOfDestNextAndGreaterMatchedDest);
|
||||
} else {
|
||||
// We're really only handling two variables like an MI:
|
||||
logPGivenSourceAndDest = digammaK - inverseKTerm
|
||||
- MathsUtils.digamma(Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater);
|
||||
logPGivenDest = MathsUtils.digamma(countOfDestNextMatched)
|
||||
- MathsUtils.digamma(countOfDestNextAndGreaterMatchedDest);
|
||||
}
|
||||
if (trimToPosNextSpikeTimes) {
|
||||
logPGivenSourceAndDest -= Math.log(searchAreaRatio);
|
||||
}
|
||||
//============================
|
||||
|
||||
if (debug && (eventIndex < 10000)) {
|
||||
System.out.printf(" te ~~ log (%d/%d)/(%d/%d) = %.4f wc-> %.5f bc-> %.4f (inferred rates %.4f vs %.4f, " +
|
||||
"win-cor %.5f vs %.5f, bias-corrected %.5f vs %.5f)\n", Knns,
|
||||
Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater,
|
||||
countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest,
|
||||
Math.log(((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) /
|
||||
((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest))),
|
||||
// TE from Window corrected rates:
|
||||
Math.log((((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) /
|
||||
((destNext_timing_upper - destNext_timing_lower_original)*searchAreaRatio)) /
|
||||
(((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest)) /
|
||||
totalSearchTimeWindowCondDestPast)),
|
||||
// TE from Bias corrected rates:
|
||||
logPGivenSourceAndDest - logPGivenDest,
|
||||
// Inferred rates raw:
|
||||
(double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater) / (2.0*radius_destNext),
|
||||
(double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest) / (2.0*radius_destNext),
|
||||
// Inferred rates with window correction:
|
||||
((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) /
|
||||
((destNext_timing_upper - destNext_timing_lower_original)*searchAreaRatio),
|
||||
((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest)) /
|
||||
totalSearchTimeWindowCondDestPast,
|
||||
// Transform the log likelihoods into bias corrected rates:
|
||||
Math.exp(logPGivenSourceAndDest) / (2.0*radius_destNext), Math.exp(logPGivenDest) / (2.0*radius_destNext));
|
||||
if (trimToPosNextSpikeTimes) {
|
||||
System.out.printf("Search area ratio: %.5f, correction %.5f, t_y_upper %.5f, t_y_lower %.5f\n",
|
||||
searchAreaRatio, -Math.log(searchAreaRatio), prevSource_timing_upper_original, prevSource_timing_lower);
|
||||
}
|
||||
}
|
||||
// Unexplained case:
|
||||
if (countOfDestNextMatched < Knns) {
|
||||
// TODO Should not happen, print something!
|
||||
System.out.printf("SHOULD NOT HAPPEN!\n");
|
||||
// So debug this:
|
||||
nnPQ = kdTreesJoint[eventType].findKNearestNeighbours(
|
||||
Knns, eventIndexWithinType);
|
||||
for (int j = 0; j < Knns; j++) {
|
||||
// Take the furthest remaining of the nearest neighbours from the PQ:
|
||||
NeighbourNodeData nnData = nnPQ.poll();
|
||||
System.out.printf("NN data %d norms: source %.5f, destPast %.5f, destNext %.5f, nextIndex %d\n",
|
||||
nnData.norms[0], nnData.norms[1], nnData.norms[2], nnData.sampleIndex);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
//======================
|
||||
// Add the contribution in:
|
||||
// a.If we were using digamma logs:
|
||||
contributionFromSpikes += logPGivenSourceAndDest - logPGivenDest;
|
||||
// b. If we are only using window corrections but actual ratios:
|
||||
//contributionFromSpikes +=
|
||||
// Math.log((((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) /
|
||||
// ((destNext_timing_upper - destNext_timing_lower_original)*searchAreaRatio)) /
|
||||
// (((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest)) /
|
||||
// totalSearchTimeWindowCondDestPast));
|
||||
} else {
|
||||
if (eventType != NEXT_DEST) {
|
||||
// Pre-condition: next event is a source spike so we'll continue to check next event
|
||||
if (debug && (eventIndex < 10000)) {
|
||||
System.out.println();
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// Post-condition: the next event is a destination spike:
|
||||
|
||||
// Find the Knns nearest neighbour matches to this event,
|
||||
// with the same previous spiker and the next.
|
||||
// TODO Add dynamic exclusion time later
|
||||
PriorityQueue<NeighbourNodeData> nnPQ =
|
||||
kdTreesJoint[NEXT_DEST].findKNearestNeighbours(
|
||||
Knns, eventIndexWithinType);
|
||||
|
||||
// Find eps_{x,y,z} as the maximum x, y and z norms amongst this set:
|
||||
double radius_sourcePast = 0.0;
|
||||
double radius_destPast = 0.0;
|
||||
double radius_destNext = 0.0;
|
||||
int radius_destNext_sampleIndex = -1;
|
||||
for (int j = 0; j < Knns; j++) {
|
||||
// Take the furthest remaining of the nearest neighbours from the PQ:
|
||||
NeighbourNodeData nnData = nnPQ.poll();
|
||||
if (nnData.norms[0] > radius_sourcePast) {
|
||||
radius_sourcePast = nnData.norms[0];
|
||||
}
|
||||
if (nnData.norms[1] > radius_destPast) {
|
||||
radius_destPast = nnData.norms[1];
|
||||
}
|
||||
if (nnData.norms[2] > radius_destNext) {
|
||||
radius_destNext = nnData.norms[2];
|
||||
radius_destNext_sampleIndex = nnData.sampleIndex;
|
||||
}
|
||||
}
|
||||
// Postcondition: radius_* variables hold the search radius for each sourcePast, destPast and destNext matches.
|
||||
|
||||
if (debug && (eventIndex < 10000)) {
|
||||
System.out.print(", timings: src: ");
|
||||
MatrixUtils.printArray(System.out, thisEventTimings[0], 5);
|
||||
System.out.print(", dest: ");
|
||||
MatrixUtils.printArray(System.out, thisEventTimings[1], 5);
|
||||
System.out.print(", time to next: ");
|
||||
MatrixUtils.printArray(System.out, thisEventTimings[2], 5);
|
||||
System.out.printf("index=%d: K=%d NNs at next_range %.5f (point %d)", eventIndexWithinType, Knns, radius_destNext, radius_destNext_sampleIndex);
|
||||
}
|
||||
|
||||
indexForNextIsDest++;
|
||||
|
||||
// Now find the matching samples in each sub-space;
|
||||
// first match dest history and source history, with a next spike in dest:
|
||||
int numMatches = kdTreesSourceDestHistories[NEXT_DEST].
|
||||
findPointsWithinRs(eventIndexWithinType,
|
||||
new double[] {radius_sourcePast, radius_destPast}, 0,
|
||||
true, isWithinR, indicesWithinR);
|
||||
// Set the search point itself to be a neighbour - this is necessary to include the waiting time
|
||||
// for it in our count:
|
||||
indicesWithinR[numMatches] = eventIndexWithinType;
|
||||
indicesWithinR[numMatches+1] = -1;
|
||||
isWithinR[eventIndexWithinType] = true;
|
||||
// And check which of these samples had spike time in dest in the window or after ours:
|
||||
int countOfDestNextAndGreater = 0;
|
||||
int countOfDestNextInWindow = 0; // Would be Knns except for one on the lower boundary (if there is one)
|
||||
double timeInWindowWithMatchingJointHistories = 0;
|
||||
for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) {
|
||||
// Pull out this matching event from the full joint space
|
||||
double[][] matchedHistoryEventTimings = eventTimings[NEXT_DEST].elementAt(indicesWithinR[nIndex]);
|
||||
// Use simple labels for relative times from the previous target spike
|
||||
double matchingHistoryTimeToNextSpike = matchedHistoryEventTimings[2][0];
|
||||
// (time to previous source spike from the previous target spike can be negative or positive.
|
||||
// matchedHistoryEventTimings[0][0] < 0 implies previous source spike occuring later
|
||||
// than previous target spike; we want opposite sign here).
|
||||
double matchingHistoryTimeToPrevSourceSpike = -matchedHistoryEventTimings[0][0];
|
||||
|
||||
// We need to check how long we spent in the window matching the next spike
|
||||
// with a matching history.
|
||||
// First, we make sure that the next (target) spike was in the window or after it,
|
||||
// and that the previous source spike did not occur after the window
|
||||
// (this is possible since our neighbour match hasn't checked for the next spike time)
|
||||
if ((matchingHistoryTimeToNextSpike >= timeToNextSpikeSincePreviousDestSpike - radius_destNext) &&
|
||||
(matchingHistoryTimeToPrevSourceSpike <= timeToNextSpikeSincePreviousDestSpike + radius_destNext)) {
|
||||
|
||||
// Real start of window cannot be before previous destination spike:
|
||||
double realStartOfWindow = Math.max(timeToNextSpikeSincePreviousDestSpike - radius_destNext, 0);
|
||||
// Also, real start cannot be before previous source spike:
|
||||
// (previous source spike occurs at -matchedHistoryEventTimings[0][0] relative
|
||||
// to previous destination spike)
|
||||
realStartOfWindow = Math.max(realStartOfWindow, matchingHistoryTimeToPrevSourceSpike);
|
||||
|
||||
// Real end of window happened either when the spike occurred (which changes the history) or at
|
||||
// the end of the window:
|
||||
double realEndOfWindow = Math.min(matchingHistoryTimeToNextSpike,
|
||||
timeToNextSpikeSincePreviousDestSpike + radius_destNext);
|
||||
|
||||
// Add in how much time with a matching history we spent in this window:
|
||||
timeInWindowWithMatchingJointHistories +=
|
||||
realEndOfWindow - realStartOfWindow;
|
||||
|
||||
countOfDestNextAndGreater++;
|
||||
|
||||
// Count spikes occurring here in the window (and check below)
|
||||
if ((matchingHistoryTimeToNextSpike >= realStartOfWindow) &&
|
||||
(matchingHistoryTimeToNextSpike <= realEndOfWindow)){
|
||||
countOfDestNextInWindow++;
|
||||
}
|
||||
|
||||
}
|
||||
// Reset the isWithinR array while we're here
|
||||
isWithinR[indicesWithinR[nIndex]] = false;
|
||||
}
|
||||
// TODO Debug check:
|
||||
// if (countOfDestNextInWindow != Knns + 1) {
|
||||
// throw new Exception("Unexpected value for countOfDestNextInWindow: " + countOfDestNextInWindow);
|
||||
//}
|
||||
|
||||
// And count how many samples with the matching history actually had a
|
||||
// *source* spike next, during or after our window.
|
||||
// Note that we now must go to the other kdTree for next source spike
|
||||
kdTreesSourceDestHistories[NEXT_SOURCE].
|
||||
findPointsWithinRs(
|
||||
new double[] {radius_sourcePast, radius_destPast}, thisEventTimings,
|
||||
true, isWithinR, indicesWithinR);
|
||||
// And check which of these samples had spike time in source at or after ours:
|
||||
int countOfSourceNextAndGreater = 0;
|
||||
for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) {
|
||||
// Pull out this matching event from the full joint space
|
||||
double[][] matchedHistoryEventTimings = eventTimings[NEXT_SOURCE].elementAt(indicesWithinR[nIndex]);
|
||||
// Use simple labels for relative times from the previous target spike
|
||||
double matchingHistoryTimeToNextSpike = matchedHistoryEventTimings[2][0];
|
||||
// (time to previous source spike from the previous target spike can be negative or positive.
|
||||
// matchedHistoryEventTimings[0][0] < 0 implies previous source spike occuring later
|
||||
// than previous target spike; we want opposite sign here).
|
||||
double matchingHistoryTimeToPrevSourceSpike = -matchedHistoryEventTimings[0][0];
|
||||
|
||||
// We need to check how long we spent in the window matching the next spike
|
||||
// with a matching history.
|
||||
// First, we make sure that the next (source) spike was in the window or after it,
|
||||
// and that the previous source spike did not occur after the window
|
||||
// (this is possible since our neighbour match hasn't checked for the next spike time)
|
||||
if ((matchingHistoryTimeToNextSpike >= timeToNextSpikeSincePreviousDestSpike - radius_destNext) &&
|
||||
(matchingHistoryTimeToPrevSourceSpike <= timeToNextSpikeSincePreviousDestSpike + radius_destNext)) {
|
||||
|
||||
// Real start of window cannot be before previous destination spike:
|
||||
double realStartOfWindow = Math.max(timeToNextSpikeSincePreviousDestSpike - radius_destNext, 0);
|
||||
// Also, real start cannot be before previous source spike:
|
||||
// (previous source spike occurs at matchingHistoryTimeToPrevSourceSpike relative
|
||||
// to previous destination spike)
|
||||
realStartOfWindow = Math.max(realStartOfWindow, matchingHistoryTimeToPrevSourceSpike);
|
||||
|
||||
// Real end of window happened either when the spike occurred or at
|
||||
// the end of the window:
|
||||
double realEndOfWindow = Math.min(matchingHistoryTimeToNextSpike,
|
||||
timeToNextSpikeSincePreviousDestSpike + radius_destNext);
|
||||
|
||||
// Add in how much time with a matching history we spent in this window:
|
||||
timeInWindowWithMatchingJointHistories +=
|
||||
realEndOfWindow - realStartOfWindow;
|
||||
|
||||
countOfSourceNextAndGreater++;
|
||||
}
|
||||
// Reset the isWithinR array while we're here
|
||||
isWithinR[indicesWithinR[nIndex]] = false;
|
||||
}
|
||||
|
||||
// We need to count spike rate for all the times we're actually within the matching window for the next spike
|
||||
// This is kind of inspired by the Greg Ver Steeg et al. approach in
|
||||
// http://www.jmlr.org/proceedings/papers/v38/gao15.pdf
|
||||
// which is thinking about where the space is actually being explored.
|
||||
// This is where the window correction code was placed, which we're now replacing
|
||||
// with computing the length of actual time we spend in the next window.
|
||||
|
||||
|
||||
if (debug && (eventIndex < 10000)) {
|
||||
System.out.printf(" of %d + %d + %d points with matching S-D history",
|
||||
Knns, countOfSourceNextAndGreater, countOfDestNextAndGreater);
|
||||
}
|
||||
|
||||
// Now find the matching samples in the dest history and
|
||||
// with a next spike timing.
|
||||
int countOfDestNextAndGreaterMatchedDest = 0;
|
||||
int countOfDestNextMatched = 0;
|
||||
double timeInWindowWithMatchingDestHistory = 0;
|
||||
// Real start of window cannot be before previous destination spike:
|
||||
double realStartOfWindow = Math.max(timeToNextSpikeSincePreviousDestSpike - radius_destNext, 0);
|
||||
if (k > 1) {
|
||||
// Search only the space of dest past -- no point
|
||||
// searching dest past and next, since we need to run through
|
||||
// all matches of dest past to count those with greater next spike
|
||||
// times we might as well count those with matching spike times
|
||||
// while we're at it.
|
||||
numMatches = kdTreeDestHistory.findPointsWithinR(indexForNextIsDest, radius_destPast,
|
||||
true, isWithinR, indicesWithinR);
|
||||
// Set the search point itself to be a neighbour - this is necessary to include the waiting time
|
||||
// for it in our count:
|
||||
indicesWithinR[numMatches] = indexForNextIsDest;
|
||||
indicesWithinR[numMatches+1] = -1;
|
||||
isWithinR[indexForNextIsDest] = true;
|
||||
// And check which of these samples had next spike time after our window starts:
|
||||
for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) {
|
||||
// Pull out this matching event from the dest history space
|
||||
double[][] matchedHistoryEventTimings = destPastAndNextTimings.elementAt(indicesWithinR[nIndex]);
|
||||
if (matchedHistoryEventTimings[1][0] >= timeToNextSpikeSincePreviousDestSpike - radius_destNext) {
|
||||
// This sample had a matched history and next spike was a
|
||||
// spike with an interval longer than or considered equal to the current sample.
|
||||
// (The "equal to" is why we look for matches within kthNnData.distance here as well.)
|
||||
|
||||
// Real end of window happened either when the spike occurred or at
|
||||
// the end of the window:
|
||||
double realEndOfWindow = Math.min(matchedHistoryEventTimings[1][0],
|
||||
timeToNextSpikeSincePreviousDestSpike + radius_destNext);
|
||||
|
||||
// Add in how much time with a matching history we spent in this window:
|
||||
timeInWindowWithMatchingDestHistory +=
|
||||
realEndOfWindow - realStartOfWindow;
|
||||
|
||||
countOfDestNextAndGreaterMatchedDest++;
|
||||
if (matchedHistoryEventTimings[1][0] <= timeToNextSpikeSincePreviousDestSpike + radius_destNext) {
|
||||
// Then we also have a match on the next spike itself
|
||||
countOfDestNextMatched++;
|
||||
}
|
||||
}
|
||||
// Reset the isWithinR array while we're here
|
||||
isWithinR[indicesWithinR[nIndex]] = false;
|
||||
}
|
||||
} else {
|
||||
|
||||
// For k = 1, we only care about time since last spike.
|
||||
// So count how many of the next spikes were within the window first:
|
||||
// -- we don't take any past dest spike ISIs into account, so we just need to look at the proportion of next
|
||||
// spike times that match.
|
||||
countOfDestNextMatched = nnSearcherDestTimeToNextSpike.countPointsWithinRs(indexForNextIsDest,
|
||||
radius_destNext, Math.min(radius_destNext, timeToNextSpikeSincePreviousDestSpike), true);
|
||||
// And also check how long each of these spent in the window:
|
||||
timeInWindowWithMatchingDestHistory =
|
||||
nnSearcherDestTimeToNextSpike.sumDistanceAboveThresholdForPointsWithinRs(indexForNextIsDest,
|
||||
radius_destNext, Math.min(radius_destNext, timeToNextSpikeSincePreviousDestSpike), true);
|
||||
// Now check for points matching or larger, we just need to make this call with the
|
||||
// revised lower radius, because we don't check the upper one.
|
||||
countOfDestNextAndGreaterMatchedDest =
|
||||
nnSearcherDestTimeToNextSpike.countPointsWithinROrLarger(indexForNextIsDest,
|
||||
Math.min(radius_destNext, timeToNextSpikeSincePreviousDestSpike), true);
|
||||
// And we need to add time in for all of the (countOfDestNextAndGreaterMatchedDest - countOfDestNextMatched)
|
||||
// points which didn't spike in the window
|
||||
// TODO - check whether this is really doing what it intends to?
|
||||
timeInWindowWithMatchingDestHistory += (double) (countOfDestNextAndGreaterMatchedDest - countOfDestNextMatched) *
|
||||
(timeToNextSpikeSincePreviousDestSpike + radius_destNext - realStartOfWindow);
|
||||
|
||||
// Add in the wait time contribution for the search point itself (and need to include it in counts):
|
||||
timeInWindowWithMatchingDestHistory += timeToNextSpikeSincePreviousDestSpike - realStartOfWindow;
|
||||
countOfDestNextMatched++;
|
||||
countOfDestNextAndGreaterMatchedDest++;
|
||||
|
||||
if (timeInWindowWithMatchingDestHistory <= 0) {
|
||||
// David thought he saw this occur, adding debug exception so we can catch it if so"
|
||||
throw new Exception("timeInWindowWithMatchingDestHistory is not > 0");
|
||||
}
|
||||
}
|
||||
|
||||
if (debug && (eventIndex < 10000)) {
|
||||
System.out.printf(", and %d of %d points for D history only; ",
|
||||
countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest);
|
||||
}
|
||||
|
||||
//============================
|
||||
// This code section takes the counts of spikes and total intervals, and
|
||||
// estimates the log rates.
|
||||
// Inferred rates raw:
|
||||
double rawRateGivenSourceAndDest = (double) countOfDestNextInWindow / timeInWindowWithMatchingJointHistories;
|
||||
double rawRateGivenDest = (double) countOfDestNextMatched / timeInWindowWithMatchingDestHistory;
|
||||
// Attempt at bias correction:
|
||||
// Using digamma of neighbour_count - 1, since the neighbour count now includes our search point and it's really k waiting times
|
||||
double logRateGivenSourceAndDestCorrected = MathsUtils.digamma(Knns) // - (1.0 / (double) Knns) // I don't think this correction is required
|
||||
- Math.log(timeInWindowWithMatchingJointHistories);
|
||||
double logRateGivenDestCorrected = MathsUtils.digamma(countOfDestNextMatched-1) // - (1.0 / (double) countOfDestNextMatched) // I don't think this correction is required
|
||||
- Math.log(timeInWindowWithMatchingDestHistory);
|
||||
//============================
|
||||
|
||||
if (debug && (eventIndex < 10000)) {
|
||||
System.out.printf(" te ~~ %.4f - %.4f = %.4f, log (%.4f)/(%.4f) = %.4f (counts %d/%d = %.4f, %d/%d = %.4f -> te %.4f)\n",
|
||||
logRateGivenSourceAndDestCorrected,
|
||||
logRateGivenDestCorrected,
|
||||
logRateGivenSourceAndDestCorrected - logRateGivenDestCorrected,
|
||||
rawRateGivenSourceAndDest,
|
||||
rawRateGivenDest,
|
||||
Math.log(rawRateGivenSourceAndDest / rawRateGivenDest),
|
||||
Knns,
|
||||
Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater,
|
||||
(double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater),
|
||||
countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest,
|
||||
(double) countOfDestNextMatched / (double) countOfDestNextAndGreaterMatchedDest,
|
||||
Math.log(((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) /
|
||||
((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest))));
|
||||
}
|
||||
|
||||
//======================
|
||||
// Add the contribution in:
|
||||
// a.If we were using digamma logs:
|
||||
contributionFromSpikes += logRateGivenSourceAndDestCorrected - logRateGivenDestCorrected;
|
||||
// contributionFromSpikes += Math.log(rawRateGivenSourceAndDest / rawRateGivenDest);
|
||||
|
||||
|
||||
// b. If we are only using window corrections but actual ratios:
|
||||
//contributionFromSpikes +=
|
||||
// Math.log((((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) /
|
||||
// ((destNext_timing_upper - destNext_timing_lower_original)*searchAreaRatio)) /
|
||||
// (((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest)) /
|
||||
// totalSearchTimeWindowCondDestPast));
|
||||
}
|
||||
System.out.println("All done!");
|
||||
contributionFromSpikes /= totalTimeLength;
|
||||
|
|
0
java/source/infodynamics/utils/UnivariateNearestNeighbourSearcher.java
Executable file → Normal file
0
java/source/infodynamics/utils/UnivariateNearestNeighbourSearcher.java
Executable file → Normal file
Loading…
Reference in New Issue