From a398c925a8310bf49ed3dd3aaea63c82fb9583a0 Mon Sep 17 00:00:00 2001 From: David Shorten Date: Thu, 4 Aug 2022 16:03:12 +0930 Subject: [PATCH] added python for spiking network inference --- .../EffectiveNetworkInference/net_inf.py | 243 ++++++++++++++++++ .../EffectiveNetworkInference/spk_to_pk.py | 43 ++++ 2 files changed, 286 insertions(+) create mode 100755 demos/python/EffectiveNetworkInference/net_inf.py create mode 100644 demos/python/EffectiveNetworkInference/spk_to_pk.py diff --git a/demos/python/EffectiveNetworkInference/net_inf.py b/demos/python/EffectiveNetworkInference/net_inf.py new file mode 100755 index 0000000..73cd878 --- /dev/null +++ b/demos/python/EffectiveNetworkInference/net_inf.py @@ -0,0 +1,243 @@ +# Argument order: network_type_name num_spikes sim_number target_index + +from jpype import * +import random +import math +import os +import numpy as np +import pickle +import copy +import sys + +net_type_name = sys.argv[1] +num_spikes_string = sys.argv[2] +repeat_num_string = sys.argv[3] +target_index_string = sys.argv[4] + +NUM_SURROGATES_PER_TE_VAL = 100 +P_LEVEL = 0.05 +KNNS = 10 +NUM_SAMPLES_MULTIPLIER = 5.0 +#SURROGATE_NUM_SAMPLES_MULTIPLIER = 5.0 +SURROGATE_NUM_SAMPLES_MULTIPLIER = 5.0 +K_PERM = 20 +JITTERING_LEVEL = 2000 + +MAX_NUM_SECOND_INTERVALS = 2 +MAX_NUM_TARGET_SPIKES = int(num_spikes_string) +SPIKES_FILE_NAME = "spikes_LIF_" + net_type_name + "_" + repeat_num_string + ".pk" +GROUND_TRUTH_FILE_NAME = "connections_LIF_"+ net_type_name + "_" + repeat_num_string + ".pk" +OUTPUT_FILE_PREFIX = "results/inferred_sources_target_2_" + net_type_name + "_" + num_spikes_string + "_" + repeat_num_string + "_" + target_index_string +LOG_FILE_NAME = "logs/" + net_type_name + "_" + num_spikes_string + "_" + repeat_num_string + "_" + target_index_string + ".log" + +log = open(LOG_FILE_NAME, "w") +sys.stdout = log + +def prepare_conditional_trains(calc_object, cond_set): + cond_trains = [] + calc_object.clearConditionalIntervals() + if len(cond_set) > 0: + for key in cond_set.keys(): + cond_trains.append(spikes[key]) + teCalc.appendConditionalIntervals(JArray(JInt, 1)(cond_set[key])) + return cond_trains + +def set_target_embeddings(embedding_list): + if len(embedding_list) > 0: + embedding_string = str(embedding_list[0]) + for i in range(2, len(embedding_list)): + embedding_string += "," + str(embedding_list[i]) + teCalc.setProperty("DEST_PAST_INTERVALS", embedding_string) + else: + teCalc.setProperty("DEST_PAST_INTERVALS", "") + + +target_index = int(target_index_string) +print("\n****** Network inference for target neuron", target_index, "******\n\n") + + +# Setup JIDT +jarLocation = os.path.join(os.getcwd(), "../jidt/infodynamics.jar"); +if (not(os.path.isfile(jarLocation))): + exit("infodynamics.jar not found (expected at " + os.path.abspath(jarLocation) + ") - are you running from demos/python?") +startJVM(getDefaultJVMPath(), "-ea", "-Djava.class.path=" + jarLocation) +teCalcClass = JPackage("infodynamics.measures.spiking.integration").TransferEntropyCalculatorSpikingIntegration +teCalc = teCalcClass() +teCalc.setProperty("knns", str(KNNS)) +teCalc.setProperty("NUM_SAMPLES_MULTIPLIER", str(NUM_SAMPLES_MULTIPLIER)) +teCalc.setProperty("SURROGATE_NUM_SAMPLES_MULTIPLIER", str(SURROGATE_NUM_SAMPLES_MULTIPLIER)) +teCalc.setProperty("K_PERM", str(K_PERM)) +teCalc.setProperty("DO_JITTERED_SAMPLING", "true") +teCalc.setProperty("JITTERED_SAMPLING_NOISE_LEVEL", str(JITTERING_LEVEL)) + +# Load spikes and ground truth connectivity +spikes = pickle.load(open(SPIKES_FILE_NAME, 'rb')) +cons = pickle.load(open(GROUND_TRUTH_FILE_NAME, 'rb')) +if MAX_NUM_TARGET_SPIKES < len(spikes[target_index]): + spikes[target_index] = spikes[target_index][:MAX_NUM_TARGET_SPIKES] +print("Number of target spikes: ", len(spikes[target_index]), "\n\n") + + +# First determine the correct target embedding +target_embedding_set = [1] +next_target_interval = 2 +still_significant = True +print("**** Determining target embedding set ****\n") +while still_significant: + set_target_embeddings(target_embedding_set) + teCalc.setProperty("SOURCE_PAST_INTERVALS", str(next_target_interval)) + teCalc.startAddObservations() + teCalc.addObservations(JArray(JDouble, 1)(spikes[target_index]), JArray(JDouble, 1)(spikes[target_index])) + teCalc.finaliseAddObservations(); + TE = teCalc.computeAverageLocalOfObservations() + sig = teCalc.computeSignificance(NUM_SURROGATES_PER_TE_VAL, TE) + print("candidate interval:", next_target_interval, " TE:", TE, " p val:", sig.pValue) + if sig.pValue > P_LEVEL: + print("Lost significance, end of target embedding determination") + still_significant = False + else: + target_embedding_set.append(next_target_interval) + next_target_interval += 1 +print("target embedding set:", target_embedding_set, "\n\n") + + +# Now add the sources +# cond_set is a dictionary where keys are added sources and values are lists of included intervals for the +# source key. +cond_set = dict() +# next_interval_for_each_candidate will be a matrix with two columns +# first column has the source indices, second has the next interval that will be considered +next_interval_for_each_candidate = np.arange(0, len(spikes), dtype = np.intc) +next_interval_for_each_candidate = next_interval_for_each_candidate[next_interval_for_each_candidate != target_index] +next_interval_for_each_candidate = np.column_stack((next_interval_for_each_candidate, np.ones(len(next_interval_for_each_candidate), dtype = np.intc))) +still_significant = True +TE_vals_at_each_round = [] +surrogate_vals_at_each_round = [] +print("**** Adding Sources ****\n") +num_twos = 0 +while still_significant: + print("Current conditioning set:") + for key in cond_set.keys(): + print("source", key, "intervals", cond_set[key]) + print("\nEstimating TE on candidate sources") + cond_trains = prepare_conditional_trains(teCalc, cond_set) + TE_vals = np.zeros(next_interval_for_each_candidate.shape[0]) + debiased_TE_vals = -1 * np.ones(next_interval_for_each_candidate.shape[0]) + surrogate_vals = -1 * np.ones((next_interval_for_each_candidate.shape[0], NUM_SURROGATES_PER_TE_VAL)) + debiased_surrogate_vals = 1 - np.ones((next_interval_for_each_candidate.shape[0], NUM_SURROGATES_PER_TE_VAL)) + is_con = np.zeros(next_interval_for_each_candidate.shape[0]) + for i in range(next_interval_for_each_candidate.shape[0]): + if len(spikes[next_interval_for_each_candidate[i, 0]]) < 10: + continue + teCalc.startAddObservations() + teCalc.setProperty("SOURCE_PAST_INTERVALS", str(next_interval_for_each_candidate[i, 1])) + if len(cond_set) > 0: + teCalc.addObservations(JArray(JDouble, 1)(spikes[next_interval_for_each_candidate[i, 0]]), + JArray(JDouble, 1)(spikes[target_index]), JArray(JDouble, 2)(cond_trains)) + else: + teCalc.addObservations(JArray(JDouble, 1)(spikes[next_interval_for_each_candidate[i, 0]]), + JArray(JDouble, 1)(spikes[target_index])) + teCalc.finaliseAddObservations(); + TE_vals[i] = teCalc.computeAverageLocalOfObservations() + is_con[i] = ([next_interval_for_each_candidate[i, 0], target_index] in cons) + sig = teCalc.computeSignificance(NUM_SURROGATES_PER_TE_VAL, TE_vals[i]) + surrogate_vals[i] = sig.distribution + debiased_TE_vals[i] = TE_vals[i] - np.mean(surrogate_vals[i]) + debiased_surrogate_vals[i] = sig.distribution - np.mean(surrogate_vals[i]) + print("Source", next_interval_for_each_candidate[i, 0], "Interval", next_interval_for_each_candidate[i, 1], + " TE:", str(debiased_TE_vals[i])) + log.flush() + + TE_vals_at_each_round.append(TE_vals) + surrogate_vals_at_each_round.append(surrogate_vals) + sorted_TE_indices = np.argsort(debiased_TE_vals) + print("\nSorted order of sources:\n", next_interval_for_each_candidate[:, 0][sorted_TE_indices[:]]) + print("Ground truth for sorted order:\n", is_con[sorted_TE_indices[:]]) + + index_of_max_candidate = sorted_TE_indices[-1] + samples_from_max_dist = np.max(debiased_surrogate_vals, axis = 0) + np.sort(samples_from_max_dist) + index_of_first_greater_than_estimate = np.searchsorted(samples_from_max_dist > debiased_TE_vals[index_of_max_candidate], 1) + p_val = (NUM_SURROGATES_PER_TE_VAL - index_of_first_greater_than_estimate)/float(NUM_SURROGATES_PER_TE_VAL) + print("\nMaximum candidate is source", next_interval_for_each_candidate[index_of_max_candidate, 0], + "interval", next_interval_for_each_candidate[index_of_max_candidate, 1]) + print("p: ", p_val) + if p_val <= P_LEVEL: + if (next_interval_for_each_candidate[index_of_max_candidate, 0]) in cond_set: + cond_set[next_interval_for_each_candidate[index_of_max_candidate, 0]].append(next_interval_for_each_candidate[index_of_max_candidate, 1]) + else: + cond_set[next_interval_for_each_candidate[index_of_max_candidate, 0]] = [next_interval_for_each_candidate[index_of_max_candidate, 1]] + + if next_interval_for_each_candidate[index_of_max_candidate, 1] == 2: + num_twos += 1 + if num_twos >= MAX_NUM_SECOND_INTERVALS: + print("\nMaximum number of second intervals reached\n\n") + still_significant = False + + next_interval_for_each_candidate[index_of_max_candidate, 1] += 1 + + print("\nCandidate added\n\n") + else: + still_significant = False + print("\nLost Significance\n\n") + +print("**** Pruning Sources ****\n") +# Repeatedly removes the connection that has the lowest TE out of all insignificant connections. +# Only considers the furthest intervals as candidates in each round. +everything_significant = False +while not everything_significant: + print("Current conditioning set:") + for key in cond_set.keys(): + print("source", key, "intervals", cond_set[key]) + print("\nEstimating TE on candidate sources") + everything_significant = True + insignificant_sources = [] + insignificant_sources_TE = [] + for candidate_source in cond_set: + cond_set_minus_candidate = copy.deepcopy(cond_set) + # If more than one interval, remove the last + if len(cond_set_minus_candidate[candidate_source]) > 1: + cond_set_minus_candidate[candidate_source] = cond_set_minus_candidate[candidate_source][:-1] + # Otherwise, remove source from dict + else: + cond_set_minus_candidate.pop(candidate_source) + teCalc.setProperty("SOURCE_PAST_INTERVALS", str(cond_set[candidate_source][-1])) + cond_trains = prepare_conditional_trains(teCalc, cond_set_minus_candidate) + teCalc.startAddObservations() + if len(cond_set_minus_candidate) > 0: + teCalc.addObservations(JArray(JDouble, 1)(spikes[candidate_source]), JArray(JDouble, 1)(spikes[target_index]), JArray(JDouble, 2)(cond_trains)) + else: + teCalc.addObservations(JArray(JDouble, 1)(spikes[candidate_source]), JArray(JDouble, 1)(spikes[target_index])) + teCalc.finaliseAddObservations(); + TE = teCalc.computeAverageLocalOfObservations() + sig = teCalc.computeSignificance(NUM_SURROGATES_PER_TE_VAL, TE) + print("Source", candidate_source, "Interval", cond_set[candidate_source][-1], + " TE:", str(round(TE, 2)), " p val:", sig.pValue) + if sig.pValue > P_LEVEL: + everything_significant = False + insignificant_sources.append(candidate_source) + insignificant_sources_TE.append(TE) + if not everything_significant: + min_TE_source = insignificant_sources[np.argmin(insignificant_sources_TE)] + print("removing source", min_TE_source, "interval", cond_set[min_TE_source][-1]) + if len(cond_set[min_TE_source]) > 1: + cond_set[min_TE_source] = cond_set[min_TE_source][:-1] + else: + cond_set.pop(min_TE_source) + + + +print("\n\n****** Final Inferred Source Set ******\n") +for key in cond_set.keys(): + print("source", key, "intervals", cond_set[key]) +print("\nTrue Sources:") +for con in cons: + if con[1] == target_index: + print(con[0], " ",) + + +output_file = open(OUTPUT_FILE_PREFIX + ".pk", 'wb') +pickle.dump(cond_set, output_file) +#pickle.dump(surrogate_vals_at_each_round, output_file) +#pickle.dump(TE_vals_at_each_round, output_file) +output_file.close() diff --git a/demos/python/EffectiveNetworkInference/spk_to_pk.py b/demos/python/EffectiveNetworkInference/spk_to_pk.py new file mode 100644 index 0000000..2bef589 --- /dev/null +++ b/demos/python/EffectiveNetworkInference/spk_to_pk.py @@ -0,0 +1,43 @@ +import numpy as np +import pickle +import sys +import ast +import matplotlib.pyplot as plt + +RUN = "1-1-20.2" + +spk_file = open('extracted_data_wagenaar/1-1/' + RUN + '.spk', 'r') +time_upper = 8 * 60 * 60 * 2.5e4 + +spikes = [] +for line in spk_file: + line = line.strip() + line = line.split(",") + line = [float(time) for time in line if time != ""] + spikes.append(np.array(line)) + +start_times = [train[0] for train in spikes if len(train) > 0] +lowest_start_time = min(start_times) +cutoff_time = lowest_start_time + time_upper + +for i in range(len(spikes)): + spikes[i] = spikes[i][spikes[i] < cutoff_time] + spikes[i] = spikes[i] - lowest_start_time + spikes[i] = spikes[i] + np.random.uniform(size = spikes[i].shape) - 0.5 + spikes[i] = np.sort(spikes[i]) + +print(len(spikes)) +for i in range(len(spikes)): + print(spikes[i].shape) + print(spikes[i][:10]) + +#plt.eventplot(spikes, linewidth = 0.5) +#plt.show() +spikes_file = open("spikes_LIF_" + RUN + "_" + sys.argv[1] + ".pk", "wb") +pickle.dump(spikes, spikes_file) +cons = [[0, 0]] +connections_file = open("connections_LIF_" + RUN + "_" + sys.argv[1] + ".pk", "wb") +pickle.dump(cons, connections_file) +spikes_file.close() +connections_file.close() +