diff --git a/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegration.java b/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegration.java index 690c397..0b49943 100644 --- a/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegration.java +++ b/java/source/infodynamics/measures/spiking/integration/TransferEntropyCalculatorSpikingIntegration.java @@ -3,6 +3,7 @@ package infodynamics.measures.spiking.integration; import java.util.Arrays; import java.util.Iterator; import java.util.PriorityQueue; +import java.util.Random; import java.util.Vector; import infodynamics.measures.spiking.TransferEntropyCalculatorSpiking; @@ -120,7 +121,20 @@ public class TransferEntropyCalculatorSpikingIntegration implements * source or destination spike) */ public static final String TRIM_TO_POS_PROP_NAME = "TRIM_RANGE_TO_POS_TIMES"; - + /** + * Property name for an amount of random Gaussian noise to be + * added to the data (default is 1e-8, matching the MILCA toolkit). + */ + public static final String PROP_ADD_NOISE = "NOISE_LEVEL_TO_ADD"; + /** + * Whether to add an amount of random noise to the incoming data + */ + protected boolean addNoise = true; + /** + * Amount of random Gaussian noise to add to the incoming data + */ + protected double noiseLevel = (double) 1e-8; + protected boolean trimToPosNextSpikeTimes = false; /** @@ -177,6 +191,15 @@ public class TransferEntropyCalculatorSpikingIntegration implements Knns = Integer.parseInt(propertyValue); } else if (propertyName.equalsIgnoreCase(TRIM_TO_POS_PROP_NAME)) { trimToPosNextSpikeTimes = Boolean.parseBoolean(propertyValue); + } else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) { + if (propertyValue.equals("0") || + propertyValue.equalsIgnoreCase("false")) { + addNoise = false; + noiseLevel = 0; + } else { + addNoise = true; + noiseLevel = Double.parseDouble(propertyValue); + } } else { // No property was set on this class propertySet = false; @@ -200,6 +223,8 @@ public class TransferEntropyCalculatorSpikingIntegration implements return Integer.toString(Knns); } else if (propertyName.equalsIgnoreCase(TRIM_TO_POS_PROP_NAME)) { return Boolean.toString(trimToPosNextSpikeTimes); + } else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) { + return Double.toString(noiseLevel); } else { // No property matches for this class return null; @@ -379,6 +404,10 @@ public class TransferEntropyCalculatorSpikingIntegration implements double[] spikeTimesForNextSpiker; double timeOfPrevDestSpike = destSpikeTimes[dest_index]; int numEvents = 0; + Random random = null; + if (addNoise) { + random = new Random(); + } while(true) { // 0. Check whether we're finished if ((source_index == sourceSpikeTimes.length - 1) && @@ -401,6 +430,9 @@ public class TransferEntropyCalculatorSpikingIntegration implements spikeTimesForNextSpiker = nextIsDest ? destSpikeTimes : sourceSpikeTimes; int indexForNextSpiker = nextIsDest ? dest_index : source_index; timeToNextSpike = spikeTimesForNextSpiker[indexForNextSpiker+1] - timeOfPrevDestSpike; + if (addNoise) { + timeToNextSpike += random.nextGaussian()*noiseLevel; + } // 2. Embed the past spikes double[] sourcePast = new double[l]; double[] destPast = new double[k - 1]; @@ -409,13 +441,22 @@ public class TransferEntropyCalculatorSpikingIntegration implements }*/ sourcePast[0] = timeOfPrevDestSpike - sourceSpikeTimes[source_index]; + if (addNoise) { + sourcePast[0] += random.nextGaussian()*noiseLevel; + } for (int i = 1; i < k; i++) { destPast[i - 1] = destSpikeTimes[dest_index - i + 1] - destSpikeTimes[dest_index - i]; + if (addNoise) { + destPast[i - 1] += random.nextGaussian()*noiseLevel; + } } for (int i = 1; i < l; i++) { sourcePast[i] = sourceSpikeTimes[source_index - i + 1] - sourceSpikeTimes[source_index - i]; + if (addNoise) { + sourcePast[i] += random.nextGaussian()*noiseLevel; + } } // 3. Store these embedded observations double[][] observations = new double[][]{sourcePast, destPast, @@ -478,10 +519,6 @@ public class TransferEntropyCalculatorSpikingIntegration implements double contributionFromSpikes = 0; double totalTimeLength = 0; - double digammaK = MathsUtils.digamma(Knns); - double twoInverseKTerm = 2.0 / (double) Knns; - double inverseKTerm = 1.0 / (double) Knns; - // Create temporary storage for arrays used in the neighbour counting: boolean[] isWithinR = new boolean[numberOfEvents]; // dummy, we don't really use this int[] indicesWithinR = new int[numberOfEvents]; @@ -509,314 +546,313 @@ public class TransferEntropyCalculatorSpikingIntegration implements } // Select only events where the destination spiked next: - if (eventType == NEXT_DEST) { - // Find the Knns nearest neighbour matches to this event, - // with the same previous spiker and the next. - // TODO Add dynamic exclusion time later - PriorityQueue nnPQ = - kdTreesJoint[eventType].findKNearestNeighbours( - Knns, eventIndexWithinType); - - // Find eps_{x,y,z} as the maximum x, y and z norms amongst this set: - double radius_sourcePast = 0.0; - double radius_destPast = 0.0; - double radius_destNext = 0.0; - int radius_destNext_sampleIndex = -1; - for (int j = 0; j < Knns; j++) { - // Take the furthest remaining of the nearest neighbours from the PQ: - NeighbourNodeData nnData = nnPQ.poll(); - if (nnData.norms[0] > radius_sourcePast) { - radius_sourcePast = nnData.norms[0]; - } - if (nnData.norms[1] > radius_destPast) { - radius_destPast = nnData.norms[1]; - } - if (nnData.norms[2] > radius_destNext) { - radius_destNext = nnData.norms[2]; - radius_destNext_sampleIndex = nnData.sampleIndex; - } - } - - if (debug && (eventIndex < 10000)) { - System.out.print(", timings: src: "); - MatrixUtils.printArray(System.out, thisEventTimings[0], 5); - System.out.print(", dest: "); - MatrixUtils.printArray(System.out, thisEventTimings[1], 5); - System.out.print(", time to next: "); - MatrixUtils.printArray(System.out, thisEventTimings[2], 5); - System.out.printf("index=%d: K=%d NNs at next_range %.5f (point %d)", eventIndexWithinType, Knns, radius_destNext, radius_destNext_sampleIndex); - } - - indexForNextIsDest++; - - // Now find the matching samples in each sub-space; - // first match dest history and source history, with a next spike in dest: - kdTreesSourceDestHistories[NEXT_DEST]. - findPointsWithinRs(eventIndexWithinType, - new double[] {radius_sourcePast, radius_destPast}, 0, - true, isWithinR, indicesWithinR); - // And check which of these samples had spike time in dest after ours: - int countOfDestNextAndGreater = 0; - for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { - // Pull out this matching event from the full joint space - double[][] matchedHistoryEventTimings = eventTimings[NEXT_DEST].elementAt(indicesWithinR[nIndex]); - if (matchedHistoryEventTimings[2][0] > thisEventTimings[2][0] + radius_destNext) { - // This sample had a matched history and next spike was a destination - // spike with a longer interval than the current sample - countOfDestNextAndGreater++; - } - // Reset the isWithinR array while we're here - isWithinR[indicesWithinR[nIndex]] = false; - } - // And count how many samples with the matching history actually had a - // *source* spike next, after ours. - // Note that we now must go to the other kdTree for next source spike - kdTreesSourceDestHistories[NEXT_SOURCE]. - findPointsWithinRs( - new double[] {radius_sourcePast, radius_destPast}, thisEventTimings, - true, isWithinR, indicesWithinR); - // And check which of these samples had spike time in source at or after ours: - int countOfSourceNextAndGreater = 0; - for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { - // Pull out this matching event from the full joint space - double[][] matchedHistoryEventTimings = eventTimings[NEXT_SOURCE].elementAt(indicesWithinR[nIndex]); - if (matchedHistoryEventTimings[2][0] >= thisEventTimings[2][0] - radius_destNext) { - // This sample had a matched history and next spike was a source - // spike with an interval longer than or considered equal to the current sample. - // (The "equal to" is why we look for matches within radius_destNext here as well.) - countOfSourceNextAndGreater++; - } - // Reset the isWithinR array while we're here - isWithinR[indicesWithinR[nIndex]] = false; - } - - // We need to correct radius_destNext to have a different value below, chopping it - // where it pushes into negative times (i.e. *before* the previous spike). - // TODO I think this is justified by the approach of Greg Ver Steeg et al. in - // http://www.jmlr.org/proceedings/papers/v38/gao15.pdf -- it's not the same - // as that approach, but a principled version where we know the space cannot - // be explored (so we don't need to only apply the correction where it is large). - double searchAreaRatio = 0; // Ratio of actual search area for full joint space compared to if the timings were independent - double prevSource_timing_upper_original = 0, prevSource_timing_lower = 0; - double destNext_timing_lower_original = timeToNextSpikeSincePreviousDestSpike - radius_destNext; - double destNext_timing_upper = timeToNextSpikeSincePreviousDestSpike + radius_destNext; - double radius_destNext_lower = radius_destNext; - if (trimToPosNextSpikeTimes) { - if (destNext_timing_lower_original < 0) { - // This is the range that's actually being searched in the lower - // dimensional space. - // (We don't need to adjust margins in lower dimensional space for - // this because it simply won't get any points below this!) - destNext_timing_lower_original = 0; - } - // timePreviousSourceSpikeBeforePreviousDestSpike is negative if source is after dest. - prevSource_timing_upper_original = -timePreviousSourceSpikeBeforePreviousDestSpike + radius_sourcePast; - prevSource_timing_lower = -timePreviousSourceSpikeBeforePreviousDestSpike - radius_sourcePast; - double destNext_timing_lower = Math.max(destNext_timing_lower_original, - prevSource_timing_lower); - double prevSource_timing_upper = Math.min(destNext_timing_upper, prevSource_timing_upper_original); - // Now look at various cases for the corrective ratio: - // TODO I think we still need to correct destNext_timing_lower_original to radius_destNext_lower here! - double denominator = (destNext_timing_upper - destNext_timing_lower_original) * - (prevSource_timing_upper_original - prevSource_timing_lower); - if (destNext_timing_upper - destNext_timing_lower > - prevSource_timing_upper - prevSource_timing_lower) { - // Case 1: - if (prevSource_timing_upper < destNext_timing_lower) { - // Windows do not overlap, so there will be no correction here - searchAreaRatio = (destNext_timing_upper - destNext_timing_lower) * - (prevSource_timing_upper - prevSource_timing_lower); - } else if (destNext_timing_upper < prevSource_timing_lower) { - // Should never happen, because next spike can't be before previous source spike - searchAreaRatio = 0; - throw new RuntimeException("Encountered a next target spike *before* the previous source spike"); - } else if (destNext_timing_lower < prevSource_timing_lower) { - // Dest next lower bound is below source previous lower bound -- definitely - // can't get any points in the part below prevSource_timing_lower - // for dest next. - searchAreaRatio = (prevSource_timing_upper - prevSource_timing_lower) * - (destNext_timing_upper - prevSource_timing_upper) + - 0.5 * (prevSource_timing_upper - prevSource_timing_lower) * - (prevSource_timing_upper - prevSource_timing_lower); - } else { - // destNext_timing_lower >= prevSource_timing_lower - // Dest next lower bound is within the previous source window - // but above the source lower bound. - searchAreaRatio = (destNext_timing_upper - destNext_timing_lower) * - (prevSource_timing_upper - prevSource_timing_lower) - - 0.5 * (prevSource_timing_upper - destNext_timing_lower) * - (prevSource_timing_upper - destNext_timing_lower); - } - } else { - // Case 2: - if (prevSource_timing_upper < destNext_timing_lower) { - // Windows do not overlap, so there will be no correction here - searchAreaRatio = (destNext_timing_upper - destNext_timing_lower) * - (prevSource_timing_upper - prevSource_timing_lower); - } else if (destNext_timing_upper < prevSource_timing_lower) { - // Should never happen, because next spike can't be before previous source spike - searchAreaRatio = 0; - throw new RuntimeException("Encountered a next target spike *before* the previous source spike"); - } else if (destNext_timing_lower < prevSource_timing_lower) { - searchAreaRatio = 0.5 * (destNext_timing_upper - prevSource_timing_lower) * - (destNext_timing_upper - prevSource_timing_lower); - } else { - // destNext_timing_lower >= prevSource_timing_lower - searchAreaRatio = (destNext_timing_lower - prevSource_timing_lower) * - (destNext_timing_upper - destNext_timing_lower) + - 0.5 * (destNext_timing_upper - destNext_timing_lower) * - (destNext_timing_upper - destNext_timing_lower); - } - } - if (denominator <= 0) { - // We'll get NaN if we divide by it. No correction is necessary in this - // case, the margins are too thin to cause a discrepancy anyway. - searchAreaRatio = 1; - } else { - searchAreaRatio /= denominator; - } - // And finally fix up the lower radius on destNext if required: - radius_destNext_lower = timeToNextSpikeSincePreviousDestSpike - destNext_timing_lower; - } - - if (debug && (eventIndex < 10000)) { - System.out.printf(" of %d + %d + %d points with matching S-D history", - Knns, countOfSourceNextAndGreater, countOfDestNextAndGreater); - } - - // Now find the matching samples in the dest history and - // with a next spike timing. - int countOfDestNextAndGreaterMatchedDest = 0; - int countOfDestNextMatched = 0; - if (k > 1) { - // Search only the space of dest past -- no point - // searching dest past and next, since we need to run through - // all matches of dest past to count those with greater next spike - // times we might as well count those with matching spike times - // while we're at it. - kdTreeDestHistory.findPointsWithinR(indexForNextIsDest, radius_destPast, - true, isWithinR, indicesWithinR); - // And check which of these samples had next spike time after ours: - for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { - // Pull out this matching event from the dest history space - double[][] matchedHistoryEventTimings = destPastAndNextTimings.elementAt(indicesWithinR[nIndex]); - if (matchedHistoryEventTimings[1][0] >= timeToNextSpikeSincePreviousDestSpike - radius_destNext_lower) { - // This sample had a matched history and next spike was a - // spike with an interval longer than or considered equal to the current sample. - // (The "equal to" is why we look for matches within kthNnData.distance here as well.) - countOfDestNextAndGreaterMatchedDest++; - if (matchedHistoryEventTimings[1][0] <= timeToNextSpikeSincePreviousDestSpike + radius_destNext) { - // Then we also have a match on the next spike itself - countOfDestNextMatched++; - } - } - // Reset the isWithinR array while we're here - isWithinR[indicesWithinR[nIndex]] = false; - } - } else { - // We don't take any past dest spike ISIs into account, so we just need to look at the proportion of next - // spike times that match. - countOfDestNextMatched = nnSearcherDestTimeToNextSpike.countPointsWithinRs(indexForNextIsDest, - radius_destNext, radius_destNext_lower, true); - // To check for points matching or larger, we just need to make this call with the - // revised lower radius, because we don't check the upper one. - countOfDestNextAndGreaterMatchedDest = - nnSearcherDestTimeToNextSpike.countPointsWithinROrLarger(indexForNextIsDest, radius_destNext_lower, true); - } - double totalSearchTimeWindowCondDestPast = radius_destNext_lower + radius_destNext; - - if (debug && (eventIndex < 10000)) { - System.out.printf(", and %d of %d points for D history only; ", - countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest); - } - - //============================ - // This code section is if we wish to compute using digamma logs: - //============================ - // With these neighbours counted, we're ready to compute the probability of the spike given the past - // of source and dest. - double logPGivenSourceAndDest; - double logPGivenDest; - if (k > 1) { - // We're handling three variables: - logPGivenSourceAndDest = digammaK - twoInverseKTerm - - MathsUtils.digamma(Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater) - + 1.0 / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater); - logPGivenDest = MathsUtils.digamma(countOfDestNextMatched) - - 1.0 / ((double) countOfDestNextMatched) - - MathsUtils.digamma(countOfDestNextAndGreaterMatchedDest); - } else { - // We're really only handling two variables like an MI: - logPGivenSourceAndDest = digammaK - inverseKTerm - - MathsUtils.digamma(Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater); - logPGivenDest = MathsUtils.digamma(countOfDestNextMatched) - - MathsUtils.digamma(countOfDestNextAndGreaterMatchedDest); - } - if (trimToPosNextSpikeTimes) { - logPGivenSourceAndDest -= Math.log(searchAreaRatio); - } - //============================ - - if (debug && (eventIndex < 10000)) { - System.out.printf(" te ~~ log (%d/%d)/(%d/%d) = %.4f wc-> %.5f bc-> %.4f (inferred rates %.4f vs %.4f, " + - "win-cor %.5f vs %.5f, bias-corrected %.5f vs %.5f)\n", Knns, - Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater, - countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest, - Math.log(((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) / - ((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest))), - // TE from Window corrected rates: - Math.log((((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) / - ((destNext_timing_upper - destNext_timing_lower_original)*searchAreaRatio)) / - (((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest)) / - totalSearchTimeWindowCondDestPast)), - // TE from Bias corrected rates: - logPGivenSourceAndDest - logPGivenDest, - // Inferred rates raw: - (double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater) / (2.0*radius_destNext), - (double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest) / (2.0*radius_destNext), - // Inferred rates with window correction: - ((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) / - ((destNext_timing_upper - destNext_timing_lower_original)*searchAreaRatio), - ((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest)) / - totalSearchTimeWindowCondDestPast, - // Transform the log likelihoods into bias corrected rates: - Math.exp(logPGivenSourceAndDest) / (2.0*radius_destNext), Math.exp(logPGivenDest) / (2.0*radius_destNext)); - if (trimToPosNextSpikeTimes) { - System.out.printf("Search area ratio: %.5f, correction %.5f, t_y_upper %.5f, t_y_lower %.5f\n", - searchAreaRatio, -Math.log(searchAreaRatio), prevSource_timing_upper_original, prevSource_timing_lower); - } - } - // Unexplained case: - if (countOfDestNextMatched < Knns) { - // TODO Should not happen, print something! - System.out.printf("SHOULD NOT HAPPEN!\n"); - // So debug this: - nnPQ = kdTreesJoint[eventType].findKNearestNeighbours( - Knns, eventIndexWithinType); - for (int j = 0; j < Knns; j++) { - // Take the furthest remaining of the nearest neighbours from the PQ: - NeighbourNodeData nnData = nnPQ.poll(); - System.out.printf("NN data %d norms: source %.5f, destPast %.5f, destNext %.5f, nextIndex %d\n", - nnData.norms[0], nnData.norms[1], nnData.norms[2], nnData.sampleIndex); - } - - } - - //====================== - // Add the contribution in: - // a.If we were using digamma logs: - contributionFromSpikes += logPGivenSourceAndDest - logPGivenDest; - // b. If we are only using window corrections but actual ratios: - //contributionFromSpikes += - // Math.log((((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) / - // ((destNext_timing_upper - destNext_timing_lower_original)*searchAreaRatio)) / - // (((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest)) / - // totalSearchTimeWindowCondDestPast)); - } else { + if (eventType != NEXT_DEST) { + // Pre-condition: next event is a source spike so we'll continue to check next event if (debug && (eventIndex < 10000)) { System.out.println(); } - } + continue; + } + // Post-condition: the next event is a destination spike: + + // Find the Knns nearest neighbour matches to this event, + // with the same previous spiker and the next. + // TODO Add dynamic exclusion time later + PriorityQueue nnPQ = + kdTreesJoint[NEXT_DEST].findKNearestNeighbours( + Knns, eventIndexWithinType); + + // Find eps_{x,y,z} as the maximum x, y and z norms amongst this set: + double radius_sourcePast = 0.0; + double radius_destPast = 0.0; + double radius_destNext = 0.0; + int radius_destNext_sampleIndex = -1; + for (int j = 0; j < Knns; j++) { + // Take the furthest remaining of the nearest neighbours from the PQ: + NeighbourNodeData nnData = nnPQ.poll(); + if (nnData.norms[0] > radius_sourcePast) { + radius_sourcePast = nnData.norms[0]; + } + if (nnData.norms[1] > radius_destPast) { + radius_destPast = nnData.norms[1]; + } + if (nnData.norms[2] > radius_destNext) { + radius_destNext = nnData.norms[2]; + radius_destNext_sampleIndex = nnData.sampleIndex; + } + } + // Postcondition: radius_* variables hold the search radius for each sourcePast, destPast and destNext matches. + + if (debug && (eventIndex < 10000)) { + System.out.print(", timings: src: "); + MatrixUtils.printArray(System.out, thisEventTimings[0], 5); + System.out.print(", dest: "); + MatrixUtils.printArray(System.out, thisEventTimings[1], 5); + System.out.print(", time to next: "); + MatrixUtils.printArray(System.out, thisEventTimings[2], 5); + System.out.printf("index=%d: K=%d NNs at next_range %.5f (point %d)", eventIndexWithinType, Knns, radius_destNext, radius_destNext_sampleIndex); + } + + indexForNextIsDest++; + + // Now find the matching samples in each sub-space; + // first match dest history and source history, with a next spike in dest: + int numMatches = kdTreesSourceDestHistories[NEXT_DEST]. + findPointsWithinRs(eventIndexWithinType, + new double[] {radius_sourcePast, radius_destPast}, 0, + true, isWithinR, indicesWithinR); + // Set the search point itself to be a neighbour - this is necessary to include the waiting time + // for it in our count: + indicesWithinR[numMatches] = eventIndexWithinType; + indicesWithinR[numMatches+1] = -1; + isWithinR[eventIndexWithinType] = true; + // And check which of these samples had spike time in dest in the window or after ours: + int countOfDestNextAndGreater = 0; + int countOfDestNextInWindow = 0; // Would be Knns except for one on the lower boundary (if there is one) + double timeInWindowWithMatchingJointHistories = 0; + for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { + // Pull out this matching event from the full joint space + double[][] matchedHistoryEventTimings = eventTimings[NEXT_DEST].elementAt(indicesWithinR[nIndex]); + // Use simple labels for relative times from the previous target spike + double matchingHistoryTimeToNextSpike = matchedHistoryEventTimings[2][0]; + // (time to previous source spike from the previous target spike can be negative or positive. + // matchedHistoryEventTimings[0][0] < 0 implies previous source spike occuring later + // than previous target spike; we want opposite sign here). + double matchingHistoryTimeToPrevSourceSpike = -matchedHistoryEventTimings[0][0]; + + // We need to check how long we spent in the window matching the next spike + // with a matching history. + // First, we make sure that the next (target) spike was in the window or after it, + // and that the previous source spike did not occur after the window + // (this is possible since our neighbour match hasn't checked for the next spike time) + if ((matchingHistoryTimeToNextSpike >= timeToNextSpikeSincePreviousDestSpike - radius_destNext) && + (matchingHistoryTimeToPrevSourceSpike <= timeToNextSpikeSincePreviousDestSpike + radius_destNext)) { + + // Real start of window cannot be before previous destination spike: + double realStartOfWindow = Math.max(timeToNextSpikeSincePreviousDestSpike - radius_destNext, 0); + // Also, real start cannot be before previous source spike: + // (previous source spike occurs at -matchedHistoryEventTimings[0][0] relative + // to previous destination spike) + realStartOfWindow = Math.max(realStartOfWindow, matchingHistoryTimeToPrevSourceSpike); + + // Real end of window happened either when the spike occurred (which changes the history) or at + // the end of the window: + double realEndOfWindow = Math.min(matchingHistoryTimeToNextSpike, + timeToNextSpikeSincePreviousDestSpike + radius_destNext); + + // Add in how much time with a matching history we spent in this window: + timeInWindowWithMatchingJointHistories += + realEndOfWindow - realStartOfWindow; + + countOfDestNextAndGreater++; + + // Count spikes occurring here in the window (and check below) + if ((matchingHistoryTimeToNextSpike >= realStartOfWindow) && + (matchingHistoryTimeToNextSpike <= realEndOfWindow)){ + countOfDestNextInWindow++; + } + + } + // Reset the isWithinR array while we're here + isWithinR[indicesWithinR[nIndex]] = false; + } + // TODO Debug check: + // if (countOfDestNextInWindow != Knns + 1) { + // throw new Exception("Unexpected value for countOfDestNextInWindow: " + countOfDestNextInWindow); + //} + + // And count how many samples with the matching history actually had a + // *source* spike next, during or after our window. + // Note that we now must go to the other kdTree for next source spike + kdTreesSourceDestHistories[NEXT_SOURCE]. + findPointsWithinRs( + new double[] {radius_sourcePast, radius_destPast}, thisEventTimings, + true, isWithinR, indicesWithinR); + // And check which of these samples had spike time in source at or after ours: + int countOfSourceNextAndGreater = 0; + for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { + // Pull out this matching event from the full joint space + double[][] matchedHistoryEventTimings = eventTimings[NEXT_SOURCE].elementAt(indicesWithinR[nIndex]); + // Use simple labels for relative times from the previous target spike + double matchingHistoryTimeToNextSpike = matchedHistoryEventTimings[2][0]; + // (time to previous source spike from the previous target spike can be negative or positive. + // matchedHistoryEventTimings[0][0] < 0 implies previous source spike occuring later + // than previous target spike; we want opposite sign here). + double matchingHistoryTimeToPrevSourceSpike = -matchedHistoryEventTimings[0][0]; + + // We need to check how long we spent in the window matching the next spike + // with a matching history. + // First, we make sure that the next (source) spike was in the window or after it, + // and that the previous source spike did not occur after the window + // (this is possible since our neighbour match hasn't checked for the next spike time) + if ((matchingHistoryTimeToNextSpike >= timeToNextSpikeSincePreviousDestSpike - radius_destNext) && + (matchingHistoryTimeToPrevSourceSpike <= timeToNextSpikeSincePreviousDestSpike + radius_destNext)) { + + // Real start of window cannot be before previous destination spike: + double realStartOfWindow = Math.max(timeToNextSpikeSincePreviousDestSpike - radius_destNext, 0); + // Also, real start cannot be before previous source spike: + // (previous source spike occurs at matchingHistoryTimeToPrevSourceSpike relative + // to previous destination spike) + realStartOfWindow = Math.max(realStartOfWindow, matchingHistoryTimeToPrevSourceSpike); + + // Real end of window happened either when the spike occurred or at + // the end of the window: + double realEndOfWindow = Math.min(matchingHistoryTimeToNextSpike, + timeToNextSpikeSincePreviousDestSpike + radius_destNext); + + // Add in how much time with a matching history we spent in this window: + timeInWindowWithMatchingJointHistories += + realEndOfWindow - realStartOfWindow; + + countOfSourceNextAndGreater++; + } + // Reset the isWithinR array while we're here + isWithinR[indicesWithinR[nIndex]] = false; + } + + // We need to count spike rate for all the times we're actually within the matching window for the next spike + // This is kind of inspired by the Greg Ver Steeg et al. approach in + // http://www.jmlr.org/proceedings/papers/v38/gao15.pdf + // which is thinking about where the space is actually being explored. + // This is where the window correction code was placed, which we're now replacing + // with computing the length of actual time we spend in the next window. + + + if (debug && (eventIndex < 10000)) { + System.out.printf(" of %d + %d + %d points with matching S-D history", + Knns, countOfSourceNextAndGreater, countOfDestNextAndGreater); + } + + // Now find the matching samples in the dest history and + // with a next spike timing. + int countOfDestNextAndGreaterMatchedDest = 0; + int countOfDestNextMatched = 0; + double timeInWindowWithMatchingDestHistory = 0; + // Real start of window cannot be before previous destination spike: + double realStartOfWindow = Math.max(timeToNextSpikeSincePreviousDestSpike - radius_destNext, 0); + if (k > 1) { + // Search only the space of dest past -- no point + // searching dest past and next, since we need to run through + // all matches of dest past to count those with greater next spike + // times we might as well count those with matching spike times + // while we're at it. + numMatches = kdTreeDestHistory.findPointsWithinR(indexForNextIsDest, radius_destPast, + true, isWithinR, indicesWithinR); + // Set the search point itself to be a neighbour - this is necessary to include the waiting time + // for it in our count: + indicesWithinR[numMatches] = indexForNextIsDest; + indicesWithinR[numMatches+1] = -1; + isWithinR[indexForNextIsDest] = true; + // And check which of these samples had next spike time after our window starts: + for (int nIndex = 0; indicesWithinR[nIndex] != -1; nIndex++) { + // Pull out this matching event from the dest history space + double[][] matchedHistoryEventTimings = destPastAndNextTimings.elementAt(indicesWithinR[nIndex]); + if (matchedHistoryEventTimings[1][0] >= timeToNextSpikeSincePreviousDestSpike - radius_destNext) { + // This sample had a matched history and next spike was a + // spike with an interval longer than or considered equal to the current sample. + // (The "equal to" is why we look for matches within kthNnData.distance here as well.) + + // Real end of window happened either when the spike occurred or at + // the end of the window: + double realEndOfWindow = Math.min(matchedHistoryEventTimings[1][0], + timeToNextSpikeSincePreviousDestSpike + radius_destNext); + + // Add in how much time with a matching history we spent in this window: + timeInWindowWithMatchingDestHistory += + realEndOfWindow - realStartOfWindow; + + countOfDestNextAndGreaterMatchedDest++; + if (matchedHistoryEventTimings[1][0] <= timeToNextSpikeSincePreviousDestSpike + radius_destNext) { + // Then we also have a match on the next spike itself + countOfDestNextMatched++; + } + } + // Reset the isWithinR array while we're here + isWithinR[indicesWithinR[nIndex]] = false; + } + } else { + + // For k = 1, we only care about time since last spike. + // So count how many of the next spikes were within the window first: + // -- we don't take any past dest spike ISIs into account, so we just need to look at the proportion of next + // spike times that match. + countOfDestNextMatched = nnSearcherDestTimeToNextSpike.countPointsWithinRs(indexForNextIsDest, + radius_destNext, Math.min(radius_destNext, timeToNextSpikeSincePreviousDestSpike), true); + // And also check how long each of these spent in the window: + timeInWindowWithMatchingDestHistory = + nnSearcherDestTimeToNextSpike.sumDistanceAboveThresholdForPointsWithinRs(indexForNextIsDest, + radius_destNext, Math.min(radius_destNext, timeToNextSpikeSincePreviousDestSpike), true); + // Now check for points matching or larger, we just need to make this call with the + // revised lower radius, because we don't check the upper one. + countOfDestNextAndGreaterMatchedDest = + nnSearcherDestTimeToNextSpike.countPointsWithinROrLarger(indexForNextIsDest, + Math.min(radius_destNext, timeToNextSpikeSincePreviousDestSpike), true); + // And we need to add time in for all of the (countOfDestNextAndGreaterMatchedDest - countOfDestNextMatched) + // points which didn't spike in the window + // TODO - check whether this is really doing what it intends to? + timeInWindowWithMatchingDestHistory += (double) (countOfDestNextAndGreaterMatchedDest - countOfDestNextMatched) * + (timeToNextSpikeSincePreviousDestSpike + radius_destNext - realStartOfWindow); + + // Add in the wait time contribution for the search point itself (and need to include it in counts): + timeInWindowWithMatchingDestHistory += timeToNextSpikeSincePreviousDestSpike - realStartOfWindow; + countOfDestNextMatched++; + countOfDestNextAndGreaterMatchedDest++; + + if (timeInWindowWithMatchingDestHistory <= 0) { + // David thought he saw this occur, adding debug exception so we can catch it if so" + throw new Exception("timeInWindowWithMatchingDestHistory is not > 0"); + } + } + + if (debug && (eventIndex < 10000)) { + System.out.printf(", and %d of %d points for D history only; ", + countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest); + } + + //============================ + // This code section takes the counts of spikes and total intervals, and + // estimates the log rates. + // Inferred rates raw: + double rawRateGivenSourceAndDest = (double) countOfDestNextInWindow / timeInWindowWithMatchingJointHistories; + double rawRateGivenDest = (double) countOfDestNextMatched / timeInWindowWithMatchingDestHistory; + // Attempt at bias correction: + // Using digamma of neighbour_count - 1, since the neighbour count now includes our search point and it's really k waiting times + double logRateGivenSourceAndDestCorrected = MathsUtils.digamma(Knns) // - (1.0 / (double) Knns) // I don't think this correction is required + - Math.log(timeInWindowWithMatchingJointHistories); + double logRateGivenDestCorrected = MathsUtils.digamma(countOfDestNextMatched-1) // - (1.0 / (double) countOfDestNextMatched) // I don't think this correction is required + - Math.log(timeInWindowWithMatchingDestHistory); + //============================ + + if (debug && (eventIndex < 10000)) { + System.out.printf(" te ~~ %.4f - %.4f = %.4f, log (%.4f)/(%.4f) = %.4f (counts %d/%d = %.4f, %d/%d = %.4f -> te %.4f)\n", + logRateGivenSourceAndDestCorrected, + logRateGivenDestCorrected, + logRateGivenSourceAndDestCorrected - logRateGivenDestCorrected, + rawRateGivenSourceAndDest, + rawRateGivenDest, + Math.log(rawRateGivenSourceAndDest / rawRateGivenDest), + Knns, + Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater, + (double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater), + countOfDestNextMatched, countOfDestNextAndGreaterMatchedDest, + (double) countOfDestNextMatched / (double) countOfDestNextAndGreaterMatchedDest, + Math.log(((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) / + ((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest)))); + } + + //====================== + // Add the contribution in: + // a.If we were using digamma logs: + contributionFromSpikes += logRateGivenSourceAndDestCorrected - logRateGivenDestCorrected; + // contributionFromSpikes += Math.log(rawRateGivenSourceAndDest / rawRateGivenDest); + + + // b. If we are only using window corrections but actual ratios: + //contributionFromSpikes += + // Math.log((((double) Knns / (double) (Knns + countOfSourceNextAndGreater + countOfDestNextAndGreater)) / + // ((destNext_timing_upper - destNext_timing_lower_original)*searchAreaRatio)) / + // (((double) (countOfDestNextMatched) / (double) (countOfDestNextAndGreaterMatchedDest)) / + // totalSearchTimeWindowCondDestPast)); } System.out.println("All done!"); contributionFromSpikes /= totalTimeLength; diff --git a/java/source/infodynamics/utils/KdTree.java b/java/source/infodynamics/utils/KdTree.java old mode 100755 new mode 100644 diff --git a/java/source/infodynamics/utils/NearestNeighbourSearcher.java b/java/source/infodynamics/utils/NearestNeighbourSearcher.java old mode 100755 new mode 100644 diff --git a/java/source/infodynamics/utils/UnivariateNearestNeighbourSearcher.java b/java/source/infodynamics/utils/UnivariateNearestNeighbourSearcher.java old mode 100755 new mode 100644