mirror of https://github.com/jlizier/jidt
commit
e747e8aa93
|
@ -0,0 +1,259 @@
|
||||||
|
# Argument order: network_type_name num_spikes sim_number target_index
|
||||||
|
|
||||||
|
from jpype import *
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
import copy
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# net_type_name is useful if you are iterating over multiple files with different network types.
|
||||||
|
# Looking at the definition of SPIKES_FILE_NAME and OUTPUT_FILE_PREFIX will imply what the purpose of
|
||||||
|
# these command line arguments is.
|
||||||
|
net_type_name = sys.argv[1]
|
||||||
|
num_spikes_string = sys.argv[2]
|
||||||
|
repeat_num_string = sys.argv[3]
|
||||||
|
target_index_string = sys.argv[4]
|
||||||
|
|
||||||
|
# The number of surrogates to create for each significance test of a TE value
|
||||||
|
NUM_SURROGATES_PER_TE_VAL = 100
|
||||||
|
# The p level below which the null hypothesis will be rejected.
|
||||||
|
P_LEVEL = 0.05
|
||||||
|
# The number of nearest neighbours to consider in the TE estimation.
|
||||||
|
KNNS = 10
|
||||||
|
# The number of random sample points laid down will be NUM_SAMPLES_MULTIPLIER * length_of_target_train
|
||||||
|
NUM_SAMPLES_MULTIPLIER = 5.0
|
||||||
|
#SURROGATE_NUM_SAMPLES_MULTIPLIER = 5.0
|
||||||
|
# As above, but for the creation of surrogates
|
||||||
|
SURROGATE_NUM_SAMPLES_MULTIPLIER = 5.0
|
||||||
|
# The number of nearest neighbours to consider when using the local permutation method to create surrogates
|
||||||
|
K_PERM = 20
|
||||||
|
# The level of the noise to add to the random sample points used in creating surrogates
|
||||||
|
JITTERING_LEVEL = 2000
|
||||||
|
|
||||||
|
# When MAX_NUM_SECOND_INTERVALS sources have 2 or more history intervals added into the conditioning set, the inference stops
|
||||||
|
MAX_NUM_SECOND_INTERVALS = 2
|
||||||
|
# Exclude target spikes beyond this number
|
||||||
|
MAX_NUM_TARGET_SPIKES = int(num_spikes_string)
|
||||||
|
# The spikes file with the below name is expected to contain a single pickled Python list. This list contains numpy arrays. Each
|
||||||
|
# numpy array contains the spike times of each candidate target.
|
||||||
|
SPIKES_FILE_NAME = "spikes_LIF_" + net_type_name + "_" + repeat_num_string + ".pk"
|
||||||
|
# The ground truth file of the below name is expected to contain a single pickled Python list. This list contains tuples of the format(source, target).
|
||||||
|
# source and target are integers of the indices of true connections.
|
||||||
|
GROUND_TRUTH_FILE_NAME = "connections_LIF_"+ net_type_name + "_" + repeat_num_string + ".pk"
|
||||||
|
OUTPUT_FILE_PREFIX = "results/inferred_sources_target_2_" + net_type_name + "_" + num_spikes_string + "_" + repeat_num_string + "_" + target_index_string
|
||||||
|
LOG_FILE_NAME = "logs/" + net_type_name + "_" + num_spikes_string + "_" + repeat_num_string + "_" + target_index_string + ".log"
|
||||||
|
|
||||||
|
log = open(LOG_FILE_NAME, "w")
|
||||||
|
sys.stdout = log
|
||||||
|
|
||||||
|
def prepare_conditional_trains(calc_object, cond_set):
|
||||||
|
cond_trains = []
|
||||||
|
calc_object.clearConditionalIntervals()
|
||||||
|
if len(cond_set) > 0:
|
||||||
|
for key in cond_set.keys():
|
||||||
|
cond_trains.append(spikes[key])
|
||||||
|
teCalc.appendConditionalIntervals(JArray(JInt, 1)(cond_set[key]))
|
||||||
|
return cond_trains
|
||||||
|
|
||||||
|
def set_target_embeddings(embedding_list):
|
||||||
|
if len(embedding_list) > 0:
|
||||||
|
embedding_string = str(embedding_list[0])
|
||||||
|
for i in range(2, len(embedding_list)):
|
||||||
|
embedding_string += "," + str(embedding_list[i])
|
||||||
|
teCalc.setProperty("DEST_PAST_INTERVALS", embedding_string)
|
||||||
|
else:
|
||||||
|
teCalc.setProperty("DEST_PAST_INTERVALS", "")
|
||||||
|
|
||||||
|
|
||||||
|
target_index = int(target_index_string)
|
||||||
|
print("\n****** Network inference for target neuron", target_index, "******\n\n")
|
||||||
|
|
||||||
|
|
||||||
|
# Setup JIDT
|
||||||
|
jarLocation = os.path.join(os.getcwd(), "../jidt/infodynamics.jar");
|
||||||
|
if (not(os.path.isfile(jarLocation))):
|
||||||
|
exit("infodynamics.jar not found (expected at " + os.path.abspath(jarLocation) + ") - are you running from demos/python?")
|
||||||
|
startJVM(getDefaultJVMPath(), "-ea", "-Djava.class.path=" + jarLocation)
|
||||||
|
teCalcClass = JPackage("infodynamics.measures.spiking.integration").TransferEntropyCalculatorSpikingIntegration
|
||||||
|
teCalc = teCalcClass()
|
||||||
|
teCalc.setProperty("knns", str(KNNS))
|
||||||
|
teCalc.setProperty("NUM_SAMPLES_MULTIPLIER", str(NUM_SAMPLES_MULTIPLIER))
|
||||||
|
teCalc.setProperty("SURROGATE_NUM_SAMPLES_MULTIPLIER", str(SURROGATE_NUM_SAMPLES_MULTIPLIER))
|
||||||
|
teCalc.setProperty("K_PERM", str(K_PERM))
|
||||||
|
teCalc.setProperty("DO_JITTERED_SAMPLING", "true")
|
||||||
|
teCalc.setProperty("JITTERED_SAMPLING_NOISE_LEVEL", str(JITTERING_LEVEL))
|
||||||
|
|
||||||
|
# Load spikes and ground truth connectivity
|
||||||
|
spikes = pickle.load(open(SPIKES_FILE_NAME, 'rb'))
|
||||||
|
cons = pickle.load(open(GROUND_TRUTH_FILE_NAME, 'rb'))
|
||||||
|
if MAX_NUM_TARGET_SPIKES < len(spikes[target_index]):
|
||||||
|
spikes[target_index] = spikes[target_index][:MAX_NUM_TARGET_SPIKES]
|
||||||
|
print("Number of target spikes: ", len(spikes[target_index]), "\n\n")
|
||||||
|
|
||||||
|
|
||||||
|
# First determine the correct target embedding
|
||||||
|
target_embedding_set = [1]
|
||||||
|
next_target_interval = 2
|
||||||
|
still_significant = True
|
||||||
|
print("**** Determining target embedding set ****\n")
|
||||||
|
while still_significant:
|
||||||
|
set_target_embeddings(target_embedding_set)
|
||||||
|
teCalc.setProperty("SOURCE_PAST_INTERVALS", str(next_target_interval))
|
||||||
|
teCalc.startAddObservations()
|
||||||
|
teCalc.addObservations(JArray(JDouble, 1)(spikes[target_index]), JArray(JDouble, 1)(spikes[target_index]))
|
||||||
|
teCalc.finaliseAddObservations();
|
||||||
|
TE = teCalc.computeAverageLocalOfObservations()
|
||||||
|
sig = teCalc.computeSignificance(NUM_SURROGATES_PER_TE_VAL, TE)
|
||||||
|
print("candidate interval:", next_target_interval, " TE:", TE, " p val:", sig.pValue)
|
||||||
|
if sig.pValue > P_LEVEL:
|
||||||
|
print("Lost significance, end of target embedding determination")
|
||||||
|
still_significant = False
|
||||||
|
else:
|
||||||
|
target_embedding_set.append(next_target_interval)
|
||||||
|
next_target_interval += 1
|
||||||
|
print("target embedding set:", target_embedding_set, "\n\n")
|
||||||
|
|
||||||
|
|
||||||
|
# Now add the sources
|
||||||
|
# cond_set is a dictionary where keys are added sources and values are lists of included intervals for the
|
||||||
|
# source key.
|
||||||
|
cond_set = dict()
|
||||||
|
# next_interval_for_each_candidate will be a matrix with two columns
|
||||||
|
# first column has the source indices, second has the next interval that will be considered
|
||||||
|
next_interval_for_each_candidate = np.arange(0, len(spikes), dtype = np.intc)
|
||||||
|
next_interval_for_each_candidate = next_interval_for_each_candidate[next_interval_for_each_candidate != target_index]
|
||||||
|
next_interval_for_each_candidate = np.column_stack((next_interval_for_each_candidate, np.ones(len(next_interval_for_each_candidate), dtype = np.intc)))
|
||||||
|
still_significant = True
|
||||||
|
TE_vals_at_each_round = []
|
||||||
|
surrogate_vals_at_each_round = []
|
||||||
|
print("**** Adding Sources ****\n")
|
||||||
|
num_twos = 0
|
||||||
|
while still_significant:
|
||||||
|
print("Current conditioning set:")
|
||||||
|
for key in cond_set.keys():
|
||||||
|
print("source", key, "intervals", cond_set[key])
|
||||||
|
print("\nEstimating TE on candidate sources")
|
||||||
|
cond_trains = prepare_conditional_trains(teCalc, cond_set)
|
||||||
|
TE_vals = np.zeros(next_interval_for_each_candidate.shape[0])
|
||||||
|
debiased_TE_vals = -1 * np.ones(next_interval_for_each_candidate.shape[0])
|
||||||
|
surrogate_vals = -1 * np.ones((next_interval_for_each_candidate.shape[0], NUM_SURROGATES_PER_TE_VAL))
|
||||||
|
debiased_surrogate_vals = 1 - np.ones((next_interval_for_each_candidate.shape[0], NUM_SURROGATES_PER_TE_VAL))
|
||||||
|
is_con = np.zeros(next_interval_for_each_candidate.shape[0])
|
||||||
|
for i in range(next_interval_for_each_candidate.shape[0]):
|
||||||
|
if len(spikes[next_interval_for_each_candidate[i, 0]]) < 10:
|
||||||
|
continue
|
||||||
|
teCalc.startAddObservations()
|
||||||
|
teCalc.setProperty("SOURCE_PAST_INTERVALS", str(next_interval_for_each_candidate[i, 1]))
|
||||||
|
if len(cond_set) > 0:
|
||||||
|
teCalc.addObservations(JArray(JDouble, 1)(spikes[next_interval_for_each_candidate[i, 0]]),
|
||||||
|
JArray(JDouble, 1)(spikes[target_index]), JArray(JDouble, 2)(cond_trains))
|
||||||
|
else:
|
||||||
|
teCalc.addObservations(JArray(JDouble, 1)(spikes[next_interval_for_each_candidate[i, 0]]),
|
||||||
|
JArray(JDouble, 1)(spikes[target_index]))
|
||||||
|
teCalc.finaliseAddObservations();
|
||||||
|
TE_vals[i] = teCalc.computeAverageLocalOfObservations()
|
||||||
|
is_con[i] = ([next_interval_for_each_candidate[i, 0], target_index] in cons)
|
||||||
|
sig = teCalc.computeSignificance(NUM_SURROGATES_PER_TE_VAL, TE_vals[i])
|
||||||
|
surrogate_vals[i] = sig.distribution
|
||||||
|
debiased_TE_vals[i] = TE_vals[i] - np.mean(surrogate_vals[i])
|
||||||
|
debiased_surrogate_vals[i] = sig.distribution - np.mean(surrogate_vals[i])
|
||||||
|
print("Source", next_interval_for_each_candidate[i, 0], "Interval", next_interval_for_each_candidate[i, 1],
|
||||||
|
" TE:", str(debiased_TE_vals[i]))
|
||||||
|
log.flush()
|
||||||
|
|
||||||
|
TE_vals_at_each_round.append(TE_vals)
|
||||||
|
surrogate_vals_at_each_round.append(surrogate_vals)
|
||||||
|
sorted_TE_indices = np.argsort(debiased_TE_vals)
|
||||||
|
print("\nSorted order of sources:\n", next_interval_for_each_candidate[:, 0][sorted_TE_indices[:]])
|
||||||
|
print("Ground truth for sorted order:\n", is_con[sorted_TE_indices[:]])
|
||||||
|
|
||||||
|
index_of_max_candidate = sorted_TE_indices[-1]
|
||||||
|
samples_from_max_dist = np.max(debiased_surrogate_vals, axis = 0)
|
||||||
|
np.sort(samples_from_max_dist)
|
||||||
|
index_of_first_greater_than_estimate = np.searchsorted(samples_from_max_dist > debiased_TE_vals[index_of_max_candidate], 1)
|
||||||
|
p_val = (NUM_SURROGATES_PER_TE_VAL - index_of_first_greater_than_estimate)/float(NUM_SURROGATES_PER_TE_VAL)
|
||||||
|
print("\nMaximum candidate is source", next_interval_for_each_candidate[index_of_max_candidate, 0],
|
||||||
|
"interval", next_interval_for_each_candidate[index_of_max_candidate, 1])
|
||||||
|
print("p: ", p_val)
|
||||||
|
if p_val <= P_LEVEL:
|
||||||
|
if (next_interval_for_each_candidate[index_of_max_candidate, 0]) in cond_set:
|
||||||
|
cond_set[next_interval_for_each_candidate[index_of_max_candidate, 0]].append(next_interval_for_each_candidate[index_of_max_candidate, 1])
|
||||||
|
else:
|
||||||
|
cond_set[next_interval_for_each_candidate[index_of_max_candidate, 0]] = [next_interval_for_each_candidate[index_of_max_candidate, 1]]
|
||||||
|
|
||||||
|
if next_interval_for_each_candidate[index_of_max_candidate, 1] == 2:
|
||||||
|
num_twos += 1
|
||||||
|
if num_twos >= MAX_NUM_SECOND_INTERVALS:
|
||||||
|
print("\nMaximum number of second intervals reached\n\n")
|
||||||
|
still_significant = False
|
||||||
|
|
||||||
|
next_interval_for_each_candidate[index_of_max_candidate, 1] += 1
|
||||||
|
|
||||||
|
print("\nCandidate added\n\n")
|
||||||
|
else:
|
||||||
|
still_significant = False
|
||||||
|
print("\nLost Significance\n\n")
|
||||||
|
|
||||||
|
print("**** Pruning Sources ****\n")
|
||||||
|
# Repeatedly removes the connection that has the lowest TE out of all insignificant connections.
|
||||||
|
# Only considers the furthest intervals as candidates in each round.
|
||||||
|
everything_significant = False
|
||||||
|
while not everything_significant:
|
||||||
|
print("Current conditioning set:")
|
||||||
|
for key in cond_set.keys():
|
||||||
|
print("source", key, "intervals", cond_set[key])
|
||||||
|
print("\nEstimating TE on candidate sources")
|
||||||
|
everything_significant = True
|
||||||
|
insignificant_sources = []
|
||||||
|
insignificant_sources_TE = []
|
||||||
|
for candidate_source in cond_set:
|
||||||
|
cond_set_minus_candidate = copy.deepcopy(cond_set)
|
||||||
|
# If more than one interval, remove the last
|
||||||
|
if len(cond_set_minus_candidate[candidate_source]) > 1:
|
||||||
|
cond_set_minus_candidate[candidate_source] = cond_set_minus_candidate[candidate_source][:-1]
|
||||||
|
# Otherwise, remove source from dict
|
||||||
|
else:
|
||||||
|
cond_set_minus_candidate.pop(candidate_source)
|
||||||
|
teCalc.setProperty("SOURCE_PAST_INTERVALS", str(cond_set[candidate_source][-1]))
|
||||||
|
cond_trains = prepare_conditional_trains(teCalc, cond_set_minus_candidate)
|
||||||
|
teCalc.startAddObservations()
|
||||||
|
if len(cond_set_minus_candidate) > 0:
|
||||||
|
teCalc.addObservations(JArray(JDouble, 1)(spikes[candidate_source]), JArray(JDouble, 1)(spikes[target_index]), JArray(JDouble, 2)(cond_trains))
|
||||||
|
else:
|
||||||
|
teCalc.addObservations(JArray(JDouble, 1)(spikes[candidate_source]), JArray(JDouble, 1)(spikes[target_index]))
|
||||||
|
teCalc.finaliseAddObservations();
|
||||||
|
TE = teCalc.computeAverageLocalOfObservations()
|
||||||
|
sig = teCalc.computeSignificance(NUM_SURROGATES_PER_TE_VAL, TE)
|
||||||
|
print("Source", candidate_source, "Interval", cond_set[candidate_source][-1],
|
||||||
|
" TE:", str(round(TE, 2)), " p val:", sig.pValue)
|
||||||
|
if sig.pValue > P_LEVEL:
|
||||||
|
everything_significant = False
|
||||||
|
insignificant_sources.append(candidate_source)
|
||||||
|
insignificant_sources_TE.append(TE)
|
||||||
|
if not everything_significant:
|
||||||
|
min_TE_source = insignificant_sources[np.argmin(insignificant_sources_TE)]
|
||||||
|
print("removing source", min_TE_source, "interval", cond_set[min_TE_source][-1])
|
||||||
|
if len(cond_set[min_TE_source]) > 1:
|
||||||
|
cond_set[min_TE_source] = cond_set[min_TE_source][:-1]
|
||||||
|
else:
|
||||||
|
cond_set.pop(min_TE_source)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
print("\n\n****** Final Inferred Source Set ******\n")
|
||||||
|
for key in cond_set.keys():
|
||||||
|
print("source", key, "intervals", cond_set[key])
|
||||||
|
print("\nTrue Sources:")
|
||||||
|
for con in cons:
|
||||||
|
if con[1] == target_index:
|
||||||
|
print(con[0], " ",)
|
||||||
|
|
||||||
|
|
||||||
|
output_file = open(OUTPUT_FILE_PREFIX + ".pk", 'wb')
|
||||||
|
pickle.dump(cond_set, output_file)
|
||||||
|
#pickle.dump(surrogate_vals_at_each_round, output_file)
|
||||||
|
#pickle.dump(TE_vals_at_each_round, output_file)
|
||||||
|
output_file.close()
|
|
@ -0,0 +1,46 @@
|
||||||
|
# This script converts CSV files of spike times (e.g. from the Wagenaar data set) into
|
||||||
|
# pickle files of spike times in the format that the net_inf.py script expects
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
import ast
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
RUN = "1-1-20.2"
|
||||||
|
|
||||||
|
spk_file = open('extracted_data_wagenaar/1-1/' + RUN + '.spk', 'r')
|
||||||
|
time_upper = 8 * 60 * 60 * 2.5e4
|
||||||
|
|
||||||
|
spikes = []
|
||||||
|
for line in spk_file:
|
||||||
|
line = line.strip()
|
||||||
|
line = line.split(",")
|
||||||
|
line = [float(time) for time in line if time != ""]
|
||||||
|
spikes.append(np.array(line))
|
||||||
|
|
||||||
|
start_times = [train[0] for train in spikes if len(train) > 0]
|
||||||
|
lowest_start_time = min(start_times)
|
||||||
|
cutoff_time = lowest_start_time + time_upper
|
||||||
|
|
||||||
|
for i in range(len(spikes)):
|
||||||
|
spikes[i] = spikes[i][spikes[i] < cutoff_time]
|
||||||
|
spikes[i] = spikes[i] - lowest_start_time
|
||||||
|
spikes[i] = spikes[i] + np.random.uniform(size = spikes[i].shape) - 0.5
|
||||||
|
spikes[i] = np.sort(spikes[i])
|
||||||
|
|
||||||
|
print(len(spikes))
|
||||||
|
for i in range(len(spikes)):
|
||||||
|
print(spikes[i].shape)
|
||||||
|
print(spikes[i][:10])
|
||||||
|
|
||||||
|
#plt.eventplot(spikes, linewidth = 0.5)
|
||||||
|
#plt.show()
|
||||||
|
spikes_file = open("spikes_LIF_" + RUN + "_" + sys.argv[1] + ".pk", "wb")
|
||||||
|
pickle.dump(spikes, spikes_file)
|
||||||
|
cons = [[0, 0]]
|
||||||
|
connections_file = open("connections_LIF_" + RUN + "_" + sys.argv[1] + ".pk", "wb")
|
||||||
|
pickle.dump(cons, connections_file)
|
||||||
|
spikes_file.close()
|
||||||
|
connections_file.close()
|
||||||
|
|
|
@ -0,0 +1,192 @@
|
||||||
|
##
|
||||||
|
## Java Information Dynamics Toolkit (JIDT)
|
||||||
|
## Copyright (C) 2022, David P. Shorten, Joseph T. Lizier
|
||||||
|
##
|
||||||
|
## This program is free software: you can redistribute it and/or modify
|
||||||
|
## it under the terms of the GNU General Public License as published by
|
||||||
|
## the Free Software Foundation, either version 3 of the License, or
|
||||||
|
## (at your option) any later version.
|
||||||
|
##
|
||||||
|
## This program is distributed in the hope that it will be useful,
|
||||||
|
## but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
## GNU General Public License for more details.
|
||||||
|
##
|
||||||
|
## You should have received a copy of the GNU General Public License
|
||||||
|
## along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
##
|
||||||
|
|
||||||
|
# Transfer entropy (TE) calculation on generated spike train data using the continuous-time TE estimator.
|
||||||
|
|
||||||
|
from jpype import *
|
||||||
|
import random
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
NUM_REPS = 2
|
||||||
|
NUM_SPIKES = int(2e3)
|
||||||
|
NUM_OBSERVATIONS = 2
|
||||||
|
NUM_SURROGATES = 10
|
||||||
|
|
||||||
|
# Params for canonical example generation
|
||||||
|
RATE_Y = 1.0
|
||||||
|
RATE_X_MAX = 10
|
||||||
|
|
||||||
|
def generate_canonical_example_processes(num_y_events):
|
||||||
|
event_train_x = []
|
||||||
|
event_train_x.append(0)
|
||||||
|
|
||||||
|
event_train_y = np.random.uniform(0, int(num_y_events / RATE_Y), int(num_y_events))
|
||||||
|
event_train_y.sort()
|
||||||
|
|
||||||
|
most_recent_y_index = 0
|
||||||
|
previous_x_candidate = 0
|
||||||
|
while most_recent_y_index < (len(event_train_y) - 1):
|
||||||
|
|
||||||
|
this_x_candidate = previous_x_candidate + random.expovariate(RATE_X_MAX)
|
||||||
|
|
||||||
|
while most_recent_y_index < (len(event_train_y) - 1) and this_x_candidate > event_train_y[most_recent_y_index + 1]:
|
||||||
|
most_recent_y_index += 1
|
||||||
|
|
||||||
|
delta_t = this_x_candidate - event_train_y[most_recent_y_index]
|
||||||
|
|
||||||
|
rate = 0
|
||||||
|
|
||||||
|
if delta_t > 1:
|
||||||
|
rate = 0.5
|
||||||
|
else:
|
||||||
|
rate = 0.5 + 5.0 * math.exp(-50 * (delta_t - 0.5)**2) - 5.0 * math.exp(-50 * (0.5)**2)
|
||||||
|
if random.random() < rate/float(RATE_X_MAX):
|
||||||
|
event_train_x.append(this_x_candidate)
|
||||||
|
previous_x_candidate = this_x_candidate
|
||||||
|
|
||||||
|
event_train_x.sort()
|
||||||
|
event_train_y.sort()
|
||||||
|
|
||||||
|
return event_train_x, event_train_y
|
||||||
|
|
||||||
|
# Change location of jar to match yours (we assume script is called from demos/python):
|
||||||
|
jarLocation = os.path.join(os.getcwd(), "infodynamics.jar");
|
||||||
|
if (not(os.path.isfile(jarLocation))):
|
||||||
|
exit("infodynamics.jar not found (expected at " + os.path.abspath(jarLocation) + ") - are you running from demos/python?")
|
||||||
|
# Start the JVM (add the "-Xmx" option with say 1024M if you get crashes due to not enough memory space)
|
||||||
|
startJVM(getDefaultJVMPath(), "-ea", "-Djava.class.path=" + jarLocation)
|
||||||
|
teCalcClass = JPackage("infodynamics.measures.spiking.integration").TransferEntropyCalculatorSpikingIntegration
|
||||||
|
|
||||||
|
|
||||||
|
teCalc = teCalcClass()
|
||||||
|
teCalc.setProperty("knns", "4")
|
||||||
|
print("Independent Poisson Processes")
|
||||||
|
teCalc.setProperty("DEST_PAST_INTERVALS", "1,2")
|
||||||
|
teCalc.setProperty("SOURCE_PAST_INTERVALS", "1,2")
|
||||||
|
teCalc.setProperty("DO_JITTERED_SAMPLING", "true")
|
||||||
|
teCalc.setProperty("JITTERED_SAMPLING_NOISE_LEVEL", "0")
|
||||||
|
teCalc.appendConditionalIntervals(JArray(JInt, 1)([1, 2]))
|
||||||
|
teCalc.appendConditionalIntervals(JArray(JInt, 1)([1, 2]))
|
||||||
|
teCalc.setProperty("NORM_TYPE", "MAX_NORM")
|
||||||
|
|
||||||
|
results_poisson = np.zeros(NUM_REPS)
|
||||||
|
for i in range(NUM_REPS):
|
||||||
|
teCalc.startAddObservations()
|
||||||
|
for j in range(NUM_OBSERVATIONS):
|
||||||
|
sourceArray = NUM_SPIKES*np.random.random(NUM_SPIKES)
|
||||||
|
sourceArray.sort()
|
||||||
|
destArray = NUM_SPIKES*np.random.random(NUM_SPIKES)
|
||||||
|
destArray.sort()
|
||||||
|
condArray = NUM_SPIKES*np.random.random((2, NUM_SPIKES))
|
||||||
|
condArray.sort(axis = 1)
|
||||||
|
teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray), JArray(JDouble, 2)(condArray))
|
||||||
|
teCalc.finaliseAddObservations();
|
||||||
|
result = teCalc.computeAverageLocalOfObservations()
|
||||||
|
print("TE result %.4f nats" % (result,))
|
||||||
|
sig = teCalc.computeSignificance(NUM_SURROGATES, result)
|
||||||
|
print(sig.pValue)
|
||||||
|
results_poisson[i] = result
|
||||||
|
print("Summary: mean ", np.mean(results_poisson), " std dev ", np.std(results_poisson))
|
||||||
|
|
||||||
|
teCalc = teCalcClass()
|
||||||
|
teCalc.setProperty("knns", "4")
|
||||||
|
print("Noisy copy zero TE")
|
||||||
|
#teCalc.appendConditionalIntervals(JArray(JInt, 1)([1]))
|
||||||
|
teCalc.setProperty("DEST_PAST_INTERVALS", "1")
|
||||||
|
teCalc.setProperty("SOURCE_PAST_INTERVALS", "1")
|
||||||
|
teCalc.setProperty("DO_JITTERED_SAMPLING", "true")
|
||||||
|
teCalc.setProperty("JITTERED_SAMPLING_NOISE_LEVEL", "0")
|
||||||
|
#teCalc.setProperty("NORM_TYPE", "MAX_NORM")
|
||||||
|
|
||||||
|
results_noisy_zero = np.zeros(NUM_REPS)
|
||||||
|
for i in range(NUM_REPS):
|
||||||
|
teCalc.startAddObservations()
|
||||||
|
for j in range(NUM_OBSERVATIONS):
|
||||||
|
condArray = np.ones((1, NUM_SPIKES)) + 0.05 * np.random.random((1, NUM_SPIKES))
|
||||||
|
condArray = np.cumsum(condArray, axis = 1)
|
||||||
|
condArray.sort(axis = 1)
|
||||||
|
sourceArray = condArray[0, :] + 0.25 + 0.05 * np.random.normal(size = condArray.shape[1])
|
||||||
|
sourceArray.sort()
|
||||||
|
destArray = condArray[0, :] + 0.5 + 0.05 * np.random.normal(size = condArray.shape[1])
|
||||||
|
destArray.sort()
|
||||||
|
#teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray), JArray(JDouble, 2)(condArray))
|
||||||
|
teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray))
|
||||||
|
teCalc.finaliseAddObservations();
|
||||||
|
result = teCalc.computeAverageLocalOfObservations()
|
||||||
|
print("TE result %.4f nats" % (result,))
|
||||||
|
sig = teCalc.computeSignificance(NUM_SURROGATES, result)
|
||||||
|
print(sig.pValue)
|
||||||
|
results_noisy_zero[i] = result
|
||||||
|
print("Summary: mean ", np.mean(results_noisy_zero), " std dev ", np.std(results_noisy_zero))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
teCalc = teCalcClass()
|
||||||
|
teCalc.setProperty("knns", "4")
|
||||||
|
print("Noisy copy non-zero TE")
|
||||||
|
teCalc.appendConditionalIntervals(JArray(JInt, 1)([1]))
|
||||||
|
teCalc.setProperty("DEST_PAST_INTERVALS", "1,2")
|
||||||
|
teCalc.setProperty("SOURCE_PAST_INTERVALS", "1")
|
||||||
|
teCalc.setProperty("DO_JITTERED_SAMPLING", "true")
|
||||||
|
teCalc.setProperty("JITTERED_SAMPLING_NOISE_LEVEL", "0")
|
||||||
|
#teCalc.setProperty("NORM_TYPE", "MAX_NORM")
|
||||||
|
|
||||||
|
results_noisy_non_zero = np.zeros(NUM_REPS)
|
||||||
|
for i in range(NUM_REPS):
|
||||||
|
teCalc.startAddObservations()
|
||||||
|
for j in range(NUM_OBSERVATIONS):
|
||||||
|
sourceArray = np.ones((1, NUM_SPIKES)) + 0.05 * np.random.random((1, NUM_SPIKES))
|
||||||
|
sourceArray = np.cumsum(sourceArray)
|
||||||
|
sourceArray.sort()
|
||||||
|
condArray = sourceArray + 0.25 + 0.05 * np.random.normal(size = sourceArray.shape)
|
||||||
|
condArray.sort()
|
||||||
|
condArray = np.expand_dims(condArray, 0)
|
||||||
|
destArray = sourceArray + 0.5 + 0.05 * np.random.normal(size = sourceArray.shape)
|
||||||
|
destArray.sort()
|
||||||
|
teCalc.addObservations(JArray(JDouble, 1)(sourceArray), JArray(JDouble, 1)(destArray), JArray(JDouble, 2)(condArray))
|
||||||
|
teCalc.finaliseAddObservations();
|
||||||
|
result = teCalc.computeAverageLocalOfObservations()
|
||||||
|
print("TE result %.4f nats" % (result,))
|
||||||
|
sig = teCalc.computeSignificance(NUM_SURROGATES, result)
|
||||||
|
print(sig.pValue)
|
||||||
|
results_noisy_non_zero[i] = result
|
||||||
|
print("Summary: mean ", np.mean(results_noisy_non_zero), " std dev ", np.std(results_noisy_zero))
|
||||||
|
|
||||||
|
print("Canonical example")
|
||||||
|
teCalc = teCalcClass()
|
||||||
|
teCalc.setProperty("knns", "4")
|
||||||
|
teCalc.setProperty("DEST_PAST_INTERVALS", "1,2")
|
||||||
|
teCalc.setProperty("SOURCE_PAST_INTERVALS", "1")
|
||||||
|
teCalc.setProperty("DO_JITTERED_SAMPLING", "true")
|
||||||
|
teCalc.setProperty("JITTERED_SAMPLING_NOISE_LEVEL", "0")
|
||||||
|
#teCalc.setProperty("NUM_SAMPLES_MULTIPLIER", "1")
|
||||||
|
#teCalc.setProperty("NORM_TYPE", "MAX_NORM")
|
||||||
|
|
||||||
|
results_canonical = np.zeros(NUM_REPS)
|
||||||
|
for i in range(NUM_REPS):
|
||||||
|
event_train_x, event_train_y = generate_canonical_example_processes(NUM_SPIKES)
|
||||||
|
teCalc.setObservations(JArray(JDouble, 1)(event_train_y), JArray(JDouble, 1)(event_train_x))
|
||||||
|
result = teCalc.computeAverageLocalOfObservations()
|
||||||
|
results_canonical[i] = result
|
||||||
|
print("TE result %.4f nats" % (result,))
|
||||||
|
sig = teCalc.computeSignificance(NUM_SURROGATES, result)
|
||||||
|
print(sig.pValue)
|
||||||
|
print("Summary: mean ", np.mean(results_canonical), " std dev ", np.std(results_canonical))
|
|
@ -0,0 +1,305 @@
|
||||||
|
/*
|
||||||
|
* Java Information Dynamics Toolkit (JIDT)
|
||||||
|
* Copyright (C) 2012, Joseph T. Lizier
|
||||||
|
*
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package infodynamics.measures.spiking;
|
||||||
|
|
||||||
|
import infodynamics.utils.EmpiricalMeasurementDistribution;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* <p>Interface for implementations of the <b>transfer entropy</b> (TE),
|
||||||
|
* which may be applied to spiking time-series data.
|
||||||
|
* That is, it is applied to <code>double[]</code> data, as an array
|
||||||
|
* of time stamps at which spikes were recorded.
|
||||||
|
* See Schreiber below for the definition of transfer entropy,
|
||||||
|
* and Lizier et al. for the definition of local transfer entropy,
|
||||||
|
* and (To be published) for how to measure transfer entropy on spike trains.
|
||||||
|
* Specifically, this class implements the pairwise or <i>apparent</i>
|
||||||
|
* transfer entropy; i.e. we compute the transfer that appears to
|
||||||
|
* come from a single source variable, without examining any other
|
||||||
|
* potential sources
|
||||||
|
* (see Lizier et al, PRE, 2008).</p>
|
||||||
|
*
|
||||||
|
* <p>
|
||||||
|
* Usage of the child classes implementing this interface is intended to follow this paradigm:
|
||||||
|
* </p>
|
||||||
|
* <ol>
|
||||||
|
* <li>Construct the calculator;</li>
|
||||||
|
* <li>Set properties using {@link #setProperty(String, String)}
|
||||||
|
* e.g. including properties describing
|
||||||
|
* the source and destination embedding;</li>
|
||||||
|
* <li>Initialise the calculator using
|
||||||
|
* {@link #initialise()} or {@link #initialise(int, int)};</li>
|
||||||
|
* <li>Provide the observations/samples for the calculator
|
||||||
|
* to set up the PDFs, using:
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link #setObservations(double[], double[])}
|
||||||
|
* for calculations based on single recordings, OR</li>
|
||||||
|
* <li>The following sequence:<ol>
|
||||||
|
* <li>{@link #startAddObservations()}, then</li>
|
||||||
|
* <li>One or more calls to
|
||||||
|
* {@link #addObservations(double[], double[])}, then</li>
|
||||||
|
* <li>{@link #finaliseAddObservations()};</li>
|
||||||
|
* </ol></li>
|
||||||
|
* </ul></li>
|
||||||
|
* <li>Compute the required quantities, being one or more of:
|
||||||
|
* <ul>
|
||||||
|
* <li>the average TE: {@link #computeAverageLocalOfObservations()};</li>
|
||||||
|
* <li>the local TE values for these samples: {@link #computeLocalOfPreviousObservations()}</li>
|
||||||
|
* <li>the distribution of TE values under the null hypothesis
|
||||||
|
* of no relationship between source and
|
||||||
|
* destination values: {@link #computeSignificance(int)} or
|
||||||
|
* {@link #computeSignificance(int[][])}.</li>
|
||||||
|
* </ul>
|
||||||
|
* </li>
|
||||||
|
* <li>
|
||||||
|
* Return to step 2 or 3 to re-use the calculator on a new data set.
|
||||||
|
* </li>
|
||||||
|
* </ol>
|
||||||
|
* </p>
|
||||||
|
*
|
||||||
|
* <p><b>References:</b><br/>
|
||||||
|
* <ul>
|
||||||
|
* <li>T. Schreiber, <a href="http://dx.doi.org/10.1103/PhysRevLett.85.461">
|
||||||
|
* "Measuring information transfer"</a>,
|
||||||
|
* Physical Review Letters 85 (2) pp.461-464, 2000.</li>
|
||||||
|
* <li>J. T. Lizier, M. Prokopenko and A. Zomaya,
|
||||||
|
* <a href="http://dx.doi.org/10.1103/PhysRevE.77.026110">
|
||||||
|
* "Local information transfer as a spatiotemporal filter for complex systems"</a>
|
||||||
|
* Physical Review E 77, 026110, 2008.</li>
|
||||||
|
* <li>To be published</li>
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @author Joseph Lizier (<a href="joseph.lizier at gmail.com">email</a>,
|
||||||
|
* <a href="http://lizier.me/joseph/">www</a>)
|
||||||
|
*/
|
||||||
|
public interface TransferEntropyCalculatorSpiking {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Property name to specify the destination history embedding length k
|
||||||
|
* (default value 1)
|
||||||
|
*/
|
||||||
|
public static final String K_PROP_NAME = "k_HISTORY";
|
||||||
|
/**
|
||||||
|
* Property name for embedding length for the source past history vector
|
||||||
|
* (default value 1)
|
||||||
|
*/
|
||||||
|
public static final String L_PROP_NAME = "l_HISTORY";
|
||||||
|
/* Could try to do this one later. I think we would just consider the history
|
||||||
|
* of source as being up to this many units of time behind destination and
|
||||||
|
* no later.
|
||||||
|
*
|
||||||
|
* Property name for source-destination delay (default value is 0)
|
||||||
|
*
|
||||||
|
public static final String DELAY_PROP_NAME = "DELAY";
|
||||||
|
*/
|
||||||
|
/**
|
||||||
|
* Property name for whether each series of time stamps of spikes is sorted
|
||||||
|
* into temporal order. (default true)
|
||||||
|
*/
|
||||||
|
public static final String TIMESORTED_PROP_NAME = "TIME_SORTED";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialise the calculator for re-use with new observations.
|
||||||
|
* All parameters remain unchanged.
|
||||||
|
*
|
||||||
|
* @throws Exception
|
||||||
|
*/
|
||||||
|
public void initialise() throws Exception;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialise the calculator for re-use with new observations.
|
||||||
|
* A new history embedding length k can be supplied here; all other parameters
|
||||||
|
* remain unchanged.
|
||||||
|
*
|
||||||
|
* @param k destination history embedding length to be considered.
|
||||||
|
* @throws Exception
|
||||||
|
*/
|
||||||
|
public void initialise(int k) throws Exception;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialise the calculator for re-use with new observations.
|
||||||
|
* New history embedding lengths k and l can be supplied here; all other parameters
|
||||||
|
* remain unchanged.
|
||||||
|
*
|
||||||
|
* @param k destination history embedding length to be considered.
|
||||||
|
* @param l source history embedding length to be considered.
|
||||||
|
* @throws Exception
|
||||||
|
*/
|
||||||
|
public void initialise(int k, int l) throws Exception;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets properties for the TE calculator.
|
||||||
|
* New property values are not guaranteed to take effect until the next call
|
||||||
|
* to an initialise method.
|
||||||
|
*
|
||||||
|
* <p>Valid property names, and what their
|
||||||
|
* values should represent, include:</p>
|
||||||
|
* <ul>
|
||||||
|
* <li>{@link #K_PROP_NAME} -- destination history embedding length k
|
||||||
|
* (default value 1)</li>
|
||||||
|
* <li>{@link #L_PROP_NAME} -- embedding length for the source past history vector
|
||||||
|
* (default value 1)</li>
|
||||||
|
* <li>{@link #TIMESORTED_PROP_NAME} -- whether each series of time stamps of spikes is sorted
|
||||||
|
* into temporal order. (default "true")</li>
|
||||||
|
* </ul>
|
||||||
|
* <p><b>Note:</b> further properties may be defined by child classes.</p>
|
||||||
|
*
|
||||||
|
* <p>Unknown property values are ignored.</p>
|
||||||
|
*
|
||||||
|
* @param propertyName name of the property
|
||||||
|
* @param propertyValue value of the property.
|
||||||
|
* @throws Exception if there is a problem with the supplied value,
|
||||||
|
* or if the property is recognised but unsupported (eg some
|
||||||
|
* calculators do not support all of the embedding properties).
|
||||||
|
*/
|
||||||
|
public void setProperty(String propertyName, String propertyValue) throws Exception;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get current property values for the calculator.
|
||||||
|
*
|
||||||
|
* <p>Valid property names, and what their
|
||||||
|
* values should represent, are the same as those for
|
||||||
|
* {@link #setProperty(String, String)}</p>
|
||||||
|
*
|
||||||
|
* <p>Unknown property values are responded to with a null return value.</p>
|
||||||
|
*
|
||||||
|
* @param propertyName name of the property
|
||||||
|
* @return current value of the property
|
||||||
|
* @throws Exception for invalid property values
|
||||||
|
*/
|
||||||
|
public String getProperty(String propertyName) throws Exception;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets a single set of spiking observations from which to compute the PDF for transfer entropy.
|
||||||
|
* Cannot be called in conjunction with other methods for setting/adding
|
||||||
|
* observations.
|
||||||
|
*
|
||||||
|
* @param source series of time stamps of spikes for the source variable.
|
||||||
|
* @param destination series of time stamps of spikes for the destination
|
||||||
|
* variable. Length will generally be different to the <code>source</code>,
|
||||||
|
* unlike other transfer entropy implementations, e.g. for {@link infodynamics.measures.continuous.TransferEntropyCalculator}.
|
||||||
|
* <code>source</code> and <code>destination</code> must have the same reference time point,
|
||||||
|
* and each array is assumed to be sorted into temporal order unless
|
||||||
|
* the property {@link #TIMESORTED_PROP_NAME} has been set to false.
|
||||||
|
*
|
||||||
|
* @throws Exception
|
||||||
|
*/
|
||||||
|
public void setObservations(double source[], double destination[]) throws Exception;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Signal that we will add in the samples for computing the PDF
|
||||||
|
* from several disjoint time-series or trials via calls to
|
||||||
|
* {@link #addObservations(double[], double[])} rather than
|
||||||
|
* {@link #setObservations(double[], double[])} type methods
|
||||||
|
* (defined by the child interfaces and classes).
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public void startAddObservations();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* <p>Adds a new set of spiking observations to update the PDFs with.
|
||||||
|
* It is intended to be called multiple times, and must
|
||||||
|
* be called after {@link #startAddObservations()}. Call
|
||||||
|
* {@link #finaliseAddObservations()} once all observations have
|
||||||
|
* been supplied.</p>
|
||||||
|
*
|
||||||
|
* <p><b>Important:</b> this does not append or overlay these observations to the previously
|
||||||
|
* supplied observations, but treats them as independent trials - i.e. measurements
|
||||||
|
* such as the transfer entropy will not join them up to examine k
|
||||||
|
* consecutive values in time.</p>
|
||||||
|
*
|
||||||
|
* <p>Note that the arrays source and destination must not be over-written by the user
|
||||||
|
* until after {@link #finaliseAddObservations()} has been called
|
||||||
|
* (they are not copied by this method necessarily, the method
|
||||||
|
* may simply hold a pointer to them).</p>
|
||||||
|
*
|
||||||
|
* @param source series of time stamps of spikes for the source variable.
|
||||||
|
* Will be returned in ascending sorted order.
|
||||||
|
* @param destination series of time stamps of spikes for the destination
|
||||||
|
* variable. Length will generally be different to the <code>source</code>,
|
||||||
|
* unlike other transfer entropy implementations, e.g. for {@link infodynamics.measures.continuous.TransferEntropyCalculator}.
|
||||||
|
* <code>source</code> and <code>destination</code> must have the same reference time point,
|
||||||
|
* and each array is assumed to be sorted into temporal order unless
|
||||||
|
* the property {@link #TIMESORTED_PROP_NAME} has been set to false.
|
||||||
|
* Will be returned in ascending sorted order.
|
||||||
|
* @throws Exception
|
||||||
|
*/
|
||||||
|
public void addObservations(double[] source, double[] destination) throws Exception;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Signal that the observations are now all added via
|
||||||
|
* {@link #addObservations(double[], double[])}, PDFs can now be constructed.
|
||||||
|
*
|
||||||
|
* @throws Exception
|
||||||
|
*/
|
||||||
|
public void finaliseAddObservations() throws Exception;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Query whether the user has added more than a single observation set via the
|
||||||
|
* {@link #startAddObservations()}, "addObservations" (defined by child interfaces
|
||||||
|
* and classes), {@link #finaliseAddObservations()} sequence.
|
||||||
|
*
|
||||||
|
* @return true if more than a single observation set was supplied
|
||||||
|
*/
|
||||||
|
public boolean getAddedMoreThanOneObservationSet();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the TE from the previously-supplied samples.
|
||||||
|
*
|
||||||
|
* @return the estimate of the channel measure
|
||||||
|
*/
|
||||||
|
public double computeAverageLocalOfObservations() throws Exception;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This interface serves to indicate the return type of {@link TransferEntropyCalculator#computeLocalOfPreviousObservations()}
|
||||||
|
* as each child implementation will return something specific
|
||||||
|
*
|
||||||
|
* @author Joseph Lizier
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
public interface SpikingLocalInformationValues {
|
||||||
|
// Left empty intentionally
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return an object containing a representation of
|
||||||
|
* the of local TE values. The precise contents of this representation
|
||||||
|
* will vary depending on the underlying implementation
|
||||||
|
*/
|
||||||
|
public SpikingLocalInformationValues computeLocalOfPreviousObservations() throws Exception;
|
||||||
|
|
||||||
|
public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck, double estimatedValue) throws Exception;
|
||||||
|
|
||||||
|
public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck,
|
||||||
|
double estimatedValue, long randomSeed) throws Exception;
|
||||||
|
/**
|
||||||
|
* Set or clear debug mode for extra debug printing to stdout
|
||||||
|
*
|
||||||
|
* @param debug new setting for debug mode (on/off)
|
||||||
|
*/
|
||||||
|
public void setDebug(boolean debug);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the TE last calculated in a call to {@link #computeAverageLocalOfObservations()}
|
||||||
|
* or {@link #computeLocalOfPreviousObservations()} after the previous
|
||||||
|
* {@link #initialise()} call.
|
||||||
|
*
|
||||||
|
* @return the last computed channel measure value
|
||||||
|
*/
|
||||||
|
public double getLastAverage();
|
||||||
|
}
|
|
@ -0,0 +1,965 @@
|
||||||
|
package infodynamics.measures.spiking.integration;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Iterator;
|
||||||
|
import java.util.PriorityQueue;
|
||||||
|
import java.util.Random;
|
||||||
|
import java.util.Vector;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.lang.Math;
|
||||||
|
|
||||||
|
//import infodynamics.measures.continuous.kraskov.EuclideanUtils;
|
||||||
|
import infodynamics.measures.spiking.TransferEntropyCalculatorSpiking;
|
||||||
|
import infodynamics.utils.EmpiricalMeasurementDistribution;
|
||||||
|
import infodynamics.utils.KdTree;
|
||||||
|
import infodynamics.utils.MathsUtils;
|
||||||
|
import infodynamics.utils.MatrixUtils;
|
||||||
|
import infodynamics.utils.NeighbourNodeData;
|
||||||
|
import infodynamics.utils.FirstIndexComparatorDouble;
|
||||||
|
import infodynamics.utils.UnivariateNearestNeighbourSearcher;
|
||||||
|
import infodynamics.utils.EuclideanUtils;
|
||||||
|
import infodynamics.utils.ParsedProperties;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Computes the transfer entropy between a pair of spike trains, using an
|
||||||
|
* integration-based measure in order to match the theoretical form of TE
|
||||||
|
* between such spike trains.
|
||||||
|
*
|
||||||
|
* <p>
|
||||||
|
* Usage paradigm is as per the interface
|
||||||
|
* {@link TransferEntropyCalculatorSpiking}
|
||||||
|
* </p>
|
||||||
|
*
|
||||||
|
* @author Joseph Lizier (<a href="joseph.lizier at gmail.com">email</a>,
|
||||||
|
* <a href="http://lizier.me/joseph/">www</a>)
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
* TODO
|
||||||
|
* This implementation of the estimator does not implement dynamic exclusion windows. Such windows make sure
|
||||||
|
* that history embeddings that overlap are not considered in nearest-neighbour searches (as this breaks the
|
||||||
|
* independece assumption). Getting dynamic exclusion windows working will probably require modifications to the
|
||||||
|
* KdTree class.
|
||||||
|
*/
|
||||||
|
public class TransferEntropyCalculatorSpikingIntegration implements TransferEntropyCalculatorSpiking {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The past destination interspike intervals to consider
|
||||||
|
* (and associated property name and convenience length variable)
|
||||||
|
* The code assumes that the first interval is numbered 1, the next is numbered 2, etc.
|
||||||
|
* It also assumes that the intervals are sorted. The setter method performs sorting to ensure this.
|
||||||
|
*/
|
||||||
|
public static final String DEST_PAST_INTERVALS_PROP_NAME = "DEST_PAST_INTERVALS";
|
||||||
|
protected int[] destPastIntervals = new int[] {};
|
||||||
|
protected int numDestPastIntervals = 0;
|
||||||
|
/**
|
||||||
|
* As above but for the sources
|
||||||
|
*/
|
||||||
|
public static final String SOURCE_PAST_INTERVALS_PROP_NAME = "SOURCE_PAST_INTERVALS";
|
||||||
|
protected int[] sourcePastIntervals = new int[] {};
|
||||||
|
protected int numSourcePastIntervals = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* As above but for the conditioning processes. There is no property name for this variable,
|
||||||
|
* due to there not currently being a method for converting strings to 2d arrays. Instead,
|
||||||
|
* a separate setter method is implemented.
|
||||||
|
*/
|
||||||
|
protected Vector<int[]> vectorOfCondPastIntervals = new Vector<int[]>();
|
||||||
|
protected int numCondPastIntervals = 0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Number of nearest neighbours to search for in the full joint space
|
||||||
|
*/
|
||||||
|
protected int Knns = 4;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Storage for source observations supplied via
|
||||||
|
* {@link #addObservations(double[], double[])} etc.
|
||||||
|
*/
|
||||||
|
protected Vector<double[]> vectorOfSourceSpikeTimes = null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Storage for destination observations supplied via
|
||||||
|
* {@link #addObservations(double[], double[])} etc.
|
||||||
|
*/
|
||||||
|
protected Vector<double[]> vectorOfDestinationSpikeTimes = null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Storage for conditional observations supplied via
|
||||||
|
* {@link #addObservations(double[], double[])} etc.
|
||||||
|
*/
|
||||||
|
protected Vector<double[][]> vectorOfConditionalSpikeTimes = null;
|
||||||
|
|
||||||
|
Vector<double[]> conditioningEmbeddingsFromSpikes = null;
|
||||||
|
Vector<double[]> jointEmbeddingsFromSpikes = null;
|
||||||
|
Vector<double[]> conditioningEmbeddingsFromSamples = null;
|
||||||
|
Vector<double[]> jointEmbeddingsFromSamples = null;
|
||||||
|
Vector<Double> processTimeLengths = null;
|
||||||
|
|
||||||
|
protected KdTree kdTreeJointAtSpikes = null;
|
||||||
|
protected KdTree kdTreeJointAtSamples = null;
|
||||||
|
protected KdTree kdTreeConditioningAtSpikes = null;
|
||||||
|
protected KdTree kdTreeConditioningAtSamples = null;
|
||||||
|
|
||||||
|
public static final String KNNS_PROP_NAME = "Knns";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Property name for an amount of random Gaussian noise to be added to the data
|
||||||
|
* (default is 1e-8, matching the MILCA toolkit).
|
||||||
|
*/
|
||||||
|
public static final String PROP_ADD_NOISE = "NOISE_LEVEL_TO_ADD";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether to add an amount of random noise to the incoming data
|
||||||
|
*/
|
||||||
|
protected boolean addNoise = true;
|
||||||
|
/**
|
||||||
|
* Amount of random Gaussian noise to add to the incoming data
|
||||||
|
*/
|
||||||
|
protected double noiseLevel = (double) 1e-8;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Whether to use the jittered sampling approach. Useful for bursty spike trains. Explained in methods section of
|
||||||
|
* doi.org/10.1101/2021.06.29.450432
|
||||||
|
*/
|
||||||
|
protected boolean jitteredSamplesForSurrogates = false;
|
||||||
|
public static final String DO_JITTERED_SAMPLING_PROP_NAME = "DO_JITTERED_SAMPLING";
|
||||||
|
protected double jitteredSamplingNoiseLevel = 1;
|
||||||
|
public static final String JITTERED_SAMPLING_NOISE_LEVEL = "JITTERED_SAMPLING_NOISE_LEVEL";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stores whether we are in debug mode
|
||||||
|
*/
|
||||||
|
protected boolean debug = false;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Property name for the number of random sample points to use as a multiple
|
||||||
|
* of the number of target spikes.
|
||||||
|
*/
|
||||||
|
public static final String PROP_SAMPLE_MULTIPLIER = "NUM_SAMPLES_MULTIPLIER";
|
||||||
|
protected double numSamplesMultiplier = 2.0;
|
||||||
|
/**
|
||||||
|
* Property name for the number of random sample points to use in the construction of the surrogates as a multiple
|
||||||
|
* of the number of target spikes.
|
||||||
|
*/
|
||||||
|
public static final String PROP_SURROGATE_SAMPLE_MULTIPLIER = "SURROGATE_NUM_SAMPLES_MULTIPLIER";
|
||||||
|
protected double surrogateNumSamplesMultiplier = 2.0;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Property for the number of nearest neighbours to use in the construction of the surrogates
|
||||||
|
*/
|
||||||
|
public static final String PROP_K_PERM = "K_PERM";
|
||||||
|
protected int kPerm = 10;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Property name for what type of norm to use between data points
|
||||||
|
* for each marginal variable -- Options are defined by
|
||||||
|
* {@link KdTree#setNormType(String)} and the
|
||||||
|
* default is {@link EuclideanUtils#NORM_EUCLIDEAN}.
|
||||||
|
*/
|
||||||
|
public final static String PROP_NORM_TYPE = "NORM_TYPE";
|
||||||
|
protected int normType = EuclideanUtils.NORM_EUCLIDEAN;
|
||||||
|
|
||||||
|
public TransferEntropyCalculatorSpikingIntegration() {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
* (non-Javadoc)
|
||||||
|
*
|
||||||
|
* @see
|
||||||
|
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(
|
||||||
|
* int)
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void initialise() throws Exception {
|
||||||
|
initialise(0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* (non-Javadoc)
|
||||||
|
*
|
||||||
|
* @see
|
||||||
|
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(
|
||||||
|
* int)
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void initialise(int k) throws Exception {
|
||||||
|
initialise(0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* (non-Javadoc)
|
||||||
|
*
|
||||||
|
* @see
|
||||||
|
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#initialise(
|
||||||
|
* int, int)
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void initialise(int k, int l) throws Exception {
|
||||||
|
vectorOfSourceSpikeTimes = null;
|
||||||
|
vectorOfDestinationSpikeTimes = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* (non-Javadoc)
|
||||||
|
*
|
||||||
|
* @see
|
||||||
|
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setProperty(
|
||||||
|
* java.lang.String, java.lang.String)
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void setProperty(String propertyName, String propertyValue) throws Exception {
|
||||||
|
boolean propertySet = true;
|
||||||
|
if (propertyName.equalsIgnoreCase(DEST_PAST_INTERVALS_PROP_NAME)) {
|
||||||
|
if (propertyValue.length() == 0) {
|
||||||
|
destPastIntervals = new int[] {};
|
||||||
|
} else {
|
||||||
|
int[] destPastIntervalsTemp = ParsedProperties.parseStringArrayOfInts(propertyValue);
|
||||||
|
for (int interval : destPastIntervalsTemp) {
|
||||||
|
if (interval < 1) {
|
||||||
|
throw new Exception ("Invalid interval number less than 1.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
destPastIntervals = destPastIntervalsTemp;
|
||||||
|
Arrays.sort(destPastIntervals);
|
||||||
|
}
|
||||||
|
} else if (propertyName.equalsIgnoreCase(SOURCE_PAST_INTERVALS_PROP_NAME)) {
|
||||||
|
if (propertyValue.length() == 0) {
|
||||||
|
sourcePastIntervals = new int[] {};
|
||||||
|
} else {
|
||||||
|
int[] sourcePastIntervalsTemp = ParsedProperties.parseStringArrayOfInts(propertyValue);
|
||||||
|
for (int interval : sourcePastIntervalsTemp) {
|
||||||
|
if (interval < 1) {
|
||||||
|
throw new Exception ("Invalid interval number less than 1.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sourcePastIntervals = sourcePastIntervalsTemp;
|
||||||
|
Arrays.sort(sourcePastIntervals);
|
||||||
|
}
|
||||||
|
} else if (propertyName.equalsIgnoreCase(KNNS_PROP_NAME)) {
|
||||||
|
Knns = Integer.parseInt(propertyValue);
|
||||||
|
} else if (propertyName.equalsIgnoreCase(DO_JITTERED_SAMPLING_PROP_NAME)) {
|
||||||
|
jitteredSamplesForSurrogates = Boolean.parseBoolean(propertyValue);
|
||||||
|
} else if (propertyName.equalsIgnoreCase(JITTERED_SAMPLING_NOISE_LEVEL)) {
|
||||||
|
jitteredSamplingNoiseLevel = Double.parseDouble(propertyValue);
|
||||||
|
} else if (propertyName.equalsIgnoreCase(PROP_K_PERM)) {
|
||||||
|
kPerm = Integer.parseInt(propertyValue);
|
||||||
|
} else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) {
|
||||||
|
if (propertyValue.equals("0") || propertyValue.equalsIgnoreCase("false")) {
|
||||||
|
addNoise = false;
|
||||||
|
noiseLevel = 0;
|
||||||
|
} else {
|
||||||
|
addNoise = true;
|
||||||
|
noiseLevel = Double.parseDouble(propertyValue);
|
||||||
|
}
|
||||||
|
} else if (propertyName.equalsIgnoreCase(PROP_SAMPLE_MULTIPLIER)) {
|
||||||
|
double tempNumSamplesMultiplier = Double.parseDouble(propertyValue);
|
||||||
|
if (tempNumSamplesMultiplier <= 0) {
|
||||||
|
throw new Exception ("Num samples multiplier must be greater than 0.");
|
||||||
|
} else {
|
||||||
|
numSamplesMultiplier = tempNumSamplesMultiplier;
|
||||||
|
}
|
||||||
|
} else if (propertyName.equalsIgnoreCase(PROP_NORM_TYPE)) {
|
||||||
|
normType = KdTree.validateNormType(propertyValue);
|
||||||
|
} else if (propertyName.equalsIgnoreCase(PROP_SURROGATE_SAMPLE_MULTIPLIER)) {
|
||||||
|
double tempSurrogateNumSamplesMultiplier = Double.parseDouble(propertyValue);
|
||||||
|
if (tempSurrogateNumSamplesMultiplier <= 0) {
|
||||||
|
throw new Exception ("Surrogate Num samples multiplier must be greater than 0.");
|
||||||
|
} else {
|
||||||
|
surrogateNumSamplesMultiplier = tempSurrogateNumSamplesMultiplier;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No property was set on this class
|
||||||
|
propertySet = false;
|
||||||
|
}
|
||||||
|
if (debug && propertySet) {
|
||||||
|
System.out.println(
|
||||||
|
this.getClass().getSimpleName() + ": Set property " + propertyName + " to " + propertyValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* (non-Javadoc)
|
||||||
|
*
|
||||||
|
* @see
|
||||||
|
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getProperty(
|
||||||
|
* java.lang.String)
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public String getProperty(String propertyName) throws Exception {
|
||||||
|
if (propertyName.equalsIgnoreCase(KNNS_PROP_NAME)) {
|
||||||
|
return Integer.toString(Knns);
|
||||||
|
} else if (propertyName.equalsIgnoreCase(PROP_ADD_NOISE)) {
|
||||||
|
return Double.toString(noiseLevel);
|
||||||
|
} else if (propertyName.equalsIgnoreCase(PROP_SAMPLE_MULTIPLIER)) {
|
||||||
|
return Double.toString(numSamplesMultiplier);
|
||||||
|
} else {
|
||||||
|
// No property matches for this class
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public void appendConditionalIntervals(int[] intervals) throws Exception{
|
||||||
|
for (int interval : intervals) {
|
||||||
|
if (interval < 1) {
|
||||||
|
throw new Exception ("Invalid interval number less than 1.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Arrays.sort(intervals);
|
||||||
|
vectorOfCondPastIntervals.add(intervals);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void clearConditionalIntervals() {
|
||||||
|
vectorOfCondPastIntervals = new Vector<int[]>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* (non-Javadoc)
|
||||||
|
*
|
||||||
|
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||||
|
* setObservations(double[], double[])
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void setObservations(double[] source, double[] destination) throws Exception {
|
||||||
|
startAddObservations();
|
||||||
|
addObservations(source, destination);
|
||||||
|
finaliseAddObservations();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setObservations(double[] source, double[] destination, double[][] conditionals) throws Exception {
|
||||||
|
startAddObservations();
|
||||||
|
addObservations(source, destination, conditionals);
|
||||||
|
finaliseAddObservations();
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* (non-Javadoc)
|
||||||
|
*
|
||||||
|
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||||
|
* startAddObservations()
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void startAddObservations() {
|
||||||
|
vectorOfSourceSpikeTimes = new Vector<double[]>();
|
||||||
|
vectorOfDestinationSpikeTimes = new Vector<double[]>();
|
||||||
|
vectorOfConditionalSpikeTimes = new Vector<double[][]>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* (non-Javadoc)
|
||||||
|
*
|
||||||
|
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||||
|
* addObservations(double[], double[])
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void addObservations(double[] source, double[] destination) throws Exception {
|
||||||
|
// Store these observations in our vector for now
|
||||||
|
vectorOfSourceSpikeTimes.add(source);
|
||||||
|
vectorOfDestinationSpikeTimes.add(destination);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addObservations(double[] source, double[] destination, double[][] conditionals) throws Exception {
|
||||||
|
// Store these observations in our vector for now
|
||||||
|
vectorOfSourceSpikeTimes.add(source);
|
||||||
|
vectorOfDestinationSpikeTimes.add(destination);
|
||||||
|
vectorOfConditionalSpikeTimes.add(conditionals);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* (non-Javadoc)
|
||||||
|
*
|
||||||
|
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||||
|
* finaliseAddObservations()
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void finaliseAddObservations() throws Exception {
|
||||||
|
|
||||||
|
// Set these conveniance variables as they are used quite a bit later on
|
||||||
|
numDestPastIntervals = destPastIntervals.length;
|
||||||
|
numSourcePastIntervals = sourcePastIntervals.length;
|
||||||
|
numCondPastIntervals = 0;
|
||||||
|
for (int[] intervals : vectorOfCondPastIntervals) {
|
||||||
|
numCondPastIntervals += intervals.length;
|
||||||
|
}
|
||||||
|
|
||||||
|
conditioningEmbeddingsFromSpikes = new Vector<double[]>();
|
||||||
|
jointEmbeddingsFromSpikes = new Vector<double[]>();
|
||||||
|
conditioningEmbeddingsFromSamples = new Vector<double[]>();
|
||||||
|
jointEmbeddingsFromSamples = new Vector<double[]>();
|
||||||
|
processTimeLengths = new Vector<Double>();
|
||||||
|
|
||||||
|
// Send all of the observations through:
|
||||||
|
Iterator<double[]> sourceIterator = vectorOfSourceSpikeTimes.iterator();
|
||||||
|
Iterator<double[][]> conditionalIterator = null;
|
||||||
|
if (vectorOfConditionalSpikeTimes.size() > 0) {
|
||||||
|
conditionalIterator = vectorOfConditionalSpikeTimes.iterator();
|
||||||
|
}
|
||||||
|
int timeSeriesIndex = 0;
|
||||||
|
for (double[] destSpikeTimes : vectorOfDestinationSpikeTimes) {
|
||||||
|
double[] sourceSpikeTimes = sourceIterator.next();
|
||||||
|
double[][] conditionalSpikeTimes = null;
|
||||||
|
if (vectorOfConditionalSpikeTimes.size() > 0) {
|
||||||
|
conditionalSpikeTimes = conditionalIterator.next();
|
||||||
|
} else {
|
||||||
|
conditionalSpikeTimes = new double[][] {};
|
||||||
|
}
|
||||||
|
processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, conditioningEmbeddingsFromSpikes,
|
||||||
|
jointEmbeddingsFromSpikes, conditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples,
|
||||||
|
numSamplesMultiplier, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
// Convert the vectors to arrays so that they can be put in the trees
|
||||||
|
double[][] arrayedTargetEmbeddingsFromSpikes = new double[conditioningEmbeddingsFromSpikes.size()][numDestPastIntervals + numCondPastIntervals];
|
||||||
|
double[][] arrayedJointEmbeddingsFromSpikes = new double[conditioningEmbeddingsFromSpikes.size()][numDestPastIntervals +
|
||||||
|
numCondPastIntervals + numSourcePastIntervals];
|
||||||
|
for (int i = 0; i < conditioningEmbeddingsFromSpikes.size(); i++) {
|
||||||
|
arrayedTargetEmbeddingsFromSpikes[i] = conditioningEmbeddingsFromSpikes.elementAt(i);
|
||||||
|
arrayedJointEmbeddingsFromSpikes[i] = jointEmbeddingsFromSpikes.elementAt(i);
|
||||||
|
}
|
||||||
|
double[][] arrayedTargetEmbeddingsFromSamples = new double[conditioningEmbeddingsFromSamples.size()][numDestPastIntervals + numCondPastIntervals];
|
||||||
|
double[][] arrayedJointEmbeddingsFromSamples = new double[conditioningEmbeddingsFromSamples.size()][numDestPastIntervals +
|
||||||
|
numCondPastIntervals + numSourcePastIntervals];
|
||||||
|
for (int i = 0; i < conditioningEmbeddingsFromSamples.size(); i++) {
|
||||||
|
arrayedTargetEmbeddingsFromSamples[i] = conditioningEmbeddingsFromSamples.elementAt(i);
|
||||||
|
arrayedJointEmbeddingsFromSamples[i] = jointEmbeddingsFromSamples.elementAt(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
kdTreeJointAtSpikes = new KdTree(arrayedJointEmbeddingsFromSpikes);
|
||||||
|
kdTreeJointAtSamples = new KdTree(arrayedJointEmbeddingsFromSamples);
|
||||||
|
kdTreeConditioningAtSpikes = new KdTree(arrayedTargetEmbeddingsFromSpikes);
|
||||||
|
kdTreeConditioningAtSamples = new KdTree(arrayedTargetEmbeddingsFromSamples);
|
||||||
|
|
||||||
|
kdTreeJointAtSpikes.setNormType(normType);
|
||||||
|
kdTreeJointAtSamples.setNormType(normType);
|
||||||
|
kdTreeConditioningAtSpikes.setNormType(normType);
|
||||||
|
kdTreeConditioningAtSamples.setNormType(normType);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void makeEmbeddingsAtPoints(double[] pointsAtWhichToMakeEmbeddings, int indexOfFirstPointToUse,
|
||||||
|
double[] sourceSpikeTimes, double[] destSpikeTimes,
|
||||||
|
double[][] conditionalSpikeTimes,
|
||||||
|
Vector<double[]> conditioningEmbeddings,
|
||||||
|
Vector<double[]> jointEmbeddings) {
|
||||||
|
|
||||||
|
Random random = new Random();
|
||||||
|
|
||||||
|
// Initialise the starting points of all the tracking variables
|
||||||
|
int embeddingPointIndex = indexOfFirstPointToUse;
|
||||||
|
int mostRecentDestIndex = destPastIntervals[destPastIntervals.length - 1];
|
||||||
|
int mostRecentSourceIndex = sourcePastIntervals[sourcePastIntervals.length - 1] - 1;
|
||||||
|
int[] mostRecentConditioningIndices = new int[vectorOfCondPastIntervals.size()];
|
||||||
|
for (int i = 0; i < vectorOfCondPastIntervals.size(); i++) {
|
||||||
|
mostRecentConditioningIndices[i] = vectorOfCondPastIntervals.elementAt(i)[vectorOfCondPastIntervals.elementAt(i).length - 1] - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Loop through the points at which embeddings need to be made
|
||||||
|
for (; embeddingPointIndex < pointsAtWhichToMakeEmbeddings.length; embeddingPointIndex++) {
|
||||||
|
|
||||||
|
// Advance the tracker of the most recent dest index
|
||||||
|
while (mostRecentDestIndex < (destSpikeTimes.length - 1)) {
|
||||||
|
if (destSpikeTimes[mostRecentDestIndex + 1] < pointsAtWhichToMakeEmbeddings[embeddingPointIndex]) {
|
||||||
|
mostRecentDestIndex++;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Do the same for the most recent source index
|
||||||
|
while (mostRecentSourceIndex < (sourceSpikeTimes.length - 1)) {
|
||||||
|
if (sourceSpikeTimes[mostRecentSourceIndex + 1] < pointsAtWhichToMakeEmbeddings[embeddingPointIndex]) {
|
||||||
|
mostRecentSourceIndex++;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Now advance the trackers for the most recent conditioning indices
|
||||||
|
for (int j = 0; j < mostRecentConditioningIndices.length; j++) {
|
||||||
|
while (mostRecentConditioningIndices[j] < (conditionalSpikeTimes[j].length - 1)) {
|
||||||
|
if (conditionalSpikeTimes[j][mostRecentConditioningIndices[j] + 1] < pointsAtWhichToMakeEmbeddings[embeddingPointIndex]) {
|
||||||
|
mostRecentConditioningIndices[j]++;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
double[] conditioningPast = new double[numDestPastIntervals + numCondPastIntervals];
|
||||||
|
double[] jointPast = new double[numDestPastIntervals + numCondPastIntervals + numSourcePastIntervals];
|
||||||
|
|
||||||
|
// Add the embedding intervals from the target process
|
||||||
|
for (int i = 0; i < destPastIntervals.length; i++) {
|
||||||
|
// Case where we are inserting an interval from an observation point back to the most recent event in the target process
|
||||||
|
if (destPastIntervals[i] == 1) {
|
||||||
|
conditioningPast[i] = pointsAtWhichToMakeEmbeddings[embeddingPointIndex] - destSpikeTimes[mostRecentDestIndex];
|
||||||
|
jointPast[i] = pointsAtWhichToMakeEmbeddings[embeddingPointIndex]
|
||||||
|
- destSpikeTimes[mostRecentDestIndex];
|
||||||
|
// Case where we are inserting an inter-event intervvl from the target process
|
||||||
|
} else {
|
||||||
|
conditioningPast[i] = destSpikeTimes[mostRecentDestIndex - destPastIntervals[i] + 2]
|
||||||
|
- destSpikeTimes[mostRecentDestIndex - destPastIntervals[i] + 1];
|
||||||
|
jointPast[i] = destSpikeTimes[mostRecentDestIndex - destPastIntervals[i] + 2]
|
||||||
|
- destSpikeTimes[mostRecentDestIndex - destPastIntervals[i] + 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the embeding intervals from the conditional processes
|
||||||
|
int indexOfNextEmbeddingInterval = numDestPastIntervals;
|
||||||
|
for (int i = 0; i < vectorOfCondPastIntervals.size(); i++) {
|
||||||
|
for (int j = 0; j < vectorOfCondPastIntervals.elementAt(i).length; j++) {
|
||||||
|
// Case where we are inserting an interval from an observation point back to the most recent event in the conditioning process
|
||||||
|
if (vectorOfCondPastIntervals.elementAt(i)[j] == 1) {
|
||||||
|
conditioningPast[indexOfNextEmbeddingInterval] = pointsAtWhichToMakeEmbeddings[embeddingPointIndex]
|
||||||
|
- conditionalSpikeTimes[i][mostRecentConditioningIndices[i]];
|
||||||
|
jointPast[indexOfNextEmbeddingInterval] = pointsAtWhichToMakeEmbeddings[embeddingPointIndex]
|
||||||
|
- conditionalSpikeTimes[i][mostRecentConditioningIndices[i]];
|
||||||
|
// Case where we are inserting an inter-event interval from the conditioning process
|
||||||
|
} else {
|
||||||
|
// Convenience variable
|
||||||
|
int intervalNumber = vectorOfCondPastIntervals.elementAt(i)[j];
|
||||||
|
conditioningPast[indexOfNextEmbeddingInterval] = conditionalSpikeTimes[i][mostRecentConditioningIndices[i] - intervalNumber + 2]
|
||||||
|
- conditionalSpikeTimes[i][mostRecentConditioningIndices[i] - intervalNumber + 1];
|
||||||
|
jointPast[indexOfNextEmbeddingInterval] = conditionalSpikeTimes[i][mostRecentConditioningIndices[i] - intervalNumber + 2]
|
||||||
|
- conditionalSpikeTimes[i][mostRecentConditioningIndices[i] - intervalNumber + 1];
|
||||||
|
}
|
||||||
|
indexOfNextEmbeddingInterval++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the embedding intervals from the source process (this only gets added to the joint embeddings)
|
||||||
|
for (int i = 0; i < sourcePastIntervals.length; i++) {
|
||||||
|
// Case where we are inserting an interval from an observation point back to the most recent event in the source process
|
||||||
|
if (sourcePastIntervals[i] == 1) {
|
||||||
|
jointPast[indexOfNextEmbeddingInterval] = pointsAtWhichToMakeEmbeddings[embeddingPointIndex]
|
||||||
|
- sourceSpikeTimes[mostRecentSourceIndex];
|
||||||
|
// Case where we are inserting an inter-event interval from the source process
|
||||||
|
} else {
|
||||||
|
jointPast[indexOfNextEmbeddingInterval] = sourceSpikeTimes[mostRecentSourceIndex - sourcePastIntervals[i] + 2]
|
||||||
|
- sourceSpikeTimes[mostRecentSourceIndex - sourcePastIntervals[i] + 1];
|
||||||
|
}
|
||||||
|
indexOfNextEmbeddingInterval++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add Gaussian noise, if necessary
|
||||||
|
if (addNoise) {
|
||||||
|
for (int i = 0; i < conditioningPast.length; i++) {
|
||||||
|
conditioningPast[i] = Math.log(conditioningPast[i] + 1.1);
|
||||||
|
conditioningPast[i] += random.nextGaussian() * noiseLevel;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < jointPast.length; i++) {
|
||||||
|
if (jointPast[i] < 0) {
|
||||||
|
System.out.println("NEGATIVE");
|
||||||
|
}
|
||||||
|
jointPast[i] = Math.log(jointPast[i] + 1.1);
|
||||||
|
jointPast[i] += random.nextGaussian() * noiseLevel;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
conditioningEmbeddings.add(conditioningPast);
|
||||||
|
jointEmbeddings.add(jointPast);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
protected int getFirstDestIndex(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes, Boolean setProcessTimeLengths)
|
||||||
|
throws Exception{
|
||||||
|
|
||||||
|
// First sort the spike times in case they were not properly in ascending order:
|
||||||
|
Arrays.sort(sourceSpikeTimes);
|
||||||
|
Arrays.sort(destSpikeTimes);
|
||||||
|
|
||||||
|
int firstTargetIndexOfEmbedding = destPastIntervals[destPastIntervals.length - 1];
|
||||||
|
|
||||||
|
int furthestInterval = sourcePastIntervals[sourcePastIntervals.length - 1];
|
||||||
|
while (destSpikeTimes[firstTargetIndexOfEmbedding] < sourceSpikeTimes[furthestInterval - 1]) {
|
||||||
|
firstTargetIndexOfEmbedding++;
|
||||||
|
}
|
||||||
|
if (conditionalSpikeTimes.length != vectorOfCondPastIntervals.size()) {
|
||||||
|
throw new Exception("Number of conditional embedding lengths does not match the number of conditional processes");
|
||||||
|
}
|
||||||
|
for (int i = 0; i < conditionalSpikeTimes.length; i++) {
|
||||||
|
furthestInterval = vectorOfCondPastIntervals.elementAt(i)[vectorOfCondPastIntervals.elementAt(i).length - 1];
|
||||||
|
while (destSpikeTimes[firstTargetIndexOfEmbedding] <= conditionalSpikeTimes[i][furthestInterval - 1]) {
|
||||||
|
firstTargetIndexOfEmbedding++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We don't want to reset these lengths when resampling for surrogates
|
||||||
|
if (setProcessTimeLengths) {
|
||||||
|
processTimeLengths.add(destSpikeTimes[destSpikeTimes.length - 1] - destSpikeTimes[firstTargetIndexOfEmbedding]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return firstTargetIndexOfEmbedding;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected double[] generateRandomSampleTimes(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes,
|
||||||
|
double actualNumSamplesMultiplier, int firstTargetIndexOfEmbedding, boolean doJitteredSampling) {
|
||||||
|
|
||||||
|
double sampleLowerBound = destSpikeTimes[firstTargetIndexOfEmbedding];
|
||||||
|
double sampleUpperBound = destSpikeTimes[destSpikeTimes.length - 1];
|
||||||
|
int num_samples = (int) Math.round(actualNumSamplesMultiplier * (destSpikeTimes.length - firstTargetIndexOfEmbedding + 1));
|
||||||
|
double[] randomSampleTimes = new double[num_samples];
|
||||||
|
Random rand = new Random();
|
||||||
|
if (doJitteredSampling) {
|
||||||
|
//System.out.println("jittering " + jitteredSamplingNoiseLevel);
|
||||||
|
for (int i = 0; i < randomSampleTimes.length; i++) {
|
||||||
|
randomSampleTimes[i] = destSpikeTimes[firstTargetIndexOfEmbedding + (i % (destSpikeTimes.length - firstTargetIndexOfEmbedding - 1))]
|
||||||
|
+ jitteredSamplingNoiseLevel * (rand.nextDouble() - 0.5);
|
||||||
|
//randomSampleTimes[i] = -1.0;
|
||||||
|
if ((randomSampleTimes[i] > sampleUpperBound) || (randomSampleTimes[i] < sampleLowerBound)) {
|
||||||
|
randomSampleTimes[i] = sampleLowerBound + rand.nextDouble() * (sampleUpperBound - sampleLowerBound);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (int i = 0; i < randomSampleTimes.length; i++) {
|
||||||
|
randomSampleTimes[i] = sampleLowerBound + rand.nextDouble() * (sampleUpperBound - sampleLowerBound);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Arrays.sort(randomSampleTimes);
|
||||||
|
/*System.out.println(sampleLowerBound + " " + sampleUpperBound);
|
||||||
|
for (int i = 0; i < randomSampleTimes.length; i += 100) {
|
||||||
|
System.out.print(randomSampleTimes[i] + " ");
|
||||||
|
}
|
||||||
|
System.out.println("\n\n\n");*/
|
||||||
|
return randomSampleTimes;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void processEventsFromSpikingTimeSeries(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes,
|
||||||
|
Vector<double[]> conditioningEmbeddingsFromSpikes, Vector<double[]> jointEmbeddingsFromSpikes,
|
||||||
|
Vector<double[]> conditioningEmbeddingsFromSamples, Vector<double[]> jointEmbeddingsFromSamples,
|
||||||
|
double actualNumSamplesMultiplier, boolean doJitteredSampling)
|
||||||
|
throws Exception {
|
||||||
|
|
||||||
|
int firstTargetIndexOfEmbedding = getFirstDestIndex(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, true);
|
||||||
|
double[] randomSampleTimes = generateRandomSampleTimes(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
|
||||||
|
actualNumSamplesMultiplier, firstTargetIndexOfEmbedding,
|
||||||
|
doJitteredSampling);
|
||||||
|
|
||||||
|
makeEmbeddingsAtPoints(destSpikeTimes, firstTargetIndexOfEmbedding, sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
|
||||||
|
conditioningEmbeddingsFromSpikes, jointEmbeddingsFromSpikes);
|
||||||
|
makeEmbeddingsAtPoints(randomSampleTimes, 0, sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
|
||||||
|
conditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void processEventsFromSpikingTimeSeries(double[] sourceSpikeTimes, double[] destSpikeTimes, double[][] conditionalSpikeTimes,
|
||||||
|
Vector<double[]> conditioningEmbeddingsFromSamples, Vector<double[]> jointEmbeddingsFromSamples,
|
||||||
|
double actualNumSamplesMultiplier, boolean doJitteredSampling)
|
||||||
|
throws Exception {
|
||||||
|
|
||||||
|
int firstTargetIndexOfEmbedding = getFirstDestIndex(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes, false);
|
||||||
|
double[] randomSampleTimes = generateRandomSampleTimes(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
|
||||||
|
actualNumSamplesMultiplier, firstTargetIndexOfEmbedding,
|
||||||
|
doJitteredSampling);
|
||||||
|
makeEmbeddingsAtPoints(randomSampleTimes, 0, sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
|
||||||
|
conditioningEmbeddingsFromSamples, jointEmbeddingsFromSamples);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* (non-Javadoc)
|
||||||
|
*
|
||||||
|
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||||
|
* getAddedMoreThanOneObservationSet()
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public boolean getAddedMoreThanOneObservationSet() {
|
||||||
|
return (vectorOfDestinationSpikeTimes != null) && (vectorOfDestinationSpikeTimes.size() > 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Class to allow returning two values in the subsequent method
|
||||||
|
private static class distanceAndNumPoints {
|
||||||
|
public double distance;
|
||||||
|
public int numPoints;
|
||||||
|
|
||||||
|
public distanceAndNumPoints(double distance, int numPoints) {
|
||||||
|
this.distance = distance;
|
||||||
|
this.numPoints = numPoints;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private distanceAndNumPoints findMaxDistanceAndNumPointsFromIndices(double[] point, int[] indices, Vector<double[]> setOfPoints) {
|
||||||
|
double maxDistance = 0;
|
||||||
|
int i = 0;
|
||||||
|
for (; indices[i] != -1; i++) {
|
||||||
|
double distance = KdTree.norm(point, setOfPoints.elementAt(indices[i]), normType);
|
||||||
|
if (distance > maxDistance) {
|
||||||
|
maxDistance = distance;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new distanceAndNumPoints(maxDistance, i);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double computeAverageLocalOfObservations() throws Exception {
|
||||||
|
return computeAverageLocalOfObservations(kdTreeJointAtSpikes, jointEmbeddingsFromSpikes);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
* We take the actual joint tree at spikes (along with the associated embeddings) as an argument, as we will need to swap these out when
|
||||||
|
* computing surrogates.
|
||||||
|
*/
|
||||||
|
public double computeAverageLocalOfObservations(KdTree actualKdTreeJointAtSpikes, Vector<double[]> actualJointEmbeddingsFromSpikes) throws Exception {
|
||||||
|
|
||||||
|
double currentSum = 0;
|
||||||
|
for (int i = 0; i < conditioningEmbeddingsFromSpikes.size(); i++) {
|
||||||
|
|
||||||
|
double radiusJointSpikes = actualKdTreeJointAtSpikes.findKNearestNeighbours(Knns, i).poll().norms[0];
|
||||||
|
double radiusJointSamples = kdTreeJointAtSamples.findKNearestNeighbours(Knns,
|
||||||
|
new double[][] { actualJointEmbeddingsFromSpikes.elementAt(i) }).poll().norms[0];
|
||||||
|
|
||||||
|
/*
|
||||||
|
The algorithm specified in box 1 of doi.org/10.1371/journal.pcbi.1008054 specifies finding the maximum of the two radii
|
||||||
|
just calculated and then redoing the searches in both sets at this radius. In this implementation, however, we make use
|
||||||
|
of the fact that one radius is equal to the maximum, and so only one search needs to be redone.
|
||||||
|
*/
|
||||||
|
double eps = 0.0;
|
||||||
|
// Need variables for the number of neighbours as this is now variable within the maximum radius
|
||||||
|
int kJointSpikes = 0;
|
||||||
|
int kJointSamples = 0;
|
||||||
|
if (radiusJointSpikes >= radiusJointSamples) {
|
||||||
|
/*
|
||||||
|
The maximum was the radius in the set of embeddings at spikes, so redo search in the set of embeddings at randomly
|
||||||
|
sampled points, using this larger radius.
|
||||||
|
*/
|
||||||
|
kJointSpikes = Knns;
|
||||||
|
int[] indicesWithinR = new int[jointEmbeddingsFromSamples.size()];
|
||||||
|
boolean[] isWithinR = new boolean[jointEmbeddingsFromSamples.size()];
|
||||||
|
kdTreeJointAtSamples.findPointsWithinR(radiusJointSpikes + eps,
|
||||||
|
new double[][] { actualJointEmbeddingsFromSpikes.elementAt(i) },
|
||||||
|
true,
|
||||||
|
isWithinR,
|
||||||
|
indicesWithinR);
|
||||||
|
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(actualJointEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
|
||||||
|
jointEmbeddingsFromSamples);
|
||||||
|
kJointSamples = temp.numPoints;
|
||||||
|
radiusJointSamples = temp.distance;
|
||||||
|
} else {
|
||||||
|
/*
|
||||||
|
The maximum was the radius in the set of embeddings at randomly sampled points, so redo search in the set of embeddings
|
||||||
|
at spikes, using this larger radius.
|
||||||
|
*/
|
||||||
|
kJointSamples = Knns;
|
||||||
|
int[] indicesWithinR = new int[jointEmbeddingsFromSamples.size()];
|
||||||
|
boolean[] isWithinR = new boolean[jointEmbeddingsFromSamples.size()];
|
||||||
|
actualKdTreeJointAtSpikes.findPointsWithinR(radiusJointSamples + eps,
|
||||||
|
new double[][] { actualJointEmbeddingsFromSpikes.elementAt(i) },
|
||||||
|
true,
|
||||||
|
isWithinR,
|
||||||
|
indicesWithinR);
|
||||||
|
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(actualJointEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
|
||||||
|
actualJointEmbeddingsFromSpikes);
|
||||||
|
// -1 due to the point itself being in the set
|
||||||
|
kJointSpikes = temp.numPoints - 1;
|
||||||
|
radiusJointSpikes = temp.distance;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Repeat the above steps, but in the conditioning (rather than joint) space.
|
||||||
|
double radiusConditioningSpikes = kdTreeConditioningAtSpikes.findKNearestNeighbours(Knns, i).poll().norms[0];
|
||||||
|
double radiusConditioningSamples = kdTreeConditioningAtSamples.findKNearestNeighbours(Knns,
|
||||||
|
new double[][] { conditioningEmbeddingsFromSpikes.elementAt(i) }).poll().norms[0];
|
||||||
|
int kConditioningSpikes = 0;
|
||||||
|
int kConditioningSamples = 0;
|
||||||
|
if (radiusConditioningSpikes >= radiusConditioningSamples) {
|
||||||
|
kConditioningSpikes = Knns;
|
||||||
|
int[] indicesWithinR = new int[conditioningEmbeddingsFromSamples.size()];
|
||||||
|
boolean[] isWithinR = new boolean[conditioningEmbeddingsFromSamples.size()];
|
||||||
|
kdTreeConditioningAtSamples.findPointsWithinR(radiusConditioningSpikes + eps,
|
||||||
|
new double[][] { conditioningEmbeddingsFromSpikes.elementAt(i) },
|
||||||
|
true,
|
||||||
|
isWithinR,
|
||||||
|
indicesWithinR);
|
||||||
|
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(conditioningEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
|
||||||
|
conditioningEmbeddingsFromSamples);
|
||||||
|
kConditioningSamples = temp.numPoints;
|
||||||
|
radiusConditioningSamples = temp.distance;
|
||||||
|
} else {
|
||||||
|
kConditioningSamples = Knns;
|
||||||
|
int[] indicesWithinR = new int[conditioningEmbeddingsFromSamples.size()];
|
||||||
|
boolean[] isWithinR = new boolean[conditioningEmbeddingsFromSamples.size()];
|
||||||
|
kdTreeConditioningAtSpikes.findPointsWithinR(radiusConditioningSamples + eps,
|
||||||
|
new double[][] { conditioningEmbeddingsFromSpikes.elementAt(i) },
|
||||||
|
true,
|
||||||
|
isWithinR,
|
||||||
|
indicesWithinR);
|
||||||
|
distanceAndNumPoints temp = findMaxDistanceAndNumPointsFromIndices(conditioningEmbeddingsFromSpikes.elementAt(i), indicesWithinR,
|
||||||
|
conditioningEmbeddingsFromSpikes);
|
||||||
|
// -1 due to the point itself being in the set
|
||||||
|
kConditioningSpikes = temp.numPoints - 1;
|
||||||
|
radiusConditioningSpikes = temp.distance;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* TODO
|
||||||
|
* The KdTree class defaults to the squared euclidean distance when the euclidean norm is specified. This is fine for Kraskov estimators
|
||||||
|
* (as the radii are never used, just the numbers of points within radii). It causes problems here though, as we do use the radii and the
|
||||||
|
* squared euclidean distance is not a distance metric. We get around this by just taking the square root here, but it might be better to
|
||||||
|
* fix this in the KdTree class.
|
||||||
|
*/
|
||||||
|
double tempRadiusJointSamples = radiusJointSamples;
|
||||||
|
if (normType == EuclideanUtils.NORM_EUCLIDEAN) {
|
||||||
|
radiusJointSpikes = Math.sqrt(radiusJointSpikes);
|
||||||
|
radiusJointSamples = Math.sqrt(radiusJointSamples);
|
||||||
|
radiusConditioningSpikes = Math.sqrt(radiusConditioningSpikes);
|
||||||
|
radiusConditioningSamples = Math.sqrt(radiusConditioningSamples);
|
||||||
|
}
|
||||||
|
|
||||||
|
currentSum += (MathsUtils.digamma(kJointSpikes) - MathsUtils.digamma(kJointSamples) +
|
||||||
|
((numDestPastIntervals + numCondPastIntervals + numSourcePastIntervals) * (-Math.log(radiusJointSpikes) + Math.log(radiusJointSamples))) -
|
||||||
|
MathsUtils.digamma(kConditioningSpikes) + MathsUtils.digamma(kConditioningSamples) +
|
||||||
|
+ ((numDestPastIntervals + numCondPastIntervals) * (Math.log(radiusConditioningSpikes) - Math.log(radiusConditioningSamples))));
|
||||||
|
if (Double.isNaN(currentSum)) {
|
||||||
|
for (double[] embed : jointEmbeddingsFromSamples) {
|
||||||
|
System.out.println(Arrays.toString(embed));
|
||||||
|
}
|
||||||
|
throw new Exception("NaNs in TE clac " + radiusJointSpikes + " " + radiusJointSamples + " " + tempRadiusJointSamples);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Normalise by time
|
||||||
|
double timeSum = 0;
|
||||||
|
for (Double time : processTimeLengths) {
|
||||||
|
timeSum += time;
|
||||||
|
}
|
||||||
|
currentSum /= timeSum;
|
||||||
|
return currentSum;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck, double estimatedValue) throws Exception {
|
||||||
|
return computeSignificance(numPermutationsToCheck,
|
||||||
|
estimatedValue,
|
||||||
|
System.currentTimeMillis());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public EmpiricalMeasurementDistribution computeSignificance(int numPermutationsToCheck,
|
||||||
|
double estimatedValue, long randomSeed) throws Exception{
|
||||||
|
|
||||||
|
Random random = new Random(randomSeed);
|
||||||
|
double[] surrogateTEValues = new double[numPermutationsToCheck];
|
||||||
|
|
||||||
|
for (int permutationNumber = 0; permutationNumber < numPermutationsToCheck; permutationNumber++) {
|
||||||
|
Vector<double[]> resampledConditioningEmbeddingsFromSamples = new Vector<double[]>();
|
||||||
|
Vector<double[]> resampledJointEmbeddingsFromSamples = new Vector<double[]>();
|
||||||
|
|
||||||
|
// Send all of the observations through:
|
||||||
|
Iterator<double[]> sourceIterator = vectorOfSourceSpikeTimes.iterator();
|
||||||
|
Iterator<double[][]> conditionalIterator = null;
|
||||||
|
if (vectorOfConditionalSpikeTimes.size() > 0) {
|
||||||
|
conditionalIterator = vectorOfConditionalSpikeTimes.iterator();
|
||||||
|
}
|
||||||
|
int timeSeriesIndex = 0;
|
||||||
|
for (double[] destSpikeTimes : vectorOfDestinationSpikeTimes) {
|
||||||
|
double[] sourceSpikeTimes = sourceIterator.next();
|
||||||
|
double[][] conditionalSpikeTimes = null;
|
||||||
|
if (vectorOfConditionalSpikeTimes.size() > 0) {
|
||||||
|
conditionalSpikeTimes = conditionalIterator.next();
|
||||||
|
} else {
|
||||||
|
conditionalSpikeTimes = new double[][] {};
|
||||||
|
}
|
||||||
|
if (jitteredSamplesForSurrogates) {
|
||||||
|
processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
|
||||||
|
resampledConditioningEmbeddingsFromSamples, resampledJointEmbeddingsFromSamples,
|
||||||
|
surrogateNumSamplesMultiplier, true);
|
||||||
|
} else {
|
||||||
|
processEventsFromSpikingTimeSeries(sourceSpikeTimes, destSpikeTimes, conditionalSpikeTimes,
|
||||||
|
resampledConditioningEmbeddingsFromSamples, resampledJointEmbeddingsFromSamples,
|
||||||
|
surrogateNumSamplesMultiplier, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Convert the vectors to arrays so that they can be put in the trees
|
||||||
|
double[][] arrayedResampledConditioningEmbeddingsFromSamples = new double[resampledConditioningEmbeddingsFromSamples.size()][numDestPastIntervals];
|
||||||
|
for (int i = 0; i < resampledConditioningEmbeddingsFromSamples.size(); i++) {
|
||||||
|
arrayedResampledConditioningEmbeddingsFromSamples[i] = resampledConditioningEmbeddingsFromSamples.elementAt(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
KdTree resampledKdTreeConditioningAtSamples = new KdTree(arrayedResampledConditioningEmbeddingsFromSamples);
|
||||||
|
resampledKdTreeConditioningAtSamples.setNormType(normType);
|
||||||
|
|
||||||
|
Vector<double[]> conditionallyPermutedJointEmbeddingsFromSpikes = new Vector(jointEmbeddingsFromSpikes);
|
||||||
|
|
||||||
|
Vector<Integer> usedIndices = new Vector<Integer>();
|
||||||
|
for (int i = 0; i < conditionallyPermutedJointEmbeddingsFromSpikes.size(); i++) {
|
||||||
|
PriorityQueue<NeighbourNodeData> neighbours =
|
||||||
|
resampledKdTreeConditioningAtSamples.findKNearestNeighbours(kPerm,
|
||||||
|
new double[][] {conditioningEmbeddingsFromSpikes.elementAt(i)});
|
||||||
|
ArrayList<Integer> foundIndices = new ArrayList<Integer>();
|
||||||
|
for (int j = 0; j < kPerm; j++) {
|
||||||
|
foundIndices.add(neighbours.poll().sampleIndex);
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayList<Integer> prunedIndices = new ArrayList<Integer>(foundIndices);
|
||||||
|
prunedIndices.removeAll(usedIndices);
|
||||||
|
int chosenIndex = 0;
|
||||||
|
if (prunedIndices.size() > 0) {
|
||||||
|
chosenIndex = prunedIndices.get(random.nextInt(prunedIndices.size()));
|
||||||
|
} else {
|
||||||
|
chosenIndex = foundIndices.get(random.nextInt(foundIndices.size()));
|
||||||
|
}
|
||||||
|
usedIndices.add(chosenIndex);
|
||||||
|
int embeddingLength = conditionallyPermutedJointEmbeddingsFromSpikes.elementAt(i).length;
|
||||||
|
for(int j = 0; j < numSourcePastIntervals; j++) {
|
||||||
|
conditionallyPermutedJointEmbeddingsFromSpikes.elementAt(i)[embeddingLength - numSourcePastIntervals + j] =
|
||||||
|
resampledJointEmbeddingsFromSamples.elementAt(chosenIndex)[embeddingLength - numSourcePastIntervals + j];
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
double[][] arrayedConditionallyPermutedJointEmbeddingsFromSpikes = new double[conditionallyPermutedJointEmbeddingsFromSpikes.size()][];
|
||||||
|
for (int i = 0; i < conditionallyPermutedJointEmbeddingsFromSpikes.size(); i++) {
|
||||||
|
arrayedConditionallyPermutedJointEmbeddingsFromSpikes[i] = conditionallyPermutedJointEmbeddingsFromSpikes.elementAt(i);
|
||||||
|
}
|
||||||
|
KdTree conditionallyPermutedKdTreeJointFromSpikes = new KdTree(arrayedConditionallyPermutedJointEmbeddingsFromSpikes);
|
||||||
|
conditionallyPermutedKdTreeJointFromSpikes.setNormType(normType);
|
||||||
|
surrogateTEValues[permutationNumber] = computeAverageLocalOfObservations(conditionallyPermutedKdTreeJointFromSpikes,
|
||||||
|
conditionallyPermutedJointEmbeddingsFromSpikes);
|
||||||
|
}
|
||||||
|
return new EmpiricalMeasurementDistribution(surrogateTEValues, estimatedValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
* (non-Javadoc)
|
||||||
|
*
|
||||||
|
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#
|
||||||
|
* computeLocalOfPreviousObservations()
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public SpikingLocalInformationValues computeLocalOfPreviousObservations() throws Exception {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
* (non-Javadoc)
|
||||||
|
*
|
||||||
|
* @see infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#setDebug(
|
||||||
|
* boolean)
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public void setDebug(boolean debug) {
|
||||||
|
this.debug = debug;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* (non-Javadoc)
|
||||||
|
*
|
||||||
|
* @see
|
||||||
|
* infodynamics.measures.spiking.TransferEntropyCalculatorSpiking#getLastAverage
|
||||||
|
* ()
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public double getLastAverage() {
|
||||||
|
// TODO Auto-generated method stub
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,8 @@
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
/**
|
||||||
|
* @author joseph
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
package infodynamics.measures.spiking.integration;
|
|
@ -0,0 +1,8 @@
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
/**
|
||||||
|
* @author joseph
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
package infodynamics.measures.spiking;
|
0
java/source/infodynamics/utils/UnivariateNearestNeighbourSearcher.java
Executable file → Normal file
0
java/source/infodynamics/utils/UnivariateNearestNeighbourSearcher.java
Executable file → Normal file
Loading…
Reference in New Issue