Incorporated JL code from 21/11/2018

This commit is contained in:
David Shorten 2018-12-10 20:22:15 +11:00
parent fa9a45a9cf
commit ec01683592
4 changed files with 346 additions and 310 deletions

View File

@ -3,6 +3,7 @@ package infodynamics.measures.spiking.integration;
import java.util.Arrays;
import java.util.Iterator;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.Vector;
import infodynamics.measures.spiking.TransferEntropyCalculatorSpiking;
@ -120,7 +121,20 @@ public class TransferEntropyCalculatorSpikingIntegration implements
* source or destination spike)
*/
public static final String TRIM_TO_POS_PROP_NAME = "TRIM_RANGE_TO_POS_TIMES";
/**
* Property name for an amount of random Gaussian noise to be
* added to the data (default is 1e-8, matching the MILCA toolkit).
*/
public static final String PROP_ADD_NOISE = "NOISE_LEVEL_TO_ADD";
/**
* Whether to add an amount of random noise to the incoming data
*/
protected boolean addNoise = true;
/**
* Amount of random Gaussian noise to add to the incoming data
*/
protected double noiseLevel = (double) 1e-8;
protected boolean trimToPosNextSpikeTimes = false;
/**
@ -177,6 +191,15 @@ public class TransferEntropyCalculatorSpikingIntegration implements
Knns = Integer.parseInt(propertyValue);
} else if (propertyName.equalsIgnoreCase(TRIM_TO_POS_PROP_NAME)) {
trimToPosNextSpikeTimes = Boolean.parseBoolean(propertyValue);
} else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) {
if (propertyValue.equals("0") ||
propertyValue.equalsIgnoreCase("false")) {
addNoise = false;
noiseLevel = 0;
} else {
addNoise = true;
noiseLevel = Double.parseDouble(propertyValue);
}
} else {
// No property was set on this class
propertySet = false;
@ -200,6 +223,8 @@ public class TransferEntropyCalculatorSpikingIntegration implements
return Integer.toString(Knns);
} else if (propertyName.equalsIgnoreCase(TRIM_TO_POS_PROP_NAME)) {
return Boolean.toString(trimToPosNextSpikeTimes);
} else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) {
return Double.toString(noiseLevel);
} else {
// No property matches for this class
return null;
@ -379,6 +404,10 @@ public class TransferEntropyCalculatorSpikingIntegration implements
double[] spikeTimesForNextSpiker;
double timeOfPrevDestSpike = destSpikeTimes[dest_index];
int numEvents = 0;
Random random = null;
if (addNoise) {
random = new Random();
}
while(true) {
// 0. Check whether we're finished
if ((source_index == sourceSpikeTimes.length - 1) &&
@ -401,6 +430,9 @@ public class TransferEntropyCalculatorSpikingIntegration implements
spikeTimesForNextSpiker = nextIsDest ? destSpikeTimes : sourceSpikeTimes;
int indexForNextSpiker = nextIsDest ? dest_index : source_index;
timeToNextSpike = spikeTimesForNextSpiker[indexForNextSpiker+1] - timeOfPrevDestSpike;
if (addNoise) {
timeToNextSpike += random.nextGaussian()*noiseLevel;
}
// 2. Embed the past spikes
double[] sourcePast = new double[l];
double[] destPast = new double[k - 1];
@ -409,13 +441,22 @@ public class TransferEntropyCalculatorSpikingIntegration implements
}*/
sourcePast[0] = timeOfPrevDestSpike -
sourceSpikeTimes[source_index];
if (addNoise) {
sourcePast[0] += random.nextGaussian()*noiseLevel;
}
for (int i = 1; i < k; i++) {
destPast[i - 1] = destSpikeTimes[dest_index - i + 1] -
destSpikeTimes[dest_index - i];
if (addNoise) {
destPast[i - 1] += random.nextGaussian()*noiseLevel;
}
}
for (int i = 1; i < l; i++) {
sourcePast[i] = sourceSpikeTimes[source_index - i + 1] -
sourceSpikeTimes[source_index - i];
if (addNoise) {
sourcePast[i] += random.nextGaussian()*noiseLevel;
}
}
// 3. Store these embedded observations
double[][] observations = new double[][]{sourcePast, destPast,
@ -478,10 +519,6 @@ public class TransferEntropyCalculatorSpikingIntegration implements
double contributionFromSpikes = 0;
double totalTimeLength = 0;
double digammaK = MathsUtils.digamma(Knns);
double twoInverseKTerm = 2.0 / (double) Knns;
double inverseKTerm = 1.0 / (double) Knns;
// Create temporary storage for arrays used in the neighbour counting:
boolean[] isWithinR = new boolean[numberOfEvents]; // dummy, we don't really use this
int[] indicesWithinR = new int[numberOfEvents];
@ -509,314 +546,313 @@ public class TransferEntropyCalculatorSpikingIntegration implements
}
// Select only events where the destination spiked next:
if (eventType == NEXT_DEST) {
// Find the Knns nearest neighbour matches to this event,
// with the same previous spiker and the next.
// TODO Add dynamic exclusion time later
PriorityQueue<NeighbourNodeData> nnPQ =
kdTreesJoint[eventType].findKNearestNeighbours(
Knns, eventIndexWithinType);
// Find eps_{x,y,z} as the maximum x, y and z norms amongst this set:
double radius_sourcePast = 0.0;
double radius_destPast = 0.0;
double radius_destNext = 0.0;
int radius_destNext_sampleIndex = -1;
for (int j = 0; j < Knns; j++) {
// Take the furthest remaining of the nearest neighbours from the PQ:
NeighbourNodeData nnData = nnPQ.poll();
if (nnData.norms[0] > radius_sourcePast) {
radius_sourcePast = nnData.norms[0];
}
if (nnData.norms[1] > radius_destPast) {
radius_destPast = nnData.norms[1];
}
if (nnData.norms[2] > radius_destNext) {
radius_destNext = nnData.norms[2];
radius_destNext_sampleIndex = nnData.sampleIndex;
}
}
if (debug && (eventIndex < 10000)) {
System.out.print(", timings: src: ");
MatrixUtils.printArray(System.out, thisEventTimings[0], 5);
System.out.print(", dest: ");
MatrixUtils.printArray(System.out, thisEventTimings[1], 5);
System.out.print(", time to next: ");
MatrixUtils.printArray(System.out, thisEventTimings[2], 5);
System.out.printf("index=%d: K=%d NNs at next_range %.5f (point %d)", eventIndexWithinType, Knns, radius_destNext, radius_destNext_sampleIndex);
}
indexForNextIsDest++;
// Now find the matching samples in each sub-space;
// first match dest history and source history, with a next spike in dest:
kdTreesSourceDestHistories[NEXT_DEST].
findPointsWithinRs(eventIndexWithinType,
new double[] {radius_sourcePast, radius_destPast}, 0,
true, isWithinR, indicesWithinR);
// And check which of these samples had spike time in dest after ours:
int countOfDestNextAndGreater = 0;
for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) {
// Pull out this matching event from the full joint space
double[][] matchedHistoryEventTimings = eventTimings[NEXT_DEST].elementAt(indicesWithinR[nIndex]);
if (matchedHistoryEventTimings[2][0] > thisEventTimings[2][0] + radius_destNext) {
// This sample had a matched history and next spike was a destination
// spike with a longer interval than the current sample
countOfDestNextAndGreater++;
}
// Reset the isWithinR array while we're here
isWithinR[indicesWithinR[nIndex]] = false;
}
// And count how many samples with the matching history actually had a
// *source* spike next, after ours.
// Note that we now must go to the other kdTree for next source spike
kdTreesSourceDestHistories[NEXT_SOURCE].
findPointsWithinRs(
new double[] {radius_sourcePast, radius_destPast}, thisEventTimings,
true, isWithinR, indicesWithinR);
// And check which of these samples had spike time in source at or after ours:
int countOfSourceNextAndGreater = 0;
for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) {
// Pull out this matching event from the full joint space
double[][] matchedHistoryEventTimings = eventTimings[NEXT_SOURCE].elementAt(indicesWithinR[nIndex]);
if (matchedHistoryEventTimings[2][0] >= thisEventTimings[2][0] - radius_destNext) {
// This sample had a matched history and next spike was a source
// spike with an interval longer than or considered equal to the current sample.
// (The "equal to" is why we look for matches within radius_destNext here as well.)
countOfSourceNextAndGreater++;
}
// Reset the isWithinR array while we're here
isWithinR[indicesWithinR[nIndex]] = false;
}
// We need to correct radius_destNext to have a different value below, chopping it
// where it pushes into negative times (i.e. *before* the previous spike).
// TODO I think this is justified by the approach of Greg Ver Steeg et al. in
// http://www.jmlr.org/proceedings/papers/v38/gao15.pdf -- it's not the same
// as that approach, but a principled version where we know the space cannot
// be explored (so we don't need to only apply the correction where it is large).
double searchAreaRatio = 0; // Ratio of actual search area for full joint space compared to if the timings were independent
double prevSource_timing_upper_original = 0, prevSource_timing_lower = 0;
double destNext_timing_lower_original = timeToNextSpikeSincePreviousDestSpike - radius_destNext;
double destNext_timing_upper = timeToNextSpikeSincePreviousDestSpike + radius_destNext;
double radius_destNext_lower = radius_destNext;
if (trimToPosNextSpikeTimes) {
if (destNext_timing_lower_original < 0) {
// This is the range that's actually being searched in the lower
// dimensional space.
// (We don't need to adjust margins in lower dimensional space for
// this because it simply won't get any points below this!)
destNext_timing_lower_original = 0;
}
// timePreviousSourceSpikeBeforePreviousDestSpike is negative if source is after dest.
prevSource_timing_upper_original = -timePreviousSourceSpikeBeforePreviousDestSpike + radius_sourcePast;
prevSource_timing_lower = -timePreviousSourceSpikeBeforePreviousDestSpike - radius_sourcePast;
double destNext_timing_lower = Math.max(destNext_timing_lower_original,
prevSource_timing_lower);
double prevSource_timing_upper = Math.min(destNext_timing_upper, prevSource_timing_upper_original);
// Now look at various cases for the corrective ratio:
// TODO I think we still need to correct destNext_timing_lower_original to radius_destNext_lower here!
double denominator = (destNext_timing_upper - destNext_timing_lower_original) *
(prevSource_timing_upper_original - prevSource_timing_lower);
if (destNext_timing_upper - destNext_timing_lower >
prevSource_timing_upper - prevSource_timing_lower) {
// Case 1:
if (prevSource_timing_upper < destNext_timing_lower) {
// Windows do not overlap, so there will be no correction here
searchAreaRatio = (destNext_timing_upper - destNext_timing_lower) *
(prevSource_timing_upper - prevSource_timing_lower);
} else if (destNext_timing_upper < prevSource_timing_lower) {
// Should never happen, because next spike can't be before previous source spike
searchAreaRatio = 0;
throw new RuntimeException("Encountered a next target spike *before* the previous source spike");
} else if (destNext_timing_lower < prevSource_timing_lower) {
// Dest next lower bound is below source previous lower bound -- definitely
// can't get any points in the part below prevSource_timing_lower
// for dest next.
searchAreaRatio = (prevSource_timing_upper - prevSource_timing_lower) *
(destNext_timing_upper - prevSource_timing_upper) +
0.5 * (prevSource_timing_upper - prevSource_timing_lower) *
(prevSource_timing_upper - prevSource_timing_lower);
} else {
// destNext_timing_lower >= prevSource_timing_lower
// Dest next lower bound is within the previous source window
// but above the source lower bound.
searchAreaRatio = (destNext_timing_upper - destNext_timing_lower) *
(prevSource_timing_upper - prevSource_timing_lower) -
0.5 * (prevSource_timing_upper - destNext_timing_lower) *
(prevSource_timing_upper - destNext_timing_lower);
}
} else {
// Case 2:
if (prevSource_timing_upper < destNext_timing_lower) {
// Windows do not overlap, so there will be no correction here
searchAreaRatio = (destNext_timing_upper - destNext_timing_lower) *
(prevSource_timing_upper - prevSource_timing_lower);
} else if (destNext_timing_upper < prevSource_timing_lower) {
// Should never happen, because next spike can't be before previous source spike
searchAreaRatio = 0;
throw new RuntimeException("Encountered a next target spike *before* the previous source spike");
} else if (destNext_timing_lower < prevSource_timing_lower) {
searchAreaRatio = 0.5 * (destNext_timing_upper - prevSource_timing_lower) *
(destNext_timing_upper - prevSource_timing_lower);
} else {
// destNext_timing_lower >= prevSource_timing_lower
searchAreaRatio = (destNext_timing_lower - prevSource_timing_lower) *
(destNext_timing_upper - destNext_timing_lower) +
0.5 * (destNext_timing_upper - destNext_timing_lower) *
(destNext_timing_upper - destNext_timing_lower);
}
}
if (denominator <= 0) {
// We'll get NaN if we divide by it. No correction is necessary in this
// case, the margins are too thin to cause a discrepancy anyway.
searchAreaRatio = 1;
} else {
searchAreaRatio /= denominator;
}
// And finally fix up the lower radius on destNext if required:
radius_destNext_lower = timeToNextSpikeSincePreviousDestSpike - destNext_timing_lower;
}
if (debug && (eventIndex < 10000)) {
System.out.printf(" of %d + %d + %d points with matching S-D history",
Knns, countOfSourceNextAndGreater, countOfDestNextAndGreater);
}
// Now find the matching samples in the dest history and
// with a next spike timing.
int countOfDestNextAndGreaterMatchedDest = 0;
int countOfDestNextMatched = 0;
if (k > 1) {
// Search only the space of dest past -- no point
// searching dest past and next, since we need to run through
// all matches of dest past to count those with greater next spike
// times we might as well count those with matching spike times
// while we're at it.
kdTreeDestHistory.findPointsWithinR(indexForNextIsDest, radius_destPast,
true, isWithinR, indicesWithinR);
// And check which of these samples had next spike time after ours:
for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) {
// Pull out this matching event from the dest history space
double[][] matchedHistoryEventTimings = destPastAndNextTimings.elementAt(indicesWithinR[nIndex]);
if (matchedHistoryEventTimings[1][0] >= timeToNextSpikeSincePreviousDestSpike - radius_destNext_lower) {
// This sample had a matched history and next spike was a
// spike with an interval longer than or considered equal to the current sample.
// (The "equal to" is why we look for matches within kthNnData.distance here as well.)
countOfDestNextAndGreaterMatchedDest++;
if (matchedHistoryEventTimings[1][0] <= timeToNextSpikeSincePreviousDestSpike + radius_destNext) {
// Then we also have a match on the next spike itself
countOfDestNextMatched++;
}
}
// Reset the isWithinR array while we're here
isWithinR[indicesWithinR[nIndex]] = false;
}
} else {
// We don't take any past dest spike ISIs into account, so we just need to look at the proportion of next
// spike times that match.
countOfDestNextMatched = nnSearcherDestTimeToNextSpike.countPointsWithinRs(indexForNextIsDest,
radius_destNext, radius_destNext_lower, true);
// To check for points matching or larger, we just need to make this call with the
// revised lower radius, because we don't check the upper one.
countOfDestNextAndGreaterMatchedDest =
nnSearcherDestTimeToNextSpike.countPointsWithinROrLarger(indexForNextIsDest, radius_destNext_lower, true);
}
double totalSearchTimeWindowCondDestPast = radius_destNext_lower + radius_destNext;
if (debug && (eventIndex < 10000)) {
System.out.printf(", and %d of %d points for D history only; ",
countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest);
}
//============================
// This code section is if we wish to compute using digamma logs:
//============================
// With these neighbours counted, we're ready to compute the probability of the spike given the past
// of source and dest.
double logPGivenSourceAndDest;
double logPGivenDest;
if (k > 1) {
// We're handling three variables:
logPGivenSourceAndDest = digammaK - twoInverseKTerm
- MathsUtils.digamma(Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)
+ 1.0 / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater);
logPGivenDest = MathsUtils.digamma(countOfDestNextMatched)
- 1.0 / ((double) countOfDestNextMatched)
- MathsUtils.digamma(countOfDestNextAndGreaterMatchedDest);
} else {
// We're really only handling two variables like an MI:
logPGivenSourceAndDest = digammaK - inverseKTerm
- MathsUtils.digamma(Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater);
logPGivenDest = MathsUtils.digamma(countOfDestNextMatched)
- MathsUtils.digamma(countOfDestNextAndGreaterMatchedDest);
}
if (trimToPosNextSpikeTimes) {
logPGivenSourceAndDest -= Math.log(searchAreaRatio);
}
//============================
if (debug && (eventIndex < 10000)) {
System.out.printf(" te ~~ log (%d/%d)/(%d/%d) = %.4f wc-> %.5f bc-> %.4f (inferred rates %.4f vs %.4f, " +
"win-cor %.5f vs %.5f, bias-corrected %.5f vs %.5f)\n", Knns,
Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater,
countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest,
Math.log(((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) /
((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest))),
// TE from Window corrected rates:
Math.log((((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) /
((destNext_timing_upper - destNext_timing_lower_original)*searchAreaRatio)) /
(((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest)) /
totalSearchTimeWindowCondDestPast)),
// TE from Bias corrected rates:
logPGivenSourceAndDest - logPGivenDest,
// Inferred rates raw:
(double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater) / (2.0*radius_destNext),
(double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest) / (2.0*radius_destNext),
// Inferred rates with window correction:
((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) /
((destNext_timing_upper - destNext_timing_lower_original)*searchAreaRatio),
((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest)) /
totalSearchTimeWindowCondDestPast,
// Transform the log likelihoods into bias corrected rates:
Math.exp(logPGivenSourceAndDest) / (2.0*radius_destNext), Math.exp(logPGivenDest) / (2.0*radius_destNext));
if (trimToPosNextSpikeTimes) {
System.out.printf("Search area ratio: %.5f, correction %.5f, t_y_upper %.5f, t_y_lower %.5f\n",
searchAreaRatio, -Math.log(searchAreaRatio), prevSource_timing_upper_original, prevSource_timing_lower);
}
}
// Unexplained case:
if (countOfDestNextMatched < Knns) {
// TODO Should not happen, print something!
System.out.printf("SHOULD NOT HAPPEN!\n");
// So debug this:
nnPQ = kdTreesJoint[eventType].findKNearestNeighbours(
Knns, eventIndexWithinType);
for (int j = 0; j < Knns; j++) {
// Take the furthest remaining of the nearest neighbours from the PQ:
NeighbourNodeData nnData = nnPQ.poll();
System.out.printf("NN data %d norms: source %.5f, destPast %.5f, destNext %.5f, nextIndex %d\n",
nnData.norms[0], nnData.norms[1], nnData.norms[2], nnData.sampleIndex);
}
}
//======================
// Add the contribution in:
// a.If we were using digamma logs:
contributionFromSpikes += logPGivenSourceAndDest - logPGivenDest;
// b. If we are only using window corrections but actual ratios:
//contributionFromSpikes +=
// Math.log((((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) /
// ((destNext_timing_upper - destNext_timing_lower_original)*searchAreaRatio)) /
// (((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest)) /
// totalSearchTimeWindowCondDestPast));
} else {
if (eventType != NEXT_DEST) {
// Pre-condition: next event is a source spike so we'll continue to check next event
if (debug && (eventIndex < 10000)) {
System.out.println();
}
}
continue;
}
// Post-condition: the next event is a destination spike:
// Find the Knns nearest neighbour matches to this event,
// with the same previous spiker and the next.
// TODO Add dynamic exclusion time later
PriorityQueue<NeighbourNodeData> nnPQ =
kdTreesJoint[NEXT_DEST].findKNearestNeighbours(
Knns, eventIndexWithinType);
// Find eps_{x,y,z} as the maximum x, y and z norms amongst this set:
double radius_sourcePast = 0.0;
double radius_destPast = 0.0;
double radius_destNext = 0.0;
int radius_destNext_sampleIndex = -1;
for (int j = 0; j < Knns; j++) {
// Take the furthest remaining of the nearest neighbours from the PQ:
NeighbourNodeData nnData = nnPQ.poll();
if (nnData.norms[0] > radius_sourcePast) {
radius_sourcePast = nnData.norms[0];
}
if (nnData.norms[1] > radius_destPast) {
radius_destPast = nnData.norms[1];
}
if (nnData.norms[2] > radius_destNext) {
radius_destNext = nnData.norms[2];
radius_destNext_sampleIndex = nnData.sampleIndex;
}
}
// Postcondition: radius_* variables hold the search radius for each sourcePast, destPast and destNext matches.
if (debug && (eventIndex < 10000)) {
System.out.print(", timings: src: ");
MatrixUtils.printArray(System.out, thisEventTimings[0], 5);
System.out.print(", dest: ");
MatrixUtils.printArray(System.out, thisEventTimings[1], 5);
System.out.print(", time to next: ");
MatrixUtils.printArray(System.out, thisEventTimings[2], 5);
System.out.printf("index=%d: K=%d NNs at next_range %.5f (point %d)", eventIndexWithinType, Knns, radius_destNext, radius_destNext_sampleIndex);
}
indexForNextIsDest++;
// Now find the matching samples in each sub-space;
// first match dest history and source history, with a next spike in dest:
int numMatches = kdTreesSourceDestHistories[NEXT_DEST].
findPointsWithinRs(eventIndexWithinType,
new double[] {radius_sourcePast, radius_destPast}, 0,
true, isWithinR, indicesWithinR);
// Set the search point itself to be a neighbour - this is necessary to include the waiting time
// for it in our count:
indicesWithinR[numMatches] = eventIndexWithinType;
indicesWithinR[numMatches+1] = -1;
isWithinR[eventIndexWithinType] = true;
// And check which of these samples had spike time in dest in the window or after ours:
int countOfDestNextAndGreater = 0;
int countOfDestNextInWindow = 0; // Would be Knns except for one on the lower boundary (if there is one)
double timeInWindowWithMatchingJointHistories = 0;
for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) {
// Pull out this matching event from the full joint space
double[][] matchedHistoryEventTimings = eventTimings[NEXT_DEST].elementAt(indicesWithinR[nIndex]);
// Use simple labels for relative times from the previous target spike
double matchingHistoryTimeToNextSpike = matchedHistoryEventTimings[2][0];
// (time to previous source spike from the previous target spike can be negative or positive.
// matchedHistoryEventTimings[0][0] < 0 implies previous source spike occuring later
// than previous target spike; we want opposite sign here).
double matchingHistoryTimeToPrevSourceSpike = -matchedHistoryEventTimings[0][0];
// We need to check how long we spent in the window matching the next spike
// with a matching history.
// First, we make sure that the next (target) spike was in the window or after it,
// and that the previous source spike did not occur after the window
// (this is possible since our neighbour match hasn't checked for the next spike time)
if ((matchingHistoryTimeToNextSpike >= timeToNextSpikeSincePreviousDestSpike - radius_destNext) &&
(matchingHistoryTimeToPrevSourceSpike <= timeToNextSpikeSincePreviousDestSpike + radius_destNext)) {
// Real start of window cannot be before previous destination spike:
double realStartOfWindow = Math.max(timeToNextSpikeSincePreviousDestSpike - radius_destNext, 0);
// Also, real start cannot be before previous source spike:
// (previous source spike occurs at -matchedHistoryEventTimings[0][0] relative
// to previous destination spike)
realStartOfWindow = Math.max(realStartOfWindow, matchingHistoryTimeToPrevSourceSpike);
// Real end of window happened either when the spike occurred (which changes the history) or at
// the end of the window:
double realEndOfWindow = Math.min(matchingHistoryTimeToNextSpike,
timeToNextSpikeSincePreviousDestSpike + radius_destNext);
// Add in how much time with a matching history we spent in this window:
timeInWindowWithMatchingJointHistories +=
realEndOfWindow - realStartOfWindow;
countOfDestNextAndGreater++;
// Count spikes occurring here in the window (and check below)
if ((matchingHistoryTimeToNextSpike >= realStartOfWindow) &&
(matchingHistoryTimeToNextSpike <= realEndOfWindow)){
countOfDestNextInWindow++;
}
}
// Reset the isWithinR array while we're here
isWithinR[indicesWithinR[nIndex]] = false;
}
// TODO Debug check:
// if (countOfDestNextInWindow != Knns + 1) {
// throw new Exception("Unexpected value for countOfDestNextInWindow: " + countOfDestNextInWindow);
//}
// And count how many samples with the matching history actually had a
// *source* spike next, during or after our window.
// Note that we now must go to the other kdTree for next source spike
kdTreesSourceDestHistories[NEXT_SOURCE].
findPointsWithinRs(
new double[] {radius_sourcePast, radius_destPast}, thisEventTimings,
true, isWithinR, indicesWithinR);
// And check which of these samples had spike time in source at or after ours:
int countOfSourceNextAndGreater = 0;
for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) {
// Pull out this matching event from the full joint space
double[][] matchedHistoryEventTimings = eventTimings[NEXT_SOURCE].elementAt(indicesWithinR[nIndex]);
// Use simple labels for relative times from the previous target spike
double matchingHistoryTimeToNextSpike = matchedHistoryEventTimings[2][0];
// (time to previous source spike from the previous target spike can be negative or positive.
// matchedHistoryEventTimings[0][0] < 0 implies previous source spike occuring later
// than previous target spike; we want opposite sign here).
double matchingHistoryTimeToPrevSourceSpike = -matchedHistoryEventTimings[0][0];
// We need to check how long we spent in the window matching the next spike
// with a matching history.
// First, we make sure that the next (source) spike was in the window or after it,
// and that the previous source spike did not occur after the window
// (this is possible since our neighbour match hasn't checked for the next spike time)
if ((matchingHistoryTimeToNextSpike >= timeToNextSpikeSincePreviousDestSpike - radius_destNext) &&
(matchingHistoryTimeToPrevSourceSpike <= timeToNextSpikeSincePreviousDestSpike + radius_destNext)) {
// Real start of window cannot be before previous destination spike:
double realStartOfWindow = Math.max(timeToNextSpikeSincePreviousDestSpike - radius_destNext, 0);
// Also, real start cannot be before previous source spike:
// (previous source spike occurs at matchingHistoryTimeToPrevSourceSpike relative
// to previous destination spike)
realStartOfWindow = Math.max(realStartOfWindow, matchingHistoryTimeToPrevSourceSpike);
// Real end of window happened either when the spike occurred or at
// the end of the window:
double realEndOfWindow = Math.min(matchingHistoryTimeToNextSpike,
timeToNextSpikeSincePreviousDestSpike + radius_destNext);
// Add in how much time with a matching history we spent in this window:
timeInWindowWithMatchingJointHistories +=
realEndOfWindow - realStartOfWindow;
countOfSourceNextAndGreater++;
}
// Reset the isWithinR array while we're here
isWithinR[indicesWithinR[nIndex]] = false;
}
// We need to count spike rate for all the times we're actually within the matching window for the next spike
// This is kind of inspired by the Greg Ver Steeg et al. approach in
// http://www.jmlr.org/proceedings/papers/v38/gao15.pdf
// which is thinking about where the space is actually being explored.
// This is where the window correction code was placed, which we're now replacing
// with computing the length of actual time we spend in the next window.
if (debug && (eventIndex < 10000)) {
System.out.printf(" of %d + %d + %d points with matching S-D history",
Knns, countOfSourceNextAndGreater, countOfDestNextAndGreater);
}
// Now find the matching samples in the dest history and
// with a next spike timing.
int countOfDestNextAndGreaterMatchedDest = 0;
int countOfDestNextMatched = 0;
double timeInWindowWithMatchingDestHistory = 0;
// Real start of window cannot be before previous destination spike:
double realStartOfWindow = Math.max(timeToNextSpikeSincePreviousDestSpike - radius_destNext, 0);
if (k > 1) {
// Search only the space of dest past -- no point
// searching dest past and next, since we need to run through
// all matches of dest past to count those with greater next spike
// times we might as well count those with matching spike times
// while we're at it.
numMatches = kdTreeDestHistory.findPointsWithinR(indexForNextIsDest, radius_destPast,
true, isWithinR, indicesWithinR);
// Set the search point itself to be a neighbour - this is necessary to include the waiting time
// for it in our count:
indicesWithinR[numMatches] = indexForNextIsDest;
indicesWithinR[numMatches+1] = -1;
isWithinR[indexForNextIsDest] = true;
// And check which of these samples had next spike time after our window starts:
for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) {
// Pull out this matching event from the dest history space
double[][] matchedHistoryEventTimings = destPastAndNextTimings.elementAt(indicesWithinR[nIndex]);
if (matchedHistoryEventTimings[1][0] >= timeToNextSpikeSincePreviousDestSpike - radius_destNext) {
// This sample had a matched history and next spike was a
// spike with an interval longer than or considered equal to the current sample.
// (The "equal to" is why we look for matches within kthNnData.distance here as well.)
// Real end of window happened either when the spike occurred or at
// the end of the window:
double realEndOfWindow = Math.min(matchedHistoryEventTimings[1][0],
timeToNextSpikeSincePreviousDestSpike + radius_destNext);
// Add in how much time with a matching history we spent in this window:
timeInWindowWithMatchingDestHistory +=
realEndOfWindow - realStartOfWindow;
countOfDestNextAndGreaterMatchedDest++;
if (matchedHistoryEventTimings[1][0] <= timeToNextSpikeSincePreviousDestSpike + radius_destNext) {
// Then we also have a match on the next spike itself
countOfDestNextMatched++;
}
}
// Reset the isWithinR array while we're here
isWithinR[indicesWithinR[nIndex]] = false;
}
} else {
// For k = 1, we only care about time since last spike.
// So count how many of the next spikes were within the window first:
// -- we don't take any past dest spike ISIs into account, so we just need to look at the proportion of next
// spike times that match.
countOfDestNextMatched = nnSearcherDestTimeToNextSpike.countPointsWithinRs(indexForNextIsDest,
radius_destNext, Math.min(radius_destNext, timeToNextSpikeSincePreviousDestSpike), true);
// And also check how long each of these spent in the window:
timeInWindowWithMatchingDestHistory =
nnSearcherDestTimeToNextSpike.sumDistanceAboveThresholdForPointsWithinRs(indexForNextIsDest,
radius_destNext, Math.min(radius_destNext, timeToNextSpikeSincePreviousDestSpike), true);
// Now check for points matching or larger, we just need to make this call with the
// revised lower radius, because we don't check the upper one.
countOfDestNextAndGreaterMatchedDest =
nnSearcherDestTimeToNextSpike.countPointsWithinROrLarger(indexForNextIsDest,
Math.min(radius_destNext, timeToNextSpikeSincePreviousDestSpike), true);
// And we need to add time in for all of the (countOfDestNextAndGreaterMatchedDest - countOfDestNextMatched)
// points which didn't spike in the window
// TODO - check whether this is really doing what it intends to?
timeInWindowWithMatchingDestHistory += (double) (countOfDestNextAndGreaterMatchedDest - countOfDestNextMatched) *
(timeToNextSpikeSincePreviousDestSpike + radius_destNext - realStartOfWindow);
// Add in the wait time contribution for the search point itself (and need to include it in counts):
timeInWindowWithMatchingDestHistory += timeToNextSpikeSincePreviousDestSpike - realStartOfWindow;
countOfDestNextMatched++;
countOfDestNextAndGreaterMatchedDest++;
if (timeInWindowWithMatchingDestHistory <= 0) {
// David thought he saw this occur, adding debug exception so we can catch it if so"
throw new Exception("timeInWindowWithMatchingDestHistory is not > 0");
}
}
if (debug && (eventIndex < 10000)) {
System.out.printf(", and %d of %d points for D history only; ",
countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest);
}
//============================
// This code section takes the counts of spikes and total intervals, and
// estimates the log rates.
// Inferred rates raw:
double rawRateGivenSourceAndDest = (double) countOfDestNextInWindow / timeInWindowWithMatchingJointHistories;
double rawRateGivenDest = (double) countOfDestNextMatched / timeInWindowWithMatchingDestHistory;
// Attempt at bias correction:
// Using digamma of neighbour_count - 1, since the neighbour count now includes our search point and it's really k waiting times
double logRateGivenSourceAndDestCorrected = MathsUtils.digamma(Knns) // - (1.0 / (double) Knns) // I don't think this correction is required
- Math.log(timeInWindowWithMatchingJointHistories);
double logRateGivenDestCorrected = MathsUtils.digamma(countOfDestNextMatched-1) // - (1.0 / (double) countOfDestNextMatched) // I don't think this correction is required
- Math.log(timeInWindowWithMatchingDestHistory);
//============================
if (debug && (eventIndex < 10000)) {
System.out.printf(" te ~~ %.4f - %.4f = %.4f, log (%.4f)/(%.4f) = %.4f (counts %d/%d = %.4f, %d/%d = %.4f -> te %.4f)\n",
logRateGivenSourceAndDestCorrected,
logRateGivenDestCorrected,
logRateGivenSourceAndDestCorrected - logRateGivenDestCorrected,
rawRateGivenSourceAndDest,
rawRateGivenDest,
Math.log(rawRateGivenSourceAndDest / rawRateGivenDest),
Knns,
Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater,
(double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater),
countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest,
(double) countOfDestNextMatched / (double) countOfDestNextAndGreaterMatchedDest,
Math.log(((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) /
((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest))));
}
//======================
// Add the contribution in:
// a.If we were using digamma logs:
contributionFromSpikes += logRateGivenSourceAndDestCorrected - logRateGivenDestCorrected;
// contributionFromSpikes += Math.log(rawRateGivenSourceAndDest / rawRateGivenDest);
// b. If we are only using window corrections but actual ratios:
//contributionFromSpikes +=
// Math.log((((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) /
// ((destNext_timing_upper - destNext_timing_lower_original)*searchAreaRatio)) /
// (((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest)) /
// totalSearchTimeWindowCondDestPast));
}
System.out.println("All done!");
contributionFromSpikes /= totalTimeLength;

0
java/source/infodynamics/utils/KdTree.java Executable file → Normal file
View File

View File

View File