Merge pull request #93 from dpshorten/master

Spike train TE estimation
This commit is contained in:
Joseph Lizier 2022-09-05 14:38:21 +10:00 committed by GitHub
commit e747e8aa93
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 4168 additions and 0 deletions

View File

@ -0,0 +1,259 @@
# Argument order: network_type_name num_spikes sim_number target_index
from jpype import *
import random
import math
import os
import numpy as np
import pickle
import copy
import sys
# net_type_name is useful if you are iterating over multiple files with different network types.
# Looking at the definition of SPIKES_FILE_NAME and OUTPUT_FILE_PREFIX will imply what the purpose of
# these command line arguments is.
net_type_name = sys.argv[1]
num_spikes_string = sys.argv[2]
repeat_num_string = sys.argv[3]
target_index_string = sys.argv[4]
# The number of surrogates to create for each significance test of a TE value
NUM_SURROGATES_PER_TE_VAL = 100
# The p level below which the null hypothesis will be rejected.
P_LEVEL = 0.05
# The number of nearest neighbours to consider in the TE estimation.
KNNS = 10
# The number of random sample points laid down will be NUM_SAMPLES_MULTIPLIER * length_of_target_train
NUM_SAMPLES_MULTIPLIER = 5.0
#SURROGATE_NUM_SAMPLES_MULTIPLIER = 5.0
# As above, but for the creation of surrogates
SURROGATE_NUM_SAMPLES_MULTIPLIER = 5.0
# The number of nearest neighbours to consider when using the local permutation method to create surrogates
K_PERM = 20
# The level of the noise to add to the random sample points used in creating surrogates
JITTERING_LEVEL = 2000
# When MAX_NUM_SECOND_INTERVALS sources have 2 or more history intervals added into the conditioning set, the inference stops
MAX_NUM_SECOND_INTERVALS = 2
# Exclude target spikes beyond this number
MAX_NUM_TARGET_SPIKES = int(num_spikes_string)
# The spikes file with the below name is expected to contain a single pickled Python list. This list contains numpy arrays. Each
# numpy array contains the spike times of each candidate target.
SPIKES_FILE_NAME = "spikes_LIF_" + net_type_name + "_" + repeat_num_string + ".pk"
# The ground truth file of the below name is expected to contain a single pickled Python list. This list contains tuples of the format(source, target).
# source and target are integers of the indices of true connections.
GROUND_TRUTH_FILE_NAME = "connections_LIF_"+ net_type_name + "_" + repeat_num_string + ".pk"
OUTPUT_FILE_PREFIX = "results/inferred_sources_target_2_" + net_type_name + "_" + num_spikes_string + "_" + repeat_num_string + "_" + target_index_string
LOG_FILE_NAME = "logs/" + net_type_name + "_" + num_spikes_string + "_" + repeat_num_string + "_" + target_index_string + ".log"
log = open(LOG_FILE_NAME, "w")
sys.stdout = log
def prepare_conditional_trains(calc_object, cond_set):
cond_trains = []
calc_object.clearConditionalIntervals()
if len(cond_set) > 0:
for key in cond_set.keys():
cond_trains.append(spikes[key])
teCalc.appendConditionalIntervals(JArray(JInt, 1)(cond_set[key]))
return cond_trains
def set_target_embeddings(embedding_list):
if len(embedding_list) > 0:
embedding_string = str(embedding_list[0])
for i in range(2, len(embedding_list)):
embedding_string += "," + str(embedding_list[i])
teCalc.setProperty("DEST_PAST_INTERVALS", embedding_string)
else:
teCalc.setProperty("DEST_PAST_INTERVALS", "")
target_index = int(target_index_string)
print("\n****** Network inference for target neuron", target_index, "******\n\n")
# Setup JIDT
jarLocation = os.path.join(os.getcwd(), "../jidt/infodynamics.jar");
if (not(os.path.isfile(jarLocation))):
exit("infodynamics.jar not found (expected at " + os.path.abspath(jarLocation) + ") - are you running from demos/python?")
startJVM(getDefaultJVMPath(), "-ea", "-Djava.class.path=" + jarLocation)
teCalcClass = JPackage("infodynamics.measures.spiking.integration").TransferEntropyCalculatorSpikingIntegration
teCalc = teCalcClass()
teCalc.setProperty("knns", str(KNNS))
teCalc.setProperty("NUM_SAMPLES_MULTIPLIER", str(NUM_SAMPLES_MULTIPLIER))
teCalc.setProperty("SURROGATE_NUM_SAMPLES_MULTIPLIER", str(SURROGATE_NUM_SAMPLES_MULTIPLIER))
teCalc.setProperty("K_PERM", str(K_PERM))
teCalc.setProperty("DO_JITTERED_SAMPLING", "true")
teCalc.setProperty("JITTERED_SAMPLING_NOISE_LEVEL", str(JITTERING_LEVEL))
# Load spikes and ground truth connectivity
spikes = pickle.load(open(SPIKES_FILE_NAME, 'rb'))
cons = pickle.load(open(GROUND_TRUTH_FILE_NAME, 'rb'))
if MAX_NUM_TARGET_SPIKES < len(spikes[target_index]):
spikes[target_index] = spikes[target_index][:MAX_NUM_TARGET_SPIKES]
print("Number of target spikes: ", len(spikes[target_index]), "\n\n")
# First determine the correct target embedding
target_embedding_set = [1]
next_target_interval = 2
still_significant = True
print("**** Determining target embedding set ****\n")
while still_significant:
set_target_embeddings(target_embedding_set)
teCalc.setProperty("SOURCE_PAST_INTERVALS", str(next_target_interval))
teCalc.startAddObservations()
teCalc.addObservations(JArray(JDouble, 1)(spikes[target_index]), JArray(JDouble, 1)(spikes[target_index]))
teCalc.finaliseAddObservations();
TE = teCalc.computeAverageLocalOfObservations()
sig = teCalc.computeSignificance(NUM_SURROGATES_PER_TE_VAL, TE)
print("candidate interval:", next_target_interval, " TE:", TE, " p val:", sig.pValue)
if sig.pValue > P_LEVEL:
print("Lost significance, end of target embedding determination")
still_significant = False
else:
target_embedding_set.append(next_target_interval)
next_target_interval += 1
print("target embedding set:", target_embedding_set, "\n\n")
# Now add the sources
# cond_set is a dictionary where keys are added sources and values are lists of included intervals for the
# source key.
cond_set = dict()
# next_interval_for_each_candidate will be a matrix with two columns
# first column has the source indices, second has the next interval that will be considered
next_interval_for_each_candidate = np.arange(0, len(spikes), dtype = np.intc)
next_interval_for_each_candidate = next_interval_for_each_candidate[next_interval_for_each_candidate != target_index]
next_interval_for_each_candidate = np.column_stack((next_interval_for_each_candidate, np.ones(len(next_interval_for_each_candidate), dtype = np.intc)))
still_significant = True
TE_vals_at_each_round = []
surrogate_vals_at_each_round = []
print("**** Adding Sources ****\n")
num_twos = 0
while still_significant:
print("Current conditioning set:")
for key in cond_set.keys():
print("source", key, "intervals", cond_set[key])
print("\nEstimating TE on candidate sources")
cond_trains = prepare_conditional_trains(teCalc, cond_set)
TE_vals = np.zeros(next_interval_for_each_candidate.shape[0])
debiased_TE_vals = -1 * np.ones(next_interval_for_each_candidate.shape[0])
surrogate_vals = -1 * np.ones((next_interval_for_each_candidate.shape[0], NUM_SURROGATES_PER_TE_VAL))
debiased_surrogate_vals = 1 - np.ones((next_interval_for_each_candidate.shape[0], NUM_SURROGATES_PER_TE_VAL))
is_con = np.zeros(next_interval_for_each_candidate.shape[0])
for i in range(next_interval_for_each_candidate.shape[0]):
if len(spikes[next_interval_for_each_candidate[i, 0]]) < 10:
continue
teCalc.startAddObservations()
teCalc.setProperty("SOURCE_PAST_INTERVALS", str(next_interval_for_each_candidate[i, 1]))
if len(cond_set) > 0:
teCalc.addObservations(JArray(JDouble, 1)(spikes[next_interval_for_each_candidate[i, 0]]),
JArray(JDouble, 1)(spikes[target_index]), JArray(JDouble, 2)(cond_trains))
else:
teCalc.addObservations(JArray(JDouble, 1)(spikes[next_interval_for_each_candidate[i, 0]]),
JArray(JDouble, 1)(spikes[target_index]))
teCalc.finaliseAddObservations();
TE_vals[i] = teCalc.computeAverageLocalOfObservations()
is_con[i] = ([next_interval_for_each_candidate[i, 0], target_index] in cons)
sig = teCalc.computeSignificance(NUM_SURROGATES_PER_TE_VAL, TE_vals[i])
surrogate_vals[i] = sig.distribution
debiased_TE_vals[i] = TE_vals[i] - np.mean(surrogate_vals[i])
debiased_surrogate_vals[i] = sig.distribution - np.mean(surrogate_vals[i])
print("Source", next_interval_for_each_candidate[i, 0], "Interval", next_interval_for_each_candidate[i, 1],
" TE:", str(debiased_TE_vals[i]))
log.flush()
TE_vals_at_each_round.append(TE_vals)
surrogate_vals_at_each_round.append(surrogate_vals)
sorted_TE_indices = np.argsort(debiased_TE_vals)
print("\nSorted order of sources:\n", next_interval_for_each_candidate[:, 0][sorted_TE_indices[:]])
print("Ground truth for sorted order:\n", is_con[sorted_TE_indices[:]])
index_of_max_candidate = sorted_TE_indices[-1]
samples_from_max_dist = np.max(debiased_surrogate_vals, axis = 0)
np.sort(samples_from_max_dist)
index_of_first_greater_than_estimate = np.searchsorted(samples_from_max_dist > debiased_TE_vals[index_of_max_candidate], 1)
p_val = (NUM_SURROGATES_PER_TE_VAL - index_of_first_greater_than_estimate)/float(NUM_SURROGATES_PER_TE_VAL)
print("\nMaximum candidate is source", next_interval_for_each_candidate[index_of_max_candidate, 0],
"interval", next_interval_for_each_candidate[index_of_max_candidate, 1])
print("p: ", p_val)
if p_val <= P_LEVEL:
if (next_interval_for_each_candidate[index_of_max_candidate, 0]) in cond_set:
cond_set[next_interval_for_each_candidate[index_of_max_candidate, 0]].append(next_interval_for_each_candidate[index_of_max_candidate, 1])
else:
cond_set[next_interval_for_each_candidate[index_of_max_candidate, 0]] = [next_interval_for_each_candidate[index_of_max_candidate, 1]]
if next_interval_for_each_candidate[index_of_max_candidate, 1] == 2:
num_twos += 1
if num_twos >= MAX_NUM_SECOND_INTERVALS:
print("\nMaximum number of second intervals reached\n\n")
still_significant = False
next_interval_for_each_candidate[index_of_max_candidate, 1] += 1
print("\nCandidate added\n\n")
else:
still_significant = False
print("\nLost Significance\n\n")
print("**** Pruning Sources ****\n")
# Repeatedly removes the connection that has the lowest TE out of all insignificant connections.
# Only considers the furthest intervals as candidates in each round.
everything_significant = False
while not everything_significant:
print("Current conditioning set:")
for key in cond_set.keys():
print("source", key, "intervals", cond_set[key])
print("\nEstimating TE on candidate sources")
everything_significant = True
insignificant_sources = []
insignificant_sources_TE = []
for candidate_source in cond_set:
cond_set_minus_candidate = copy.deepcopy(cond_set)
# If more than one interval, remove the last
if len(cond_set_minus_candidate[candidate_source]) > 1:
cond_set_minus_candidate[candidate_source] = cond_set_minus_candidate[candidate_source][:-1]
# Otherwise, remove source from dict
else:
cond_set_minus_candidate.pop(candidate_source)
teCalc.setProperty("SOURCE_PAST_INTERVALS", str(cond_set[candidate_source][-1]))
cond_trains = prepare_conditional_trains(teCalc, cond_set_minus_candidate)
teCalc.startAddObservations()
if len(cond_set_minus_candidate) > 0:
teCalc.addObservations(JArray(JDouble, 1)(spikes[candidate_source]), JArray(JDouble, 1)(spikes[target_index]), JArray(JDouble, 2)(cond_trains))
else:
teCalc.addObservations(JArray(JDouble, 1)(spikes[candidate_source]), JArray(JDouble, 1)(spikes[target_index]))
teCalc.finaliseAddObservations();
TE = teCalc.computeAverageLocalOfObservations()
sig = teCalc.computeSignificance(NUM_SURROGATES_PER_TE_VAL, TE)
print("Source", candidate_source, "Interval", cond_set[candidate_source][-1],
" TE:", str(round(TE, 2)), " p val:", sig.pValue)
if sig.pValue > P_LEVEL:
everything_significant = False
insignificant_sources.append(candidate_source)
insignificant_sources_TE.append(TE)
if not everything_significant:
min_TE_source = insignificant_sources[np.argmin(insignificant_sources_TE)]
print("removing source", min_TE_source, "interval", cond_set[min_TE_source][-1])
if len(cond_set[min_TE_source]) > 1:
cond_set[min_TE_source] = cond_set[min_TE_source][:-1]
else:
cond_set.pop(min_TE_source)
print("\n\n****** Final Inferred Source Set ******\n")
for key in cond_set.keys():
print("source", key, "intervals", cond_set[key])
print("\nTrue Sources:")
for con in cons:
if con[1] == target_index:
print(con[0], " ",)
output_file = open(OUTPUT_FILE_PREFIX + ".pk", 'wb')
pickle.dump(cond_set, output_file)
#pickle.dump(surrogate_vals_at_each_round, output_file)
#pickle.dump(TE_vals_at_each_round, output_file)
output_file.close()

View File

@ -0,0 +1,46 @@
# This script converts CSV files of spike times (e.g. from the Wagenaar data set) into
# pickle files of spike times in the format that the net_inf.py script expects
import numpy as np
import pickle
import sys
import ast
import matplotlib.pyplot as plt
RUN = "1-1-20.2"
spk_file = open('extracted_data_wagenaar/1-1/' + RUN + '.spk', 'r')
time_upper = 8 * 60 * 60 * 2.5e4
spikes = []
for line in spk_file:
line = line.strip()
line = line.split(",")
line = [float(time) for time in line if time != ""]
spikes.append(np.array(line))
start_times = [train[0] for train in spikes if len(train) > 0]
lowest_start_time = min(start_times)
cutoff_time = lowest_start_time + time_upper
for i in range(len(spikes)):
spikes[i] = spikes[i][spikes[i] < cutoff_time]
spikes[i] = spikes[i] - lowest_start_time
spikes[i] = spikes[i] + np.random.uniform(size = spikes[i].shape) - 0.5
spikes[i] = np.sort(spikes[i])
print(len(spikes))
for i in range(len(spikes)):
print(spikes[i].shape)
print(spikes[i][:10])
#plt.eventplot(spikes, linewidth = 0.5)
#plt.show()
spikes_file = open("spikes_LIF_" + RUN + "_" + sys.argv[1] + ".pk", "wb")
pickle.dump(spikes, spikes_file)
cons = [[0, 0]]
connections_file = open("connections_LIF_" + RUN + "_" + sys.argv[1] + ".pk", "wb")
pickle.dump(cons, connections_file)
spikes_file.close()
connections_file.close()

View File

@ -0,0 +1,192 @@
##
## Java Information Dynamics Toolkit (JIDT)
## Copyright (C) 2022, David P. Shorten, Joseph T. Lizier
##
## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.
##
## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
## GNU General Public License for more details.
##
## You should have received a copy of the GNU General Public License
## along with this program. If not, see <http://www.gnu.org/licenses/>.
##
# Transfer entropy (TE) calculation on generated spike train data using the continuous-time TE estimator.
from jpype import *
import random
import math
import os
import numpy as np
NUM_REPS = 2
NUM_SPIKES = int(2e3)
NUM_OBSERVATIONS = 2
NUM_SURROGATES = 10
# Params for canonical example generation
RATE_Y = 1.0
RATE_X_MAX = 10
def generate_canonical_example_processes(num_y_events):
event_train_x = []
event_train_x.append(0)
event_train_y = np.random.uniform(0, int(num_y_events / RATE_Y), int(num_y_events))
event_train_y.sort()
most_recent_y_index = 0
previous_x_candidate = 0
while most_recent_y_index < (len(event_train_y) - 1):
this_x_candidate = previous_x_candidate + random.expovariate(RATE_X_MAX)
while most_recent_y_index < (len(event_train_y) - 1) and this_x_candidate > event_train_y[most_recent_y_index + 1]:
most_recent_y_index += 1
delta_t = this_x_candidate - event_train_y[most_recent_y_index]
rate = 0
if delta_t > 1:
rate = 0.5
else:
rate = 0.5 + 5.0 * math.exp(-50 * (delta_t - 0.5)**2) - 5.0 * math.exp(-50 * (0.5)**2)
if random.random() < rate/float(RATE_X_MAX):
event_train_x.append(this_x_candidate)
previous_x_candidate = this_x_candidate
event_train_x.sort()
event_train_y.sort()
return event_train_x, event_train_y
# Change location of jar to match yours (we assume script is called from demos/python):
jarLocation = os.path.join(os.getcwd(), "infodynamics.jar");
if (not(os.path.isfile(jarLocation))):
exit("infodynamics.jar not found (expected at " + os.path.abspath(jarLocation) + ") - are you running from demos/python?")
# Start the JVM (add the "-Xmx" option with say 1024M if you get crashes due to not enough memory space)
startJVM(getDefaultJVMPath(), "-ea", "-Djava.class.path=" + jarLocation)
teCalcClass = JPackage("infodynamics.measures.spiking.integration").TransferEntropyCalculatorSpikingIntegration
teCalc = teCalcClass()
teCalc.setProperty("knns", "4")
print("Independent Poisson Processes")
teCalc.setProperty("DEST_PAST_INTERVALS", "1,2")
teCalc.setProperty("SOURCE_PAST_INTERVALS", "1,2")
teCalc.setProperty("DO_JITTERED_SAMPLING", "true")
teCalc.setProperty("JITTERED_SAMPLING_NOISE_LEVEL", "0")
teCalc.appendConditionalIntervals(JArray(JInt, 1)([1, 2]))
teCalc.appendConditionalIntervals(JArray(JInt, 1)([1, 2]))
teCalc.setProperty("NORM_TYPE", "MAX_NORM")
results_poisson = np.zeros(NUM_REPS)
for i in range(NUM_REPS):
teCalc.startAddObservations()
for j in range(NUM_OBSERVATIONS):
sourceArray = NUM_SPIKES*np.random.random(NUM_SPIKES)
sourceArray.sort()
destArray = NUM_SPIKES*np.random.random(NUM_SPIKES)
destArray.sort()
condArray = NUM_SPIKES*np.random.random((2, NUM_SPIKES))
condArray.sort(axis = 1)
teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray), JArray(JDouble, 2)(condArray))
teCalc.finaliseAddObservations();
result = teCalc.computeAverageLocalOfObservations()
print("TE result %.4f nats" % (result,))
sig = teCalc.computeSignificance(NUM_SURROGATES, result)
print(sig.pValue)
results_poisson[i] = result
print("Summary: mean ", np.mean(results_poisson), " std dev ", np.std(results_poisson))
teCalc = teCalcClass()
teCalc.setProperty("knns", "4")
print("Noisy copy zero TE")
#teCalc.appendConditionalIntervals(JArray(JInt, 1)([1]))
teCalc.setProperty("DEST_PAST_INTERVALS", "1")
teCalc.setProperty("SOURCE_PAST_INTERVALS", "1")
teCalc.setProperty("DO_JITTERED_SAMPLING", "true")
teCalc.setProperty("JITTERED_SAMPLING_NOISE_LEVEL", "0")
#teCalc.setProperty("NORM_TYPE", "MAX_NORM")
results_noisy_zero = np.zeros(NUM_REPS)
for i in range(NUM_REPS):
teCalc.startAddObservations()
for j in range(NUM_OBSERVATIONS):
condArray = np.ones((1, NUM_SPIKES)) + 0.05 * np.random.random((1, NUM_SPIKES))
condArray = np.cumsum(condArray, axis = 1)
condArray.sort(axis = 1)
sourceArray = condArray[0, :] + 0.25 + 0.05 * np.random.normal(size = condArray.shape[1])
sourceArray.sort()
destArray = condArray[0, :] + 0.5 + 0.05 * np.random.normal(size = condArray.shape[1])
destArray.sort()
#teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray), JArray(JDouble, 2)(condArray))
teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray))
teCalc.finaliseAddObservations();
result = teCalc.computeAverageLocalOfObservations()
print("TE result %.4f nats" % (result,))
sig = teCalc.computeSignificance(NUM_SURROGATES, result)
print(sig.pValue)
results_noisy_zero[i] = result
print("Summary: mean ", np.mean(results_noisy_zero), " std dev ", np.std(results_noisy_zero))
teCalc = teCalcClass()
teCalc.setProperty("knns", "4")
print("Noisy copy non-zero TE")
teCalc.appendConditionalIntervals(JArray(JInt, 1)([1]))
teCalc.setProperty("DEST_PAST_INTERVALS", "1,2")
teCalc.setProperty("SOURCE_PAST_INTERVALS", "1")
teCalc.setProperty("DO_JITTERED_SAMPLING", "true")
teCalc.setProperty("JITTERED_SAMPLING_NOISE_LEVEL", "0")
#teCalc.setProperty("NORM_TYPE", "MAX_NORM")
results_noisy_non_zero = np.zeros(NUM_REPS)
for i in range(NUM_REPS):
teCalc.startAddObservations()
for j in range(NUM_OBSERVATIONS):
sourceArray = np.ones((1, NUM_SPIKES)) + 0.05 * np.random.random((1, NUM_SPIKES))
sourceArray = np.cumsum(sourceArray)
sourceArray.sort()
condArray = sourceArray + 0.25 + 0.05 * np.random.normal(size = sourceArray.shape)
condArray.sort()
condArray = np.expand_dims(condArray, 0)
destArray = sourceArray + 0.5 + 0.05 * np.random.normal(size = sourceArray.shape)
destArray.sort()
teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray), JArray(JDouble, 2)(condArray))
teCalc.finaliseAddObservations();
result = teCalc.computeAverageLocalOfObservations()
print("TE result %.4f nats" % (result,))
sig = teCalc.computeSignificance(NUM_SURROGATES, result)
print(sig.pValue)
results_noisy_non_zero[i] = result
print("Summary: mean ", np.mean(results_noisy_non_zero), " std dev ", np.std(results_noisy_zero))
print("Canonical example")
teCalc = teCalcClass()
teCalc.setProperty("knns", "4")
teCalc.setProperty("DEST_PAST_INTERVALS", "1,2")
teCalc.setProperty("SOURCE_PAST_INTERVALS", "1")
teCalc.setProperty("DO_JITTERED_SAMPLING", "true")
teCalc.setProperty("JITTERED_SAMPLING_NOISE_LEVEL", "0")
#teCalc.setProperty("NUM_SAMPLES_MULTIPLIER", "1")
#teCalc.setProperty("NORM_TYPE", "MAX_NORM")
results_canonical = np.zeros(NUM_REPS)
for i in range(NUM_REPS):
event_train_x, event_train_y = generate_canonical_example_processes(NUM_SPIKES)
teCalc.setObservations(JArray(JDouble, 1)(event_train_y), JArray(JDouble, 1)(event_train_x))
result = teCalc.computeAverageLocalOfObservations()
results_canonical[i] = result
print("TE result %.4f nats" % (result,))
sig = teCalc.computeSignificance(NUM_SURROGATES, result)
print(sig.pValue)
print("Summary: mean ", np.mean(results_canonical), " std dev ", np.std(results_canonical))

View File

@ -0,0 +1,305 @@
/*
* Java Information Dynamics Toolkit (JIDT)
* Copyright (C) 2012, Joseph T. Lizier
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package infodynamics.measures.spiking;
import infodynamics.utils.EmpiricalMeasurementDistribution;
/**
* <p>Interface for implementations of the <b>transfer entropy</b> (TE),
* which may be applied to spiking time-series data.
* That is, it is applied to <code>double[]</code> data, as an array
* of time stamps at which spikes were recorded.
* See Schreiber below for the definition of transfer entropy,
* and Lizier et al. for the definition of local transfer entropy,
* and (To be published) for how to measure transfer entropy on spike trains.
* Specifically, this class implements the pairwise or <i>apparent</i>
* transfer entropy; i.e. we compute the transfer that appears to
* come from a single source variable, without examining any other
* potential sources
* (see Lizier et al, PRE, 2008).</p>
*
* <p>
* Usage of the child classes implementing this interface is intended to follow this paradigm:
* </p>
* <ol>
* <li>Construct the calculator;</li>
* <li>Set properties using {@link #setProperty(String, String)}
* e.g. including properties describing
* the source and destination embedding;</li>
* <li>Initialise the calculator using
* {@link #initialise()} or {@link #initialise(int, int)};</li>
* <li>Provide the observations/samples for the calculator
* to set up the PDFs, using:
* <ul>
* <li>{@link #setObservations(double[], double[])}
* for calculations based on single recordings, OR</li>
* <li>The following sequence:<ol>
* <li>{@link #startAddObservations()}, then</li>
* <li>One or more calls to
* {@link #addObservations(double[], double[])}, then</li>
* <li>{@link #finaliseAddObservations()};</li>
* </ol></li>
* </ul></li>
* <li>Compute the required quantities, being one or more of:
* <ul>
* <li>the average TE: {@link #computeAverageLocalOfObservations()};</li>
* <li>the local TE values for these samples: {@link #computeLocalOfPreviousObservations()}</li>
* <li>the distribution of TE values under the null hypothesis
* of no relationship between source and
* destination values: {@link #computeSignificance(int)} or
* {@link #computeSignificance(int[][])}.</li>
* </ul>
* </li>
* <li>
* Return to step 2 or 3 to re-use the calculator on a new data set.
* </li>
* </ol>
* </p>
*
* <p><b>References:</b><br/>
* <ul>
* <li>T. Schreiber, <a href="http://dx.doi.org/10.1103/PhysRevLett.85.461">
* "Measuring information transfer"</a>,
* Physical Review Letters 85 (2) pp.461-464, 2000.</li>
* <li>J. T. Lizier, M. Prokopenko and A. Zomaya,
* <a href="http://dx.doi.org/10.1103/PhysRevE.77.026110">
* "Local information transfer as a spatiotemporal filter for complex systems"</a>
* Physical Review E 77, 026110, 2008.</li>
* <li>To be published</li>
* </ul>
*
* @author Joseph Lizier (<a href="joseph.lizier at gmail.com">email</a>,
* <a href="http://lizier.me/joseph/">www</a>)
*/
public interface TransferEntropyCalculatorSpiking {
/**
* Property name to specify the destination history embedding length k
* (default value 1)
*/
public static final String K_PROP_NAME = "k_HISTORY";
/**
* Property name for embedding length for the source past history vector
* (default value 1)
*/
public static final String L_PROP_NAME = "l_HISTORY";
/* Could try to do this one later. I think we would just consider the history
* of source as being up to this many units of time behind destination and
* no later.
*
* Property name for source-destination delay (default value is 0)
*
public static final String DELAY_PROP_NAME = "DELAY";
*/
/**
* Property name for whether each series of time stamps of spikes is sorted
* into temporal order. (default true)
*/
public static final String TIMESORTED_PROP_NAME = "TIME_SORTED";
/**
* Initialise the calculator for re-use with new observations.
* All parameters remain unchanged.
*
* @throws Exception
*/
public void initialise() throws Exception;
/**
* Initialise the calculator for re-use with new observations.
* A new history embedding length k can be supplied here; all other parameters
* remain unchanged.
*
* @param k destination history embedding length to be considered.
* @throws Exception
*/
public void initialise(int k) throws Exception;
/**
* Initialise the calculator for re-use with new observations.
* New history embedding lengths k and l can be supplied here; all other parameters
* remain unchanged.
*
* @param k destination history embedding length to be considered.
* @param l source history embedding length to be considered.
* @throws Exception
*/
public void initialise(int k, int l) throws Exception;
/**
* Sets properties for the TE calculator.
* New property values are not guaranteed to take effect until the next call
* to an initialise method.
*
* <p>Valid property names, and what their
* values should represent, include:</p>
* <ul>
* <li>{@link #K_PROP_NAME} -- destination history embedding length k
* (default value 1)</li>
* <li>{@link #L_PROP_NAME} -- embedding length for the source past history vector
* (default value 1)</li>
* <li>{@link #TIMESORTED_PROP_NAME} -- whether each series of time stamps of spikes is sorted
* into temporal order. (default "true")</li>
* </ul>
* <p><b>Note:</b> further properties may be defined by child classes.</p>
*
* <p>Unknown property values are ignored.</p>
*
* @param propertyName name of the property
* @param propertyValue value of the property.
* @throws Exception if there is a problem with the supplied value,
* or if the property is recognised but unsupported (eg some
* calculators do not support all of the embedding properties).
*/
public void setProperty(String propertyName, String propertyValue) throws Exception;
/**
* Get current property values for the calculator.
*
* <p>Valid property names, and what their
* values should represent, are the same as those for
* {@link #setProperty(String, String)}</p>
*
* <p>Unknown property values are responded to with a null return value.</p>
*
* @param propertyName name of the property
* @return current value of the property
* @throws Exception for invalid property values
*/
public String getProperty(String propertyName) throws Exception;
/**
* Sets a single set of spiking observations from which to compute the PDF for transfer entropy.
* Cannot be called in conjunction with other methods for setting/adding
* observations.
*
* @param source series of time stamps of spikes for the source variable.
* @param destination series of time stamps of spikes for the destination
* variable. Length will generally be different to the <code>source</code>,
* unlike other transfer entropy implementations, e.g. for {@link infodynamics.measures.continuous.TransferEntropyCalculator}.
* <code>source</code> and <code>destination</code> must have the same reference time point,
* and each array is assumed to be sorted into temporal order unless
* the property {@link #TIMESORTED_PROP_NAME} has been set to false.
*
* @throws Exception
*/
public void setObservations(double source[], double destination[]) throws Exception;
/**
* Signal that we will add in the samples for computing the PDF
* from several disjoint time-series or trials via calls to
* {@link #addObservations(double[], double[])} rather than
* {@link #setObservations(double[], double[])} type methods
* (defined by the child interfaces and classes).
*
*/
public void startAddObservations();
/**
* <p>Adds a new set of spiking observations to update the PDFs with.
* It is intended to be called multiple times, and must
* be called after {@link #startAddObservations()}. Call
* {@link #finaliseAddObservations()} once all observations have
* been supplied.</p>
*
* <p><b>Important:</b> this does not append or overlay these observations to the previously
* supplied observations, but treats them as independent trials - i.e. measurements
* such as the transfer entropy will not join them up to examine k
* consecutive values in time.</p>
*
* <p>Note that the arrays source and destination must not be over-written by the user
* until after {@link #finaliseAddObservations()} has been called
* (they are not copied by this method necessarily, the method
* may simply hold a pointer to them).</p>
*
* @param source series of time stamps of spikes for the source variable.
* Will be returned in ascending sorted order.
* @param destination series of time stamps of spikes for the destination
* variable. Length will generally be different to the <code>source</code>,
* unlike other transfer entropy implementations, e.g. for {@link infodynamics.measures.continuous.TransferEntropyCalculator}.
* <code>source</code> and <code>destination</code> must have the same reference time point,
* and each array is assumed to be sorted into temporal order unless
* the property {@link #TIMESORTED_PROP_NAME} has been set to false.
* Will be returned in ascending sorted order.
* @throws Exception
*/
public void addObservations(double[] source, double[] destination) throws Exception;
/**
* Signal that the observations are now all added via
* {@link #addObservations(double[], double[])}, PDFs can now be constructed.
*
* @throws Exception
*/
public void finaliseAddObservations() throws Exception;
/**
* Query whether the user has added more than a single observation set via the
* {@link #startAddObservations()}, "addObservations" (defined by child interfaces
* and classes), {@link #finaliseAddObservations()} sequence.
*
* @return true if more than a single observation set was supplied
*/
public boolean getAddedMoreThanOneObservationSet();
/**
* Compute the TE from the previously-supplied samples.
*
* @return the estimate of the channel measure
*/
public double computeAverageLocalOfObservations() throws Exception;
/**
* This interface serves to indicate the return type of {@link TransferEntropyCalculator#computeLocalOfPreviousObservations()}
* as each child implementation will return something specific
*
* @author Joseph Lizier
*
*/
public interface SpikingLocalInformationValues {
// Left empty intentionally
}
/**
* @return an object containing a representation of
* the of local TE values. The precise contents of this representation
* will vary depending on the underlying implementation
*/
public SpikingLocalInformationValues computeLocalOfPreviousObservations() throws Exception;
public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck, double estimatedValue) throws Exception;
public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck,
double estimatedValue, long randomSeed) throws Exception;
/**
* Set or clear debug mode for extra debug printing to stdout
*
* @param debug new setting for debug mode (on/off)
*/
public void setDebug(boolean debug);
/**
* Return the TE last calculated in a call to {@link #computeAverageLocalOfObservations()}
* or {@link #computeLocalOfPreviousObservations()} after the previous
* {@link #initialise()} call.
*
* @return the last computed channel measure value
*/
public double getLastAverage();
}

View File

@ -0,0 +1,965 @@
package infodynamics.measures.spiking.integration;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.Vector;
import java.util.ArrayList;
import java.lang.Math;
//import infodynamics.measures.continuous.kraskov.EuclideanUtils;
import infodynamics.measures.spiking.TransferEntropyCalculatorSpiking;
import infodynamics.utils.EmpiricalMeasurementDistribution;
import infodynamics.utils.KdTree;
import infodynamics.utils.MathsUtils;
import infodynamics.utils.MatrixUtils;
import infodynamics.utils.NeighbourNodeData;
import infodynamics.utils.FirstIndexComparatorDouble;
import infodynamics.utils.UnivariateNearestNeighbourSearcher;
import infodynamics.utils.EuclideanUtils;
import infodynamics.utils.ParsedProperties;
/**
* Computes the transfer entropy between a pair of spike trains, using an
* integration-based measure in order to match the theoretical form of TE
* between such spike trains.
*
* <p>
* Usage paradigm is as per the interface
* {@link TransferEntropyCalculatorSpiking}
* </p>
*
* @author Joseph Lizier (<a href="joseph.lizier at gmail.com">email</a>,
* <a href="http://lizier.me/joseph/">www</a>)
*/
/*
* TODO
* This implementation of the estimator does not implement dynamic exclusion windows. Such windows make sure
* that history embeddings that overlap are not considered in nearest-neighbour searches (as this breaks the
* independece assumption). Getting dynamic exclusion windows working will probably require modifications to the
* KdTree class.
*/
public class TransferEntropyCalculatorSpikingIntegration implements TransferEntropyCalculatorSpiking {
/**
* The past destination interspike intervals to consider
* (and associated property name and convenience length variable)
* The code assumes that the first interval is numbered 1, the next is numbered 2, etc.
* It also assumes that the intervals are sorted. The setter method performs sorting to ensure this.
*/
public static final String DEST_PAST_INTERVALS_PROP_NAME = "DEST_PAST_INTERVALS";
protected int[] destPastIntervals = new int[] {};
protected int numDestPastIntervals = 0;
/**
* As above but for the sources
*/
public static final String SOURCE_PAST_INTERVALS_PROP_NAME = "SOURCE_PAST_INTERVALS";
protected int[] sourcePastIntervals = new int[] {};
protected int numSourcePastIntervals = 0;
/**
* As above but for the conditioning processes. There is no property name for this variable,
* due to there not currently being a method for converting strings to 2d arrays. Instead,
* a separate setter method is implemented.
*/
protected Vector<int[]> vectorOfCondPastIntervals = new Vector<int[]>();
protected int numCondPastIntervals = 0;
/**
* Number of nearest neighbours to search for in the full joint space
*/
protected int Knns = 4;
/**
* Storage for source observations supplied via
* {@link #addObservations(double[], double[])} etc.
*/
protected Vector<double[]> vectorOfSourceSpikeTimes = null;
/**
* Storage for destination observations supplied via
* {@link #addObservations(double[], double[])} etc.
*/
protected Vector<double[]> vectorOfDestinationSpikeTimes = null;
/**
* Storage for conditional observations supplied via
* {@link #addObservations(double[], double[])} etc.
*/
protected Vector<double[][]> vectorOfConditionalSpikeTimes = null;
Vector<double[]> conditioningEmbeddingsFromSpikes = null;
Vector<double[]> jointEmbeddingsFromSpikes = null;
Vector<double[]> conditioningEmbeddingsFromSamples = null;
Vector<double[]> jointEmbeddingsFromSamples = null;
Vector<Double> processTimeLengths = null;
protected KdTree kdTreeJointAtSpikes = null;
protected KdTree kdTreeJointAtSamples = null;
protected KdTree kdTreeConditioningAtSpikes = null;
protected KdTree kdTreeConditioningAtSamples = null;
public static final String KNNS_PROP_NAME = "Knns";
/**
* Property name for an amount of random Gaussian noise to be added to the data
* (default is 1e-8, matching the MILCA toolkit).
*/
public static final String PROP_ADD_NOISE = "NOISE_LEVEL_TO_ADD";
/**
* Whether to add an amount of random noise to the incoming data
*/
protected boolean addNoise = true;
/**
* Amount of random Gaussian noise to add to the incoming data
*/
protected double noiseLevel = (double) 1e-8;
/**
* Whether to use the jittered sampling approach. Useful for bursty spike trains. Explained in methods section of
* doi.org/10.1101/2021.06.29.450432
*/
protected boolean jitteredSamplesForSurrogates = false;
public static final String DO_JITTERED_SAMPLING_PROP_NAME = "DO_JITTERED_SAMPLING";
protected double jitteredSamplingNoiseLevel = 1;
public static final String JITTERED_SAMPLING_NOISE_LEVEL = "JITTERED_SAMPLING_NOISE_LEVEL";
/**
* Stores whether we are in debug mode
*/
protected boolean debug = false;
/**
* Property name for the number of random sample points to use as a multiple
* of the number of target spikes.
*/
public static final String PROP_SAMPLE_MULTIPLIER = "NUM_SAMPLES_MULTIPLIER";
protected double numSamplesMultiplier = 2.0;
/**
* Property name for the number of random sample points to use in the construction of the surrogates as a multiple
* of the number of target spikes.
*/
public static final String PROP_SURROGATE_SAMPLE_MULTIPLIER = "SURROGATE_NUM_SAMPLES_MULTIPLIER";
protected double surrogateNumSamplesMultiplier = 2.0;
/**
* Property for the number of nearest neighbours to use in the construction of the surrogates
*/
public static final String PROP_K_PERM = "K_PERM";
protected int kPerm = 10;
/**
* Property name for what type of norm to use between data points
* for each marginal variable -- Options are defined by
* {@link KdTree#setNormType(String)} and the
* default is {@link EuclideanUtils#NORM_EUCLIDEAN}.
*/
public final static String PROP_NORM_TYPE = "NORM_TYPE";
protected int normType = EuclideanUtils.NORM_EUCLIDEAN;
public TransferEntropyCalculatorSpikingIntegration() {
super();
}
/*
* (non-Javadoc)
*
* @see
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(
* int)
*/
@Override
public void initialise() throws Exception {
initialise(0, 0);
}
/*
* (non-Javadoc)
*
* @see
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(
* int)
*/
@Override
public void initialise(int k) throws Exception {
initialise(0, 0);
}
/*
* (non-Javadoc)
*
* @see
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(
* int, int)
*/
@Override
public void initialise(int k, int l) throws Exception {
vectorOfSourceSpikeTimes = null;
vectorOfDestinationSpikeTimes = null;
}
/*
* (non-Javadoc)
*
* @see
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setProperty(
* java.lang.String, java.lang.String)
*/
@Override
public void setProperty(String propertyName, String propertyValue) throws Exception {
boolean propertySet = true;
if (propertyName.equalsIgnoreCase(DEST_PAST_INTERVALS_PROP_NAME)) {
if (propertyValue.length() == 0) {
destPastIntervals = new int[] {};
} else {
int[] destPastIntervalsTemp = ParsedProperties.parseStringArrayOfInts(propertyValue);
for (int interval : destPastIntervalsTemp) {
if (interval < 1) {
throw new Exception ("Invalid interval number less than 1.");
}
}
destPastIntervals = destPastIntervalsTemp;
Arrays.sort(destPastIntervals);
}
} else if (propertyName.equalsIgnoreCase(SOURCE_PAST_INTERVALS_PROP_NAME)) {
if (propertyValue.length() == 0) {
sourcePastIntervals = new int[] {};
} else {
int[] sourcePastIntervalsTemp = ParsedProperties.parseStringArrayOfInts(propertyValue);
for (int interval : sourcePastIntervalsTemp) {
if (interval < 1) {
throw new Exception ("Invalid interval number less than 1.");
}
}
sourcePastIntervals = sourcePastIntervalsTemp;
Arrays.sort(sourcePastIntervals);
}
} else if (propertyName.equalsIgnoreCase(KNNS_PROP_NAME)) {
Knns = Integer.parseInt(propertyValue);
} else if (propertyName.equalsIgnoreCase(DO_JITTERED_SAMPLING_PROP_NAME)) {
jitteredSamplesForSurrogates = Boolean.parseBoolean(propertyValue);
} else if (propertyName.equalsIgnoreCase(JITTERED_SAMPLING_NOISE_LEVEL)) {
jitteredSamplingNoiseLevel = Double.parseDouble(propertyValue);
} else if (propertyName.equalsIgnoreCase(PROP_K_PERM)) {
kPerm = Integer.parseInt(propertyValue);
} else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) {
if (propertyValue.equals("0") || propertyValue.equalsIgnoreCase("false")) {
addNoise = false;
noiseLevel = 0;
} else {
addNoise = true;
noiseLevel = Double.parseDouble(propertyValue);
}
} else if (propertyName.equalsIgnoreCase(PROP_SAMPLE_MULTIPLIER)) {
double tempNumSamplesMultiplier = Double.parseDouble(propertyValue);
if (tempNumSamplesMultiplier <= 0) {
throw new Exception ("Num samples multiplier must be greater than 0.");
} else {
numSamplesMultiplier = tempNumSamplesMultiplier;
}
} else if (propertyName.equalsIgnoreCase(PROP_NORM_TYPE)) {
normType = KdTree.validateNormType(propertyValue);
} else if (propertyName.equalsIgnoreCase(PROP_SURROGATE_SAMPLE_MULTIPLIER)) {
double tempSurrogateNumSamplesMultiplier = Double.parseDouble(propertyValue);
if (tempSurrogateNumSamplesMultiplier <= 0) {
throw new Exception ("Surrogate Num samples multiplier must be greater than 0.");
} else {
surrogateNumSamplesMultiplier = tempSurrogateNumSamplesMultiplier;
}
} else {
// No property was set on this class
propertySet = false;
}
if (debug && propertySet) {
System.out.println(
this.getClass().getSimpleName() + ": Set property " + propertyName + " to " + propertyValue);
}
}
/*
* (non-Javadoc)
*
* @see
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getProperty(
* java.lang.String)
*/
@Override
public String getProperty(String propertyName) throws Exception {
if (propertyName.equalsIgnoreCase(KNNS_PROP_NAME)) {
return Integer.toString(Knns);
} else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) {
return Double.toString(noiseLevel);
} else if (propertyName.equalsIgnoreCase(PROP_SAMPLE_MULTIPLIER)) {
return Double.toString(numSamplesMultiplier);
} else {
// No property matches for this class
return null;
}
}
public void appendConditionalIntervals(int[] intervals) throws Exception{
for (int interval : intervals) {
if (interval < 1) {
throw new Exception ("Invalid interval number less than 1.");
}
}
Arrays.sort(intervals);
vectorOfCondPastIntervals.add(intervals);
}
public void clearConditionalIntervals() {
vectorOfCondPastIntervals = new Vector<int[]>();
}
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* setObservations(double[], double[])
*/
@Override
public void setObservations(double[] source, double[] destination) throws Exception {
startAddObservations();
addObservations(source, destination);
finaliseAddObservations();
}
public void setObservations(double[] source, double[] destination, double[][] conditionals) throws Exception {
startAddObservations();
addObservations(source, destination, conditionals);
finaliseAddObservations();
}
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* startAddObservations()
*/
@Override
public void startAddObservations() {
vectorOfSourceSpikeTimes = new Vector<double[]>();
vectorOfDestinationSpikeTimes = new Vector<double[]>();
vectorOfConditionalSpikeTimes = new Vector<double[][]>();
}
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* addObservations(double[], double[])
*/
@Override
public void addObservations(double[] source, double[] destination) throws Exception {
// Store these observations in our vector for now
vectorOfSourceSpikeTimes.add(source);
vectorOfDestinationSpikeTimes.add(destination);
}
public void addObservations(double[] source, double[] destination, double[][] conditionals) throws Exception {
// Store these observations in our vector for now
vectorOfSourceSpikeTimes.add(source);
vectorOfDestinationSpikeTimes.add(destination);
vectorOfConditionalSpikeTimes.add(conditionals);
}
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* finaliseAddObservations()
*/
@Override
public void finaliseAddObservations() throws Exception {
// Set these conveniance variables as they are used quite a bit later on
numDestPastIntervals = destPastIntervals.length;
numSourcePastIntervals = sourcePastIntervals.length;
numCondPastIntervals = 0;
for (int[] intervals : vectorOfCondPastIntervals) {
numCondPastIntervals += intervals.length;
}
conditioningEmbeddingsFromSpikes = new Vector<double[]>();
jointEmbeddingsFromSpikes = new Vector<double[]>();
conditioningEmbeddingsFromSamples = new Vector<double[]>();
jointEmbeddingsFromSamples = new Vector<double[]>();
processTimeLengths = new Vector<Double>();
// Send all of the observations through:
Iterator<double[]> sourceIterator = vectorOfSourceSpikeTimes.iterator();
Iterator<double[][]> conditionalIterator = null;
if (vectorOfConditionalSpikeTimes.size() > 0) {
conditionalIterator = vectorOfConditionalSpikeTimes.iterator();
}
int timeSeriesIndex = 0;
for (double[] destSpikeTimes : vectorOfDestinationSpikeTimes) {
double[] sourceSpikeTimes = sourceIterator.next();
double[][] conditionalSpikeTimes = null;
if (vectorOfConditionalSpikeTimes.size() > 0) {
conditionalSpikeTimes = conditionalIterator.next();
} else {
conditionalSpikeTimes = new double[][] {};
}
processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, conditioningEmbeddingsFromSpikes,
jointEmbeddingsFromSpikes, conditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples,
numSamplesMultiplier, false);
}
// Convert the vectors to arrays so that they can be put in the trees
double[][] arrayedTargetEmbeddingsFromSpikes = new double[conditioningEmbeddingsFromSpikes.size()][numDestPastIntervals + numCondPastIntervals];
double[][] arrayedJointEmbeddingsFromSpikes = new double[conditioningEmbeddingsFromSpikes.size()][numDestPastIntervals +
numCondPastIntervals + numSourcePastIntervals];
for (int i = 0; i < conditioningEmbeddingsFromSpikes.size(); i++) {
arrayedTargetEmbeddingsFromSpikes[i] = conditioningEmbeddingsFromSpikes.elementAt(i);
arrayedJointEmbeddingsFromSpikes[i] = jointEmbeddingsFromSpikes.elementAt(i);
}
double[][] arrayedTargetEmbeddingsFromSamples = new double[conditioningEmbeddingsFromSamples.size()][numDestPastIntervals + numCondPastIntervals];
double[][] arrayedJointEmbeddingsFromSamples = new double[conditioningEmbeddingsFromSamples.size()][numDestPastIntervals +
numCondPastIntervals + numSourcePastIntervals];
for (int i = 0; i < conditioningEmbeddingsFromSamples.size(); i++) {
arrayedTargetEmbeddingsFromSamples[i] = conditioningEmbeddingsFromSamples.elementAt(i);
arrayedJointEmbeddingsFromSamples[i] = jointEmbeddingsFromSamples.elementAt(i);
}
kdTreeJointAtSpikes = new KdTree(arrayedJointEmbeddingsFromSpikes);
kdTreeJointAtSamples = new KdTree(arrayedJointEmbeddingsFromSamples);
kdTreeConditioningAtSpikes = new KdTree(arrayedTargetEmbeddingsFromSpikes);
kdTreeConditioningAtSamples = new KdTree(arrayedTargetEmbeddingsFromSamples);
kdTreeJointAtSpikes.setNormType(normType);
kdTreeJointAtSamples.setNormType(normType);
kdTreeConditioningAtSpikes.setNormType(normType);
kdTreeConditioningAtSamples.setNormType(normType);
}
protected void makeEmbeddingsAtPoints(double[] pointsAtWhichToMakeEmbeddings, int indexOfFirstPointToUse,
double[] sourceSpikeTimes, double[] destSpikeTimes,
double[][] conditionalSpikeTimes,
Vector<double[]> conditioningEmbeddings,
Vector<double[]> jointEmbeddings) {
Random random = new Random();
// Initialise the starting points of all the tracking variables
int embeddingPointIndex = indexOfFirstPointToUse;
int mostRecentDestIndex = destPastIntervals[destPastIntervals.length - 1];
int mostRecentSourceIndex = sourcePastIntervals[sourcePastIntervals.length - 1] - 1;
int[] mostRecentConditioningIndices = new int[vectorOfCondPastIntervals.size()];
for (int i = 0; i < vectorOfCondPastIntervals.size(); i++) {
mostRecentConditioningIndices[i] = vectorOfCondPastIntervals.elementAt(i)[vectorOfCondPastIntervals.elementAt(i).length - 1] - 1;
}
// Loop through the points at which embeddings need to be made
for (; embeddingPointIndex < pointsAtWhichToMakeEmbeddings.length; embeddingPointIndex++) {
// Advance the tracker of the most recent dest index
while (mostRecentDestIndex < (destSpikeTimes.length - 1)) {
if (destSpikeTimes[mostRecentDestIndex + 1] < pointsAtWhichToMakeEmbeddings[embeddingPointIndex]) {
mostRecentDestIndex++;
} else {
break;
}
}
// Do the same for the most recent source index
while (mostRecentSourceIndex < (sourceSpikeTimes.length - 1)) {
if (sourceSpikeTimes[mostRecentSourceIndex + 1] < pointsAtWhichToMakeEmbeddings[embeddingPointIndex]) {
mostRecentSourceIndex++;
} else {
break;
}
}
// Now advance the trackers for the most recent conditioning indices
for (int j = 0; j < mostRecentConditioningIndices.length; j++) {
while (mostRecentConditioningIndices[j] < (conditionalSpikeTimes[j].length - 1)) {
if (conditionalSpikeTimes[j][mostRecentConditioningIndices[j] + 1] < pointsAtWhichToMakeEmbeddings[embeddingPointIndex]) {
mostRecentConditioningIndices[j]++;
} else {
break;
}
}
}
double[] conditioningPast = new double[numDestPastIntervals + numCondPastIntervals];
double[] jointPast = new double[numDestPastIntervals + numCondPastIntervals + numSourcePastIntervals];
// Add the embedding intervals from the target process
for (int i = 0; i < destPastIntervals.length; i++) {
// Case where we are inserting an interval from an observation point back to the most recent event in the target process
if (destPastIntervals[i] == 1) {
conditioningPast[i] = pointsAtWhichToMakeEmbeddings[embeddingPointIndex] - destSpikeTimes[mostRecentDestIndex];
jointPast[i] = pointsAtWhichToMakeEmbeddings[embeddingPointIndex]
- destSpikeTimes[mostRecentDestIndex];
// Case where we are inserting an inter-event intervvl from the target process
} else {
conditioningPast[i] = destSpikeTimes[mostRecentDestIndex - destPastIntervals[i] + 2]
- destSpikeTimes[mostRecentDestIndex - destPastIntervals[i] + 1];
jointPast[i] = destSpikeTimes[mostRecentDestIndex - destPastIntervals[i] + 2]
- destSpikeTimes[mostRecentDestIndex - destPastIntervals[i] + 1];
}
}
// Add the embeding intervals from the conditional processes
int indexOfNextEmbeddingInterval = numDestPastIntervals;
for (int i = 0; i < vectorOfCondPastIntervals.size(); i++) {
for (int j = 0; j < vectorOfCondPastIntervals.elementAt(i).length; j++) {
// Case where we are inserting an interval from an observation point back to the most recent event in the conditioning process
if (vectorOfCondPastIntervals.elementAt(i)[j] == 1) {
conditioningPast[indexOfNextEmbeddingInterval] = pointsAtWhichToMakeEmbeddings[embeddingPointIndex]
- conditionalSpikeTimes[i][mostRecentConditioningIndices[i]];
jointPast[indexOfNextEmbeddingInterval] = pointsAtWhichToMakeEmbeddings[embeddingPointIndex]
- conditionalSpikeTimes[i][mostRecentConditioningIndices[i]];
// Case where we are inserting an inter-event interval from the conditioning process
} else {
// Convenience variable
int intervalNumber = vectorOfCondPastIntervals.elementAt(i)[j];
conditioningPast[indexOfNextEmbeddingInterval] = conditionalSpikeTimes[i][mostRecentConditioningIndices[i] - intervalNumber + 2]
- conditionalSpikeTimes[i][mostRecentConditioningIndices[i] - intervalNumber + 1];
jointPast[indexOfNextEmbeddingInterval] = conditionalSpikeTimes[i][mostRecentConditioningIndices[i] - intervalNumber + 2]
- conditionalSpikeTimes[i][mostRecentConditioningIndices[i] - intervalNumber + 1];
}
indexOfNextEmbeddingInterval++;
}
}
// Add the embedding intervals from the source process (this only gets added to the joint embeddings)
for (int i = 0; i < sourcePastIntervals.length; i++) {
// Case where we are inserting an interval from an observation point back to the most recent event in the source process
if (sourcePastIntervals[i] == 1) {
jointPast[indexOfNextEmbeddingInterval] = pointsAtWhichToMakeEmbeddings[embeddingPointIndex]
- sourceSpikeTimes[mostRecentSourceIndex];
// Case where we are inserting an inter-event interval from the source process
} else {
jointPast[indexOfNextEmbeddingInterval] = sourceSpikeTimes[mostRecentSourceIndex - sourcePastIntervals[i] + 2]
- sourceSpikeTimes[mostRecentSourceIndex - sourcePastIntervals[i] + 1];
}
indexOfNextEmbeddingInterval++;
}
// Add Gaussian noise, if necessary
if (addNoise) {
for (int i = 0; i < conditioningPast.length; i++) {
conditioningPast[i] = Math.log(conditioningPast[i] + 1.1);
conditioningPast[i] += random.nextGaussian() * noiseLevel;
}
for (int i = 0; i < jointPast.length; i++) {
if (jointPast[i] < 0) {
System.out.println("NEGATIVE");
}
jointPast[i] = Math.log(jointPast[i] + 1.1);
jointPast[i] += random.nextGaussian() * noiseLevel;
}
}
conditioningEmbeddings.add(conditioningPast);
jointEmbeddings.add(jointPast);
}
}
protected int getFirstDestIndex(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes, Boolean setProcessTimeLengths)
throws Exception{
// First sort the spike times in case they were not properly in ascending order:
Arrays.sort(sourceSpikeTimes);
Arrays.sort(destSpikeTimes);
int firstTargetIndexOfEmbedding = destPastIntervals[destPastIntervals.length - 1];
int furthestInterval = sourcePastIntervals[sourcePastIntervals.length - 1];
while (destSpikeTimes[firstTargetIndexOfEmbedding] < sourceSpikeTimes[furthestInterval - 1]) {
firstTargetIndexOfEmbedding++;
}
if (conditionalSpikeTimes.length != vectorOfCondPastIntervals.size()) {
throw new Exception("Number of conditional embedding lengths does not match the number of conditional processes");
}
for (int i = 0; i < conditionalSpikeTimes.length; i++) {
furthestInterval = vectorOfCondPastIntervals.elementAt(i)[vectorOfCondPastIntervals.elementAt(i).length - 1];
while (destSpikeTimes[firstTargetIndexOfEmbedding] <= conditionalSpikeTimes[i][furthestInterval - 1]) {
firstTargetIndexOfEmbedding++;
}
}
// We don't want to reset these lengths when resampling for surrogates
if (setProcessTimeLengths) {
processTimeLengths.add(destSpikeTimes[destSpikeTimes.length - 1] - destSpikeTimes[firstTargetIndexOfEmbedding]);
}
return firstTargetIndexOfEmbedding;
}
protected double[] generateRandomSampleTimes(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes,
double actualNumSamplesMultiplier, int firstTargetIndexOfEmbedding, boolean doJitteredSampling) {
double sampleLowerBound = destSpikeTimes[firstTargetIndexOfEmbedding];
double sampleUpperBound = destSpikeTimes[destSpikeTimes.length - 1];
int num_samples = (int) Math.round(actualNumSamplesMultiplier * (destSpikeTimes.length - firstTargetIndexOfEmbedding + 1));
double[] randomSampleTimes = new double[num_samples];
Random rand = new Random();
if (doJitteredSampling) {
//System.out.println("jittering " + jitteredSamplingNoiseLevel);
for (int i = 0; i < randomSampleTimes.length; i++) {
randomSampleTimes[i] = destSpikeTimes[firstTargetIndexOfEmbedding + (i % (destSpikeTimes.length - firstTargetIndexOfEmbedding - 1))]
+ jitteredSamplingNoiseLevel * (rand.nextDouble() - 0.5);
//randomSampleTimes[i] = -1.0;
if ((randomSampleTimes[i] > sampleUpperBound) || (randomSampleTimes[i] < sampleLowerBound)) {
randomSampleTimes[i] = sampleLowerBound + rand.nextDouble() * (sampleUpperBound - sampleLowerBound);
}
}
} else {
for (int i = 0; i < randomSampleTimes.length; i++) {
randomSampleTimes[i] = sampleLowerBound + rand.nextDouble() * (sampleUpperBound - sampleLowerBound);
}
}
Arrays.sort(randomSampleTimes);
/*System.out.println(sampleLowerBound + " " + sampleUpperBound);
for (int i = 0; i < randomSampleTimes.length; i += 100) {
System.out.print(randomSampleTimes[i] + " ");
}
System.out.println("\n\n\n");*/
return randomSampleTimes;
}
protected void processEventsFromSpikingTimeSeries(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes,
Vector<double[]> conditioningEmbeddingsFromSpikes, Vector<double[]> jointEmbeddingsFromSpikes,
Vector<double[]> conditioningEmbeddingsFromSamples, Vector<double[]> jointEmbeddingsFromSamples,
double actualNumSamplesMultiplier, boolean doJitteredSampling)
throws Exception {
int firstTargetIndexOfEmbedding = getFirstDestIndex(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, true);
double[] randomSampleTimes = generateRandomSampleTimes(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
actualNumSamplesMultiplier, firstTargetIndexOfEmbedding,
doJitteredSampling);
makeEmbeddingsAtPoints(destSpikeTimes, firstTargetIndexOfEmbedding, sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
conditioningEmbeddingsFromSpikes, jointEmbeddingsFromSpikes);
makeEmbeddingsAtPoints(randomSampleTimes, 0, sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
conditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples);
}
protected void processEventsFromSpikingTimeSeries(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes,
Vector<double[]> conditioningEmbeddingsFromSamples, Vector<double[]> jointEmbeddingsFromSamples,
double actualNumSamplesMultiplier, boolean doJitteredSampling)
throws Exception {
int firstTargetIndexOfEmbedding = getFirstDestIndex(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, false);
double[] randomSampleTimes = generateRandomSampleTimes(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
actualNumSamplesMultiplier, firstTargetIndexOfEmbedding,
doJitteredSampling);
makeEmbeddingsAtPoints(randomSampleTimes, 0, sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
conditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples);
}
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* getAddedMoreThanOneObservationSet()
*/
@Override
public boolean getAddedMoreThanOneObservationSet() {
return (vectorOfDestinationSpikeTimes != null) && (vectorOfDestinationSpikeTimes.size() > 1);
}
// Class to allow returning two values in the subsequent method
private static class distanceAndNumPoints {
public double distance;
public int numPoints;
public distanceAndNumPoints(double distance, int numPoints) {
this.distance = distance;
this.numPoints = numPoints;
}
}
private distanceAndNumPoints findMaxDistanceAndNumPointsFromIndices(double[] point, int[] indices, Vector<double[]> setOfPoints) {
double maxDistance = 0;
int i = 0;
for (; indices[i] != -1; i++) {
double distance = KdTree.norm(point, setOfPoints.elementAt(indices[i]), normType);
if (distance > maxDistance) {
maxDistance = distance;
}
}
return new distanceAndNumPoints(maxDistance, i);
}
@Override
public double computeAverageLocalOfObservations() throws Exception {
return computeAverageLocalOfObservations(kdTreeJointAtSpikes, jointEmbeddingsFromSpikes);
}
/*
* We take the actual joint tree at spikes (along with the associated embeddings) as an argument, as we will need to swap these out when
* computing surrogates.
*/
public double computeAverageLocalOfObservations(KdTree actualKdTreeJointAtSpikes, Vector<double[]> actualJointEmbeddingsFromSpikes) throws Exception {
double currentSum = 0;
for (int i = 0; i < conditioningEmbeddingsFromSpikes.size(); i++) {
double radiusJointSpikes = actualKdTreeJointAtSpikes.findKNearestNeighbours(Knns, i).poll().norms[0];
double radiusJointSamples = kdTreeJointAtSamples.findKNearestNeighbours(Knns,
new double[][] { actualJointEmbeddingsFromSpikes.elementAt(i) }).poll().norms[0];
/*
The algorithm specified in box 1 of doi.org/10.1371/journal.pcbi.1008054 specifies finding the maximum of the two radii
just calculated and then redoing the searches in both sets at this radius. In this implementation, however, we make use
of the fact that one radius is equal to the maximum, and so only one search needs to be redone.
*/
double eps = 0.0;
// Need variables for the number of neighbours as this is now variable within the maximum radius
int kJointSpikes = 0;
int kJointSamples = 0;
if (radiusJointSpikes >= radiusJointSamples) {
/*
The maximum was the radius in the set of embeddings at spikes, so redo search in the set of embeddings at randomly
sampled points, using this larger radius.
*/
kJointSpikes = Knns;
int[] indicesWithinR = new int[jointEmbeddingsFromSamples.size()];
boolean[] isWithinR = new boolean[jointEmbeddingsFromSamples.size()];
kdTreeJointAtSamples.findPointsWithinR(radiusJointSpikes + eps,
new double[][] { actualJointEmbeddingsFromSpikes.elementAt(i) },
true,
isWithinR,
indicesWithinR);
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(actualJointEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
jointEmbeddingsFromSamples);
kJointSamples = temp.numPoints;
radiusJointSamples = temp.distance;
} else {
/*
The maximum was the radius in the set of embeddings at randomly sampled points, so redo search in the set of embeddings
at spikes, using this larger radius.
*/
kJointSamples = Knns;
int[] indicesWithinR = new int[jointEmbeddingsFromSamples.size()];
boolean[] isWithinR = new boolean[jointEmbeddingsFromSamples.size()];
actualKdTreeJointAtSpikes.findPointsWithinR(radiusJointSamples + eps,
new double[][] { actualJointEmbeddingsFromSpikes.elementAt(i) },
true,
isWithinR,
indicesWithinR);
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(actualJointEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
actualJointEmbeddingsFromSpikes);
// -1 due to the point itself being in the set
kJointSpikes = temp.numPoints - 1;
radiusJointSpikes = temp.distance;
}
// Repeat the above steps, but in the conditioning (rather than joint) space.
double radiusConditioningSpikes = kdTreeConditioningAtSpikes.findKNearestNeighbours(Knns, i).poll().norms[0];
double radiusConditioningSamples = kdTreeConditioningAtSamples.findKNearestNeighbours(Knns,
new double[][] { conditioningEmbeddingsFromSpikes.elementAt(i) }).poll().norms[0];
int kConditioningSpikes = 0;
int kConditioningSamples = 0;
if (radiusConditioningSpikes >= radiusConditioningSamples) {
kConditioningSpikes = Knns;
int[] indicesWithinR = new int[conditioningEmbeddingsFromSamples.size()];
boolean[] isWithinR = new boolean[conditioningEmbeddingsFromSamples.size()];
kdTreeConditioningAtSamples.findPointsWithinR(radiusConditioningSpikes + eps,
new double[][] { conditioningEmbeddingsFromSpikes.elementAt(i) },
true,
isWithinR,
indicesWithinR);
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(conditioningEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
conditioningEmbeddingsFromSamples);
kConditioningSamples = temp.numPoints;
radiusConditioningSamples = temp.distance;
} else {
kConditioningSamples = Knns;
int[] indicesWithinR = new int[conditioningEmbeddingsFromSamples.size()];
boolean[] isWithinR = new boolean[conditioningEmbeddingsFromSamples.size()];
kdTreeConditioningAtSpikes.findPointsWithinR(radiusConditioningSamples + eps,
new double[][] { conditioningEmbeddingsFromSpikes.elementAt(i) },
true,
isWithinR,
indicesWithinR);
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(conditioningEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
conditioningEmbeddingsFromSpikes);
// -1 due to the point itself being in the set
kConditioningSpikes = temp.numPoints - 1;
radiusConditioningSpikes = temp.distance;
}
/*
* TODO
* The KdTree class defaults to the squared euclidean distance when the euclidean norm is specified. This is fine for Kraskov estimators
* (as the radii are never used, just the numbers of points within radii). It causes problems here though, as we do use the radii and the
* squared euclidean distance is not a distance metric. We get around this by just taking the square root here, but it might be better to
* fix this in the KdTree class.
*/
double tempRadiusJointSamples = radiusJointSamples;
if (normType == EuclideanUtils.NORM_EUCLIDEAN) {
radiusJointSpikes = Math.sqrt(radiusJointSpikes);
radiusJointSamples = Math.sqrt(radiusJointSamples);
radiusConditioningSpikes = Math.sqrt(radiusConditioningSpikes);
radiusConditioningSamples = Math.sqrt(radiusConditioningSamples);
}
currentSum += (MathsUtils.digamma(kJointSpikes) - MathsUtils.digamma(kJointSamples) +
((numDestPastIntervals + numCondPastIntervals + numSourcePastIntervals) * (-Math.log(radiusJointSpikes) + Math.log(radiusJointSamples))) -
MathsUtils.digamma(kConditioningSpikes) + MathsUtils.digamma(kConditioningSamples) +
+ ((numDestPastIntervals + numCondPastIntervals) * (Math.log(radiusConditioningSpikes) - Math.log(radiusConditioningSamples))));
if (Double.isNaN(currentSum)) {
for (double[] embed : jointEmbeddingsFromSamples) {
System.out.println(Arrays.toString(embed));
}
throw new Exception("NaNs in TE clac " + radiusJointSpikes + " " + radiusJointSamples + " " + tempRadiusJointSamples);
}
}
// Normalise by time
double timeSum = 0;
for (Double time : processTimeLengths) {
timeSum += time;
}
currentSum /= timeSum;
return currentSum;
}
@Override
public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck, double estimatedValue) throws Exception {
return computeSignificance(numPermutationsToCheck,
estimatedValue,
System.currentTimeMillis());
}
@Override
public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck,
double estimatedValue, long randomSeed) throws Exception{
Random random = new Random(randomSeed);
double[] surrogateTEValues = new double[numPermutationsToCheck];
for (int permutationNumber = 0; permutationNumber < numPermutationsToCheck; permutationNumber++) {
Vector<double[]> resampledConditioningEmbeddingsFromSamples = new Vector<double[]>();
Vector<double[]> resampledJointEmbeddingsFromSamples = new Vector<double[]>();
// Send all of the observations through:
Iterator<double[]> sourceIterator = vectorOfSourceSpikeTimes.iterator();
Iterator<double[][]> conditionalIterator = null;
if (vectorOfConditionalSpikeTimes.size() > 0) {
conditionalIterator = vectorOfConditionalSpikeTimes.iterator();
}
int timeSeriesIndex = 0;
for (double[] destSpikeTimes : vectorOfDestinationSpikeTimes) {
double[] sourceSpikeTimes = sourceIterator.next();
double[][] conditionalSpikeTimes = null;
if (vectorOfConditionalSpikeTimes.size() > 0) {
conditionalSpikeTimes = conditionalIterator.next();
} else {
conditionalSpikeTimes = new double[][] {};
}
if (jitteredSamplesForSurrogates) {
processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
resampledConditioningEmbeddingsFromSamples, resampledJointEmbeddingsFromSamples,
surrogateNumSamplesMultiplier, true);
} else {
processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
resampledConditioningEmbeddingsFromSamples, resampledJointEmbeddingsFromSamples,
surrogateNumSamplesMultiplier, false);
}
}
// Convert the vectors to arrays so that they can be put in the trees
double[][] arrayedResampledConditioningEmbeddingsFromSamples = new double[resampledConditioningEmbeddingsFromSamples.size()][numDestPastIntervals];
for (int i = 0; i < resampledConditioningEmbeddingsFromSamples.size(); i++) {
arrayedResampledConditioningEmbeddingsFromSamples[i] = resampledConditioningEmbeddingsFromSamples.elementAt(i);
}
KdTree resampledKdTreeConditioningAtSamples = new KdTree(arrayedResampledConditioningEmbeddingsFromSamples);
resampledKdTreeConditioningAtSamples.setNormType(normType);
Vector<double[]> conditionallyPermutedJointEmbeddingsFromSpikes = new Vector(jointEmbeddingsFromSpikes);
Vector<Integer> usedIndices = new Vector<Integer>();
for (int i = 0; i < conditionallyPermutedJointEmbeddingsFromSpikes.size(); i++) {
PriorityQueue<NeighbourNodeData> neighbours =
resampledKdTreeConditioningAtSamples.findKNearestNeighbours(kPerm,
new double[][] {conditioningEmbeddingsFromSpikes.elementAt(i)});
ArrayList<Integer> foundIndices = new ArrayList<Integer>();
for (int j = 0; j < kPerm; j++) {
foundIndices.add(neighbours.poll().sampleIndex);
}
ArrayList<Integer> prunedIndices = new ArrayList<Integer>(foundIndices);
prunedIndices.removeAll(usedIndices);
int chosenIndex = 0;
if (prunedIndices.size() > 0) {
chosenIndex = prunedIndices.get(random.nextInt(prunedIndices.size()));
} else {
chosenIndex = foundIndices.get(random.nextInt(foundIndices.size()));
}
usedIndices.add(chosenIndex);
int embeddingLength = conditionallyPermutedJointEmbeddingsFromSpikes.elementAt(i).length;
for(int j = 0; j < numSourcePastIntervals; j++) {
conditionallyPermutedJointEmbeddingsFromSpikes.elementAt(i)[embeddingLength - numSourcePastIntervals + j] =
resampledJointEmbeddingsFromSamples.elementAt(chosenIndex)[embeddingLength - numSourcePastIntervals + j];
}
}
double[][] arrayedConditionallyPermutedJointEmbeddingsFromSpikes = new double[conditionallyPermutedJointEmbeddingsFromSpikes.size()][];
for (int i = 0; i < conditionallyPermutedJointEmbeddingsFromSpikes.size(); i++) {
arrayedConditionallyPermutedJointEmbeddingsFromSpikes[i] = conditionallyPermutedJointEmbeddingsFromSpikes.elementAt(i);
}
KdTree conditionallyPermutedKdTreeJointFromSpikes = new KdTree(arrayedConditionallyPermutedJointEmbeddingsFromSpikes);
conditionallyPermutedKdTreeJointFromSpikes.setNormType(normType);
surrogateTEValues[permutationNumber] = computeAverageLocalOfObservations(conditionallyPermutedKdTreeJointFromSpikes,
conditionallyPermutedJointEmbeddingsFromSpikes);
}
return new EmpiricalMeasurementDistribution(surrogateTEValues, estimatedValue);
}
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
* computeLocalOfPreviousObservations()
*/
@Override
public SpikingLocalInformationValues computeLocalOfPreviousObservations() throws Exception {
// TODO Auto-generated method stub
return null;
}
/*
* (non-Javadoc)
*
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setDebug(
* boolean)
*/
@Override
public void setDebug(boolean debug) {
this.debug = debug;
}
/*
* (non-Javadoc)
*
* @see
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getLastAverage
* ()
*/
@Override
public double getLastAverage() {
// TODO Auto-generated method stub
return 0;
}
}

View File

@ -0,0 +1,8 @@
/**
*
*/
/**
* @author joseph
*
*/
package infodynamics.measures.spiking.integration;

View File

@ -0,0 +1,8 @@
/**
*
*/
/**
* @author joseph
*
*/
package infodynamics.measures.spiking;

0
java/source/infodynamics/utils/KdTree.java Executable file → Normal file
View File

View File

View File