From ddc3350c6c812961971f2a61d27cf645bd2cf247 Mon Sep 17 00:00:00 2001 From: ago109 Date: Sun, 21 Jul 2024 19:44:56 -0400 Subject: [PATCH 1/3] added tistdp exhibit --- exhibits/time-integrated-stdp/README.md | 75 ++++ .../time-integrated-stdp/analyze_models.sh | 74 ++++ exhibits/time-integrated-stdp/bind.sh | 23 ++ exhibits/time-integrated-stdp/bind_labels.py | 137 ++++++++ .../time-integrated-stdp/custom/__init__.py | 1 + .../custom/ti_stdp_synapse.py | 98 ++++++ exhibits/time-integrated-stdp/eval.py | 158 +++++++++ exhibits/time-integrated-stdp/eval.sh | 23 ++ .../time-integrated-stdp/extract_codes.py | 86 +++++ .../time-integrated-stdp/fig/tistdp_snn.jpg | Bin 0 -> 28511 bytes .../time-integrated-stdp/harvest_latents.sh | 36 ++ .../json_files/config.json | 3 + .../json_files/modules.json | 16 + .../patch_model/assemble_patterns.py | 159 +++++++++ .../patch_model/custom/LCNSynapse.py | 114 ++++++ .../patch_model/custom/__init__.py | 3 + .../patch_model/custom/patch_utils.py | 95 +++++ .../patch_model/custom/ti_STDP_LCNSynapse.py | 111 ++++++ .../patch_model/custom/ti_STDP_Synapse.py | 97 ++++++ .../patch_model/patch_tistdp_snn.py | 235 +++++++++++++ .../patch_model/patched_train.py | 192 +++++++++++ .../patch_model/sample_model.sh | 29 ++ .../patch_model/train_by_patch.py | 167 +++++++++ .../patch_model/train_patch_models.sh | 40 +++ exhibits/time-integrated-stdp/snn_case1.py | 321 +++++++++++++++++ exhibits/time-integrated-stdp/snn_case2.py | 324 ++++++++++++++++++ exhibits/time-integrated-stdp/train.py | 206 +++++++++++ exhibits/time-integrated-stdp/train_models.sh | 56 +++ exhibits/time-integrated-stdp/viz_codes.py | 29 ++ 29 files changed, 2908 insertions(+) create mode 100755 exhibits/time-integrated-stdp/README.md create mode 100755 exhibits/time-integrated-stdp/analyze_models.sh create mode 100755 exhibits/time-integrated-stdp/bind.sh create mode 100755 exhibits/time-integrated-stdp/bind_labels.py create mode 100755 exhibits/time-integrated-stdp/custom/__init__.py create mode 100755 exhibits/time-integrated-stdp/custom/ti_stdp_synapse.py create mode 100755 exhibits/time-integrated-stdp/eval.py create mode 100755 exhibits/time-integrated-stdp/eval.sh create mode 100755 exhibits/time-integrated-stdp/extract_codes.py create mode 100755 exhibits/time-integrated-stdp/fig/tistdp_snn.jpg create mode 100755 exhibits/time-integrated-stdp/harvest_latents.sh create mode 100644 exhibits/time-integrated-stdp/json_files/config.json create mode 100644 exhibits/time-integrated-stdp/json_files/modules.json create mode 100755 exhibits/time-integrated-stdp/patch_model/assemble_patterns.py create mode 100755 exhibits/time-integrated-stdp/patch_model/custom/LCNSynapse.py create mode 100755 exhibits/time-integrated-stdp/patch_model/custom/__init__.py create mode 100755 exhibits/time-integrated-stdp/patch_model/custom/patch_utils.py create mode 100755 exhibits/time-integrated-stdp/patch_model/custom/ti_STDP_LCNSynapse.py create mode 100644 exhibits/time-integrated-stdp/patch_model/custom/ti_STDP_Synapse.py create mode 100755 exhibits/time-integrated-stdp/patch_model/patch_tistdp_snn.py create mode 100755 exhibits/time-integrated-stdp/patch_model/patched_train.py create mode 100755 exhibits/time-integrated-stdp/patch_model/sample_model.sh create mode 100755 exhibits/time-integrated-stdp/patch_model/train_by_patch.py create mode 100755 exhibits/time-integrated-stdp/patch_model/train_patch_models.sh create mode 100755 exhibits/time-integrated-stdp/snn_case1.py create mode 100755 exhibits/time-integrated-stdp/snn_case2.py create mode 100755 exhibits/time-integrated-stdp/train.py create mode 100755 exhibits/time-integrated-stdp/train_models.sh create mode 100755 exhibits/time-integrated-stdp/viz_codes.py diff --git a/exhibits/time-integrated-stdp/README.md b/exhibits/time-integrated-stdp/README.md new file mode 100755 index 0000000..36475e8 --- /dev/null +++ b/exhibits/time-integrated-stdp/README.md @@ -0,0 +1,75 @@ +# Time-Integrated Spike-Timing-Dependent Plasticity + +Version: ngclearn>=1.2.beta1, ngcsimlib==0.3.beta4 + +This exhibit contains an implementation of the spiking neuronal model +and credit assignment process proposed and studied in: + +Gebhardt, William, and Alexander G. Ororbia. "Time-Integrated Spike-Timing- +Dependent Plasticity." arXiv preprint arXiv:2407.10028 (2024). + +

+
+ Visual depiction of the TI-STDP-adapted SNN architecture. +

+ + + +## Running and Analyzing the Model Simulations + +### Unsupervised Digit-Level Biophysical Model + +To run the main TI-STDP SNN experiments of the paper, simply execute: + +```console +$ ./train_models.sh 0 tistdp snn_case1 +``` + +which will trigger three experimental trials for adaptation of the +`Case 1` model described in the paper on MNIST. If you want to train +the online, `Case 2` model described in the paper on MNIST, you simply +need to change the third argument to the Bash script like so: + +```console +$ ./train_models.sh 0 tistdp snn_case2 +``` + +Independent of whichever case-study you select above, you can analyze the +trained models, in accordance with what was done in the paper, by executing +the analysis bash script as follows: + +```console +$ ./analyze_models.sh 0 tistdp ## run on GPU 0 the "tistdp" config +``` + +Task: Models under this section engage in unsupervised representation +learning and jointly learn, through spike-timing driven credit assignment, +a low-level and higher-level abstract distributed, discrete representations +of sensory input data. In this exhibit, this is particularly focused on +using patterns in the MNIST database. + +### Part-Whole Assembly SNN Model + +To run the patch-model SNN adapted with TI-STDP, enter the `patch_model/` +sub-directory and then execute: + +```console +$ ./train_patch_models.sh +``` + +which will run a single trial to produce the SNN generative assembly +(or part-whole hierarchical) model constructed in the paper. + +Task: This biophysical model engages in a form unsupervised +representation learning that is focused on learning a simple bi-level +part-whole hierarchy of sensory input data (in this exhibit, the focus is on +using data from the MNIST database). + +## Model Descriptions, Hyperparameters, and Configuration Details + +Model explanations, meta-parameters settings and experimental details are +provided in the above reference paper. diff --git a/exhibits/time-integrated-stdp/analyze_models.sh b/exhibits/time-integrated-stdp/analyze_models.sh new file mode 100755 index 0000000..5d4db3b --- /dev/null +++ b/exhibits/time-integrated-stdp/analyze_models.sh @@ -0,0 +1,74 @@ +#!/bin/bash + +## get in user-provided program args +GPU_ID=$1 #1 +MODEL=$2 # evstdp trstdp tistdp + +if [[ "$MODEL" != "evstdp" && "$MODEL" != "trstdp" && "$MODEL" != "tistdp" ]]; then + echo "Invalid Arg: $MODEL -- only 'evstdp', 'trstdp', 'tistdp' models supported!" + exit 1 +fi +echo " >>>> Setting up $MODEL on GPU $GPU_ID" + +SEEDS=(1234 77 811) + +PARAM_SUBDIR="/custom" +DISABLE_ADAPT_AT_EVAL=False ## set to true to turn off eval-time adaptive thresholds +MAKE_CLUSTER_PLOT=False #True +REBIND_LABELS=0 ## rebind labels to train model? + +N_SAMPLES=50000 +DATA_X="../../data/mnist/trainX.npy" +DATA_Y="../../data/mnist/trainY.npy" +DEV_X="../../data/mnist/testX.npy" # validX.npy +DEV_Y="../../data/mnist/testY.npy" # validY.npy +EXTRACT_TRAINING_SPIKES=0 # set to 1 if you want to extract training set codes + +for seed in "${SEEDS[@]}" +do + EXP_DIR="exp_$MODEL""_$seed/" + echo " > Running Simulation/Model: $EXP_DIR" + + CODEBOOK=$EXP_DIR"training_codes.npy" + TEST_CODEBOOK=$EXP_DIR"test_codes.npy" + PLOT_FNAME=$EXP_DIR"codes.jpg" + + if [[ $REBIND_LABELS == 1 ]]; then + CUDA_VISIBLE_DEVICES=$GPU_ID python bind_labels.py --dataX=$DATA_X --dataY=$DATA_Y \ + --model_type=$MODEL \ + --model_dir=$EXP_DIR$MODEL \ + --n_samples=$N_SAMPLES \ + --exp_dir=$EXP_DIR \ + --disable_adaptation=$DISABLE_ADAPT_AT_EVAL \ + --param_subdir=$PARAM_SUBDIR + fi + + ## eval model +# CUDA_VISIBLE_DEVICES=$GPU_ID python eval.py --dataX=$DEV_X --dataY=$DEV_Y \ +# --model_type=$MODEL --model_dir=$EXP_DIR$MODEL \ +# --label_fname=$EXP_DIR"binded_labels.npy" \ +# --exp_dir=$EXP_DIR \ +# --disable_adaptation=$DISABLE_ADAPT_AT_EVAL \ +# --param_subdir=$PARAM_SUBDIR \ +# --make_cluster_plot=$MAKE_CLUSTER_PLOT + ## call codebook extraction processes + if [[ $EXTRACT_TRAINING_SPIKES == 1 ]]; then + CUDA_VISIBLE_DEVICES=$GPU_ID python extract_codes.py --dataX=$DATA_X \ + --n_samples=$N_SAMPLES \ + --codebook_fname=$CODEBOOK \ + --model_type=$MODEL \ + --model_fname=$EXP_DIR$MODEL \ + --disable_adaptation=False \ + --param_subdir=$PARAM_SUBDIR + fi + CUDA_VISIBLE_DEVICES=$GPU_ID python extract_codes.py --dataX=$DEV_X \ + --codebook_fname=$TEST_CODEBOOK \ + --model_type=$MODEL \ + --model_fname=$EXP_DIR$MODEL \ + --disable_adaptation=False \ + --param_subdir=$PARAM_SUBDIR + ## visualize latent codes + CUDA_VISIBLE_DEVICES=$GPU_ID python viz_codes.py --plot_fname=$PLOT_FNAME \ + --codes_fname=$TEST_CODEBOOK \ + --labels_fname=$DEV_Y +done diff --git a/exhibits/time-integrated-stdp/bind.sh b/exhibits/time-integrated-stdp/bind.sh new file mode 100755 index 0000000..d08f116 --- /dev/null +++ b/exhibits/time-integrated-stdp/bind.sh @@ -0,0 +1,23 @@ +#!/bin/bash +GPU_ID=1 #0 + +N_SAMPLES=10000 +DISABLE_ADAPT_AT_EVAL=False + +EXP_DIR="exp_trstdp/" +MODEL="trstdp" +#EXP_DIR="exp_evstdp/" +#MODEL="evstdp" +DEV_X="../../data/mnist/trainX.npy" # validX.npy +DEV_Y="../../data/mnist/trainY.npy" # validY.npy +PARAM_SUBDIR="/custom_snapshot2" +#PARAM_SUBDIR="/custom" + +## eval model +CUDA_VISIBLE_DEVICES=$GPU_ID python bind_labels.py --dataX=$DEV_X --dataY=$DEV_Y \ + --model_type=$MODEL \ + --model_dir=$EXP_DIR$MODEL \ + --n_samples=$N_SAMPLES \ + --exp_dir=$EXP_DIR \ + --disable_adaptation=$DISABLE_ADAPT_AT_EVAL \ + --param_subdir=$PARAM_SUBDIR diff --git a/exhibits/time-integrated-stdp/bind_labels.py b/exhibits/time-integrated-stdp/bind_labels.py new file mode 100755 index 0000000..2e7a339 --- /dev/null +++ b/exhibits/time-integrated-stdp/bind_labels.py @@ -0,0 +1,137 @@ +from jax import numpy as jnp, random +import sys, getopt as gopt, optparse, time + +import matplotlib #.pyplot as plt +matplotlib.use('Agg') +import matplotlib.pyplot as plt +cmap = plt.cm.jet + +################################################################################ +# read in general program arguments +options, remainder = gopt.getopt(sys.argv[1:], '', ["dataX=", "dataY=", + "model_dir=", "model_type=", + "exp_dir=", "n_samples=", + "disable_adaptation=", + "param_subdir="]) + +model_case = "snn_case1" +disable_adaptation = True +exp_dir = "exp/" +param_subdir = "/custom" +model_type = "tistdp" +model_dir = "exp/tistdp" +dataX = "../../data/mnist/trainX.npy" +dataY = "../../data/mnist/trainY.npy" +n_samples = 10000 +verbosity = 0 ## verbosity level (0 - fairly minimal, 1 - prints multiple lines on I/O) +for opt, arg in options: + if opt in ("--dataX"): + dataX = arg.strip() + elif opt in ("--dataY"): + dataY = arg.strip() + elif opt in ('--model_dir'): + model_dir = arg.strip() + elif opt in ('--model_type'): + model_type = arg.strip() + elif opt in ('--exp_dir'): + exp_dir = arg.strip() + elif opt in ('--param_subdir'): + param_subdir = arg.strip() + elif opt in ('--n_samples'): + n_samples = int(arg.strip()) + elif opt in ('--disable_adaptation'): + disable_adaptation = (arg.strip().lower() == "true") + print(" > Disable short-term adaptation? ", disable_adaptation) + +if model_case == "snn_case1": + print(" >> Setting up Case 1 model!") + from snn_case1 import load_from_disk, get_nodes +elif model_case == "snn_case2": + print(" >> Setting up Case 2 model!") + from snn_case2 import load_from_disk, get_nodes +else: + print("Error: No other model case studies supported! (", model_case, " invalid)") + exit() + +print(">> X: {} Y: {}".format(dataX, dataY)) + +dkey = random.PRNGKey(1234) +dkey, *subkeys = random.split(dkey, 3) + +## load dataset +_X = jnp.load(dataX) +_Y = jnp.load(dataY) +if 0 < n_samples < _X.shape[0]: + ptrs = random.permutation(subkeys[0], _X.shape[0])[0:n_samples] + _X = _X[ptrs, :] + _Y = _Y[ptrs, :] + # _X = _X[0:n_samples, :] + # _Y = _Y[0:n_samples, :] + print("-> Binding {} first randomly selected samples to model".format(n_samples)) +n_batches = _X.shape[0] ## num batches is = to num samples (online learning) + +## basic simulation hyper-parameter/configuration values go here +viz_mod = 1000 #10000 +mb_size = 1 ## locked to batch sizes of 1 +patch_shape = (28, 28) +in_dim = patch_shape[0] * patch_shape[1] + +T = 250 #300 ## num time steps to simulate (stimulus presentation window length) +dt = 1. ## integration time constant + +################################################################################ +print("--- Loading Model ---") + +## Load in model +model = load_from_disk(model_dir, param_dir=param_subdir, + disable_adaptation=disable_adaptation) +nodes, node_map = get_nodes(model) + +################################################################################ +print("--- Starting Binding Process ---") + +print("------------------------------------") +model.showStats(-1) + +## enter main adaptation loop over data patterns +class_responses = jnp.zeros((_Y.shape[1], node_map.get("z2e").n_units)) +num_bound = 0 +n_total_samp_seen = 0 +tstart = time.time() +n_samps_seen = 0 +for j in range(n_batches): + idx = j + Xb = _X[idx: idx + mb_size, :] + Yb = _Y[idx: idx + mb_size, :] + + model.reset() + model.clamp(Xb) + spikes1, spikes2 = model.infer( + jnp.array([[dt * k, dt] for k in range(T)])) + ## bind output spike train(s) + responses = Yb.T * jnp.sum(spikes2, axis=0) + class_responses = class_responses + responses + num_bound += 1 + + n_samps_seen += Xb.shape[0] + n_total_samp_seen += Xb.shape[0] + print("\r Binding {} images...".format(n_samps_seen), end="") +tend = time.time() +print() +sim_time = tend - tstart +sim_time_hr = (sim_time/3600.0) # convert time to hours +print(" -> Binding.Time = {} s".format(sim_time_hr)) +print("------------------------------------") + +## compute max-frequency (~firing rate) spike responses +class_responses = jnp.argmax(class_responses, axis=0, keepdims=True) +print("---- Max Class Responses ----") +print(class_responses) +print(class_responses.shape) +bind_fname = "{}binded_labels.npy".format(exp_dir) +print(" >> Saving label bindings to: ", bind_fname) +jnp.save(bind_fname, class_responses) + + + + diff --git a/exhibits/time-integrated-stdp/custom/__init__.py b/exhibits/time-integrated-stdp/custom/__init__.py new file mode 100755 index 0000000..0129e54 --- /dev/null +++ b/exhibits/time-integrated-stdp/custom/__init__.py @@ -0,0 +1 @@ +from .ti_stdp_synapse import TI_STDP_Synapse diff --git a/exhibits/time-integrated-stdp/custom/ti_stdp_synapse.py b/exhibits/time-integrated-stdp/custom/ti_stdp_synapse.py new file mode 100755 index 0000000..439db30 --- /dev/null +++ b/exhibits/time-integrated-stdp/custom/ti_stdp_synapse.py @@ -0,0 +1,98 @@ +import time + +from ngclearn import Component, Compartment, resolver +from ngclearn.components import DenseSynapse +from ngclearn.utils.model_utils import normalize_matrix + +from jax import numpy as jnp, random +import os.path + +class TI_STDP_Synapse(DenseSynapse): + def __init__(self, name, shape, alpha=0.0075, beta=0.5, pre_decay=0.5, + resist_scale=1., p_conn=1., weight_init=None, **kwargs): + super().__init__(name=name, shape=shape, resist_scale=resist_scale, + p_conn=p_conn, weight_init=weight_init, **kwargs) + + #Params + self.batch_size = 1 + self.alpha = alpha + self.beta = beta + self.pre_decay = pre_decay + self.Aplus = 1 + self.Aminus = 1 + + #Compartments + self.pre = Compartment(None) + self.post = Compartment(None) + + self.reset() + + @staticmethod + def _norm(weights, norm_scale): + return normalize_matrix(weights, wnorm=100, scale=norm_scale) + + @resolver(_norm) + def norm(self, weights): + self.weights.set(weights) + + + @staticmethod + def _reset(batch_size, shape): + return jnp.zeros((batch_size, shape[0])), \ + jnp.zeros((batch_size, shape[1])), \ + jnp.zeros((batch_size, shape[0])), \ + jnp.zeros((batch_size, shape[1])) + + @resolver(_reset) + def reset(self, inputs, outputs, pre, post): + self.inputs.set(inputs) + self.outputs.set(outputs) + self.pre.set(pre) + self.post.set(post) + + @staticmethod + def _evolve(pre, post, weights, + shape, alpha, beta, pre_decay, Aplus, Aminus, + t, dt): + pre_size = shape[0] + post_size = shape[1] + + pre_synaptic_binary_mask = jnp.where(pre > 0, 1., 0.) + post_synaptic_binary_mask = jnp.where(post > 0, 1., 0.) + + broadcast_pre_synaptic_binary_mask = jnp.zeros(shape) + jnp.reshape(pre_synaptic_binary_mask, (pre_size, 1)) + broadcast_post_synaptic_binary_mask = jnp.zeros(shape) + jnp.reshape(post_synaptic_binary_mask, (1, post_size)) + + broadcast_pre_synaptic_time = jnp.zeros(shape) + jnp.reshape(pre, (pre_size, 1)) + broadcast_post_synaptic_time = jnp.zeros(shape) + jnp.reshape(post, (1, post_size)) + + no_pre_synaptic_spike_mask = broadcast_post_synaptic_binary_mask * (1. - broadcast_pre_synaptic_binary_mask) + no_pre_synaptic_weight_update = (-alpha) * (pre_decay / jnp.exp((1 / dt) * (t - broadcast_post_synaptic_time))) + # no_pre_synaptic_weight_update = no_pre_synaptic_weight_update + + # Both have spiked + both_spike_mask = broadcast_post_synaptic_binary_mask * broadcast_pre_synaptic_binary_mask + both_spike_update = (-alpha / (broadcast_pre_synaptic_time - broadcast_post_synaptic_time - (0.5 * dt))) * \ + (beta / jnp.exp((1 / dt) * (t - broadcast_post_synaptic_time))) + + masked_no_pre_synaptic_weight_update = no_pre_synaptic_spike_mask * no_pre_synaptic_weight_update + masked_both_spike_update = both_spike_mask * both_spike_update + + plasticity = jnp.where(masked_both_spike_update > 0, + Aplus * masked_both_spike_update, + Aminus * masked_both_spike_update) + + decay = masked_no_pre_synaptic_weight_update + + plasticity = plasticity * (1 - weights) + + decay = decay * weights + + update = plasticity + decay + _W = weights + update + + return jnp.clip(_W, 0., 1.) + + @resolver(_evolve) + def evolve(self, weights): + self.weights.set(weights) diff --git a/exhibits/time-integrated-stdp/eval.py b/exhibits/time-integrated-stdp/eval.py new file mode 100755 index 0000000..e4abb5c --- /dev/null +++ b/exhibits/time-integrated-stdp/eval.py @@ -0,0 +1,158 @@ +from jax import numpy as jnp, random, nn, jit +import numpy as np, time +import sys, getopt as gopt, optparse +## bring in ngc-learn analysis tools +from ngclearn.utils.viz.raster import create_raster_plot +import ngclearn.utils.metric_utils as metrics + +import matplotlib #.pyplot as plt +matplotlib.use('Agg') +import matplotlib.pyplot as plt +cmap = plt.cm.jet + +def plot_clusters(W1, W2, figdim=50): ## generalized dual-layer plotting code + n_input = W1.weights.value.shape[0] ## input layer + n_hid1 = W1.weights.value.shape[1] ## layer size 1 + n_hid2 = W2.weights.value.shape[1] ## layer size 2 + ndim0 = int(jnp.sqrt(n_input)) ## sqrt(layer size 0) + ndim1 = int(jnp.sqrt(n_hid1)) ## sqrt(layer size 1) + ndim2 = int(jnp.sqrt(n_hid2)) ## sqrt(layer size 2) + plt.figure(figsize=(figdim, figdim)) + plt.subplots_adjust(hspace=0.1, wspace=0.1) + + for q in range(W2.weights.value.shape[1]): + masked = W1.weights.value * W2.weights.value[:, q] + + dim = ((ndim0 * ndim1) + (ndim1 - 1)) #(28 * 10) + (10 - 1) + + full = jnp.ones((dim, dim)) * jnp.amax(masked) + + for k in range(n_hid1): + r = k // ndim1 #k // 10 #sqrt(hidden layer size) + c = k % ndim1 # k % 10 + + full = full.at[(r * (ndim0 + 1)):(r + 1) * ndim0 + r, + (c * (ndim0 + 1)):(c + 1) * ndim0 + c].set( + jnp.reshape(masked[:, k], (ndim0, ndim0))) + + plt.subplot(ndim2, ndim2, q + 1) # 5 = sqrt(output layer size) + plt.imshow(full, cmap=plt.cm.bone, interpolation='nearest') + plt.axis("off") + + plt.subplots_adjust(top=0.9) + plt.savefig("{}clusters.jpg".format(exp_dir), bbox_inches='tight') + plt.clf() + plt.close() + +# read in general program arguments +options, remainder = gopt.getopt(sys.argv[1:], '', ["dataX=", "dataY=", + "model_dir=", "model_type=", + "label_fname=", "exp_dir=", + "param_subdir=", + "disable_adaptation=", + "make_cluster_plot="]) + +model_case = "snn_case1" +exp_dir = "exp/" +label_fname = "exp/labs.npy" +model_type = "tistdp" +model_dir = "exp/tistdp" +param_subdir = "/custom" +disable_adaptation = True +dataX = "../../data/mnist/trainX.npy" +dataY = "../../data/mnist/trainY.npy" +make_cluster_plot = True +verbosity = 0 ## verbosity level (0 - fairly minimal, 1 - prints multiple lines on I/O) +for opt, arg in options: + if opt in ("--dataX"): + dataX = arg.strip() + elif opt in ("--dataY"): + dataY = arg.strip() + elif opt in ('--model_dir'): + model_dir = arg.strip() + elif opt in ('--model_type'): + model_type = arg.strip() + elif opt in ('--label_fname'): + label_fname = arg.strip() + elif opt in ('--exp_dir'): + exp_dir = arg.strip() + elif opt in ('--param_subdir'): + param_subdir = arg.strip() + elif opt in ('--disable_adaptation'): + disable_adaptation = (arg.strip().lower() == "true") + print(" > Disable short-term adaptation? ", disable_adaptation) + elif opt in ('--make_cluster_plot'): + make_cluster_plot = (arg.strip().lower() == "true") + print(" > Make cluster plot? ", make_cluster_plot) + +if model_case == "snn_case1": + print(" >> Setting up Case 1 model!") + from snn_case1 import load_from_disk, get_nodes +elif model_case == "snn_case2": + print(" >> Setting up Case 2 model!") + from snn_case2 import load_from_disk, get_nodes +else: + print("Error: No other model case studies supported! (", model_case, " invalid)") + exit() + +print(">> X: {} Y: {}".format(dataX, dataY)) + +T = 250 # 300 +dt = 1. + +## load dataset +batch_size = 1 #100 +_X = jnp.load(dataX) +_Y = jnp.load(dataY) +n_batches = int(_X.shape[0]/batch_size) +patch_shape = (28, 28) + +dkey = random.PRNGKey(time.time_ns()) +model = load_from_disk(model_dir, param_dir=param_subdir, + disable_adaptation=disable_adaptation) +nodes = model.get_components("W1", "W1ie", "W1ei", "z0", "z1e", "z1i", + "W2", "W2ie", "W2ei", "z2e", "z2i") +W1, W1ie, W1ei, z0, z1e, z1i, W2, W2ie, W2ei, z2e, z2i = nodes + +if make_cluster_plot: + ## plot clusters formed by 2nd layer of model's spikes + print(" >> Creating model cluster plot...") + plot_clusters(W1, W2) + +## extract label bindings +bindings = jnp.load(label_fname) + +acc = 0. +Ns = 0. +yMu = [] +for i in range(n_batches): + Xb = _X[i * batch_size:(i+1) * batch_size, :] + Yb = _Y[i * batch_size:(i+1) * batch_size, :] + Ns += Xb.shape[0] + model.reset() + model.clamp(Xb) + spikes1, spikes2 = model.infer( + jnp.array([[dt * k, dt] for k in range(T)])) + winner = jnp.argmax(jnp.sum(spikes2, axis=0)) + yHat = nn.one_hot(bindings[:, winner], num_classes=Yb.shape[1]) + yMu.append(yHat) + acc = metrics.measure_ACC(yHat, Yb) + acc + print("\r Acc = {} (over {} samples)".format(acc/Ns, i+1), end="") +print() +print("===============================================") +yMu = jnp.concatenate(yMu, axis=0) +conf_matrix, precision, recall, misses, acc, adj_acc = metrics.analyze_scores(yMu, _Y) +print(conf_matrix) +print("---") +print(" >> Number of Misses = {}".format(misses)) +print(" >> Acc = {} Precision = {} Recall = {}".format(acc, precision, recall)) +msg = "{}".format(conf_matrix) +jnp.save("{}confusion.npy".format(exp_dir), conf_matrix) +msg = ("{}\n---\n" + "Number of Misses = {}\n" + "Acc = {} Adjusted-Acc = {}\n" + "Precision = {}\nRecall = {}\n").format(msg, misses, acc, adj_acc, precision, recall) +fd = open("{}scores.txt".format(exp_dir), "w") +fd.write(msg) +fd.close() + diff --git a/exhibits/time-integrated-stdp/eval.sh b/exhibits/time-integrated-stdp/eval.sh new file mode 100755 index 0000000..0fd21f7 --- /dev/null +++ b/exhibits/time-integrated-stdp/eval.sh @@ -0,0 +1,23 @@ +#!/bin/bash +GPU_ID=1 #0 + +EXP_DIR="exp_trstdp/" +MODEL="trstdp" +#EXP_DIR="exp_evstdp/" +#MODEL="evstdp" +PARAM_SUBDIR="/custom_snapshot2" +#PARAM_SUBDIR="/custom" +DISABLE_ADAPT_AT_EVAL=False ## set to true to turn off eval-time adaptive thresholds +MAKE_CLUSTER_PLOT=False + +DEV_X="../../data/mnist/testX.npy" # validX.npy +DEV_Y="../../data/mnist/testY.npy" # validY.npy + +# eval model +CUDA_VISIBLE_DEVICES=$GPU_ID python eval.py --dataX=$DEV_X --dataY=$DEV_Y \ + --model_type=$MODEL --model_dir=$EXP_DIR$MODEL \ + --label_fname=$EXP_DIR"binded_labels.npy" \ + --exp_dir=$EXP_DIR \ + --disable_adaptation=$DISABLE_ADAPT_AT_EVAL \ + --param_subdir=$PARAM_SUBDIR \ + --make_cluster_plot=$MAKE_CLUSTER_PLOT diff --git a/exhibits/time-integrated-stdp/extract_codes.py b/exhibits/time-integrated-stdp/extract_codes.py new file mode 100755 index 0000000..43f3e42 --- /dev/null +++ b/exhibits/time-integrated-stdp/extract_codes.py @@ -0,0 +1,86 @@ +from jax import numpy as jnp, random, nn, jit +import numpy as np, time +import sys, getopt as gopt, optparse +## bring in ngc-learn analysis tools +from ngclearn.utils.viz.raster import create_raster_plot +from ngclearn.utils.metric_utils import measure_ACC + +# read in general program arguments +options, remainder = gopt.getopt(sys.argv[1:], '', ["dataX=", + "codebook_fname=", + "model_fname=", + "n_samples=", "model_type=", + "param_subdir=", + "disable_adaptation="]) + +model_case = "snn_case1" +n_samples = -1 +model_type = "tistdp" +model_fname = "exp/tistdp" +param_subdir = "/custom" +disable_adaptation = True +codebook_fname = "codes.npy" +dataX = "../../data/mnist/trainX.npy" +for opt, arg in options: + if opt in ("--dataX"): + dataX = arg.strip() + elif opt in ("--codebook_fname"): + codebook_fname = arg.strip() + elif opt in ("--model_fname"): + model_fname = arg.strip() + elif opt in ("--n_samples"): + n_samples = int(arg.strip()) + elif opt in ('--model_type'): + model_type = arg.strip() + elif opt in ('--param_subdir'): + param_subdir = arg.strip() + elif opt in ('--disable_adaptation'): + disable_adaptation = (arg.strip().lower() == "true") + print(" > Disable short-term adaptation? ", disable_adaptation) + +if model_case == "snn_case1": + print(" >> Setting up Case 1 model!") + from snn_case1 import load_from_disk, get_nodes +elif model_case == "snn_case2": + print(" >> Setting up Case 2 model!") + from snn_case2 import load_from_disk, get_nodes +else: + print("Error: No other model case studies supported! (", model_case, " invalid)") + exit() + +print(">> X: {}".format(dataX)) + +## load dataset +batch_size = 1 #100 +_X = jnp.load(dataX) +if 0 < n_samples < _X.shape[0]: + _X = _X[0:n_samples, :] +n_batches = int(_X.shape[0]/batch_size) +patch_shape = (28, 28) + +dkey = random.PRNGKey(time.time_ns()) +model = load_from_disk(model_directory=model_fname, param_dir=param_subdir, + disable_adaptation=disable_adaptation) + +T = 250 # 300 +dt = 1. + +codes = [] ## get latent (spike) codes +acc = 0. +Ns = 0. +for i in range(n_batches): + Xb = _X[i * batch_size:(i + 1) * batch_size, :] + Ns += Xb.shape[0] + model.reset() + model.clamp(Xb) + spikes1, spikes2 = model.infer( + jnp.array([[dt * k, dt] for k in range(T)])) + counts = jnp.sum(spikes2, axis=0) ## get counts + codes.append(counts) + print("\r > Processed ({} samples)".format(Ns), end="") +print() + +print(" >> Saving code-book to disk: {}".format(codebook_fname)) +codes = jnp.concatenate(codes, axis=0) +print(" >> Code.shape = ", codes.shape) +jnp.save(codebook_fname, codes) diff --git a/exhibits/time-integrated-stdp/fig/tistdp_snn.jpg b/exhibits/time-integrated-stdp/fig/tistdp_snn.jpg new file mode 100755 index 0000000000000000000000000000000000000000..d0db7b20165a15e9fcb7cd1f8ca093c20a7d35ee GIT binary patch literal 28511 zcmb??1yEa2w{DQ)Qrw+Fad!#O;suI3l;Tj_HAr!X0)@7?dvTZI?obHsPJ#vtkNK1Kt79P*6}& zUZJ6)qN1atVPFwrV_{-qk>L~I5>k=VP*agpQqnSVGSkwtGf+~#6=GrM=Hch(r(qV6 z5at!<78zI14K4M0RdLPA7Fdi4q!`K5Qj%l80eyjS>iJklrx z>Sn0)E`+?nNx5hYGQWF>G^Qbpe4kxI&@qTfNXf|GFfqSnVdWPP6cQE@efM5gPF_J# zNmEN(M^{hZz}&*p%G$=(&duG!)63h(H}q>*ctm7WbaKkK)U@>P8JT(c1%*Y$C8cGx zb@dI6P0cN>e|r1+2L^|RM`mW{<`)+KE-i0vZSU;v?H?Q-K`$<^u5WH(clZCmg#bYM zpRis&|0l42!G-sN3lSL^2^sYtxDXJ%UJMBj`4t@x3cj>Ds+kJ`J#R1?p-fWl?;dmp zJ`D)bXV+;AVn+VWH_(4T`zNyh9a_Wo^N5XOIM z71(|QS=5l(3WK6TN?q*RPf`|%`^6(U*%S4JN`?BanlVnbb1H=`U9wuU0%#t_f|w6x zwBa6WA@AS*y+uvZIeJ4vymkBTVw9XPcJRI$b!)I8KG?JOB6^o@(dqMMh9%8u0;Cen zUmnwipEF`!RaSt1D2=$l*9e?f_O4Wa@|Kf8GS(Sn(4_14UAUq3W$3~GZrJ|{<3BjY z;vHbb%OF(HYMk-uK5{u6DgqlDu7Ty2rQO|qBqBl&p$(S8qQ+notydjbebr@JN z?~taBb8#cy^b}0FRrM|jQ8?_I(6HMP)t()%MmM_TJb@eGPU-CBV%&YUB>m(g*)VCR zQ>@YwM(;{7it4?Ql9E}I88G&-*M*~p4&q;{YBynFib_k)0_a9eYf*D?DAAf)T`_zH zlx58Rr3SKwMspxLJFY#Di#D*5PUrK7pvUOxvH zc$vHTVHSurtC_qGC)djfUbWdx6*cr2r%}4{M-RAbZ7*wfXwct3KhwjSi%igrMZQu= zpk|(2D5SD1Ela3L`!+8c@L$=kziq{lFTc*^U%E7+cw5Kq;FMED0i@f=D#2!>i*|yf zP<`Djv6j{(`q|25!Wf-~I-ecO(eO-u51#A{*XqOo3AVPhHL{z#sX?aM)U}+A2di?^ zXMnZ7!O9WIuoEQ@^Dps<(4bJU1gdZwq!3BL1&($EQ28C-s+}cMCc8meQ<#oU!g*Pc zXBfM*i12OsazxBEHfA4*pfs?dVgZv_Uq6jEc~T_8*3#&;93_@@^=F<&aYZvrOG^gZ zMbE~?WLZko;g1C)M=^=Gqsd3k&OV1d6MU2X{coPr*o+`I)XHO%?xg~w|( z7`fHI@HT~2-10L9X7Zpo4*LBRDho{?TM^`eFm$Yi+++Tv?#EYb#3D!rwaLubh)k@z z4x`I>3Y1BqsO9S+w$Ren_i)3cR=H?WOA!hHAgVPa@Ni?V)x~8q7h?szKKpH2tYWJ~ z!7~7VAG+f=U*18W(pjIm?DzJACY}6p5Prd@*oSzDK|(;LV6IE)--hx26a%4O>^8A| zXR@xjcYvU7M@o14dNwiT8TWOZ!$~J|WSz#caW-V__oyv@g34!W5L;BLOv1KLcv>qmS4=eA(D1m0p#76Zm90 zata<&H`Z6veg-63JOiX3`(3pcj=CLNS?&~T$to^<;HEI^A248CfxcmD3#M}bU6xc7$zQqO>~8HCf`riEk? z#wOQk%FuZgtES#_N|ra`0m#MhOdDtwh^fNXUkr+V&h)zxit@wrGfjs_-7*PSvhz=^ z`Zv|Xg+g`xBKLF^;zCAO!r|DB(=VYXc2-<$=ie^>;)Jx&oJ>ZhvC==7gbhk_w3la( zqyX_!n!+hZJe;4Oh{27I#~W5j^ndg=QPV-=>)gBE^4K<>)}S~N`2_`dGb5#_ zK%FA1IK|!aR*pqE`hOC0cl=o%IqDx;SweQ1Ck!M?**Ygx+kS`xD$jn|jUpN)T@B`h z9l89)f?G9y{gMF0(;m?=BsU?1+?0E4&MiE$R2Ul=%bg@xiw7r-+)My=@On(1{1$V) znFq2VNfq~jxu{D5<&vDV3>3e6KH+w$gA)ukBq_YRlp*_J)dZT&EI-Buq6;%>Lxric zIh4J&IQTSSNt+$kGQ8;4%?}(CJK$zZUFOAIw>w-V*JzteB1`l)Q@0(caG5ChF^KhO zLW&V9TWQxF_YQ;WA-pv|`6V@3NR^1`e+InXuTLO9Sgy)2_rB#>5;*h!&RL+H)aA1q zk04=)1f{MAt5MkMHZpntxGi<(C}{LC*MFTch7hlZMXVt&!VtvkqW{~Ue>ShawvjyS zP+7RoMs+Na6Mw^c_qYRaZ{5Xz1#>Ab2Y+tJvlG3j-_7 z@jv*L7~(67N-dnT+cR43@PyMgeJD34VDYQ%`lQq_DBqidC zC%PQ#)jW#MY}#8Jme<}~$Uy6eHsr-rNzSe~_az{J?POyI$fZ(4a{t2CY3J}RGS=uN zTA{5uDu(2-$o*{&b_|H%K-U7(!DQr1x$woF&d^7)$esfv!~k9vL!&T%ZR+41CF|tz zYUK7WXjQM{TPE~8E80M)(zd7(Ul_*HGWp0De(ivH913L;NnRMJk|U^by2DT+xy7hHeDLM>p75tTN?0TAhTnVzewBzLGI~%=Mux z#o$|CYc1JU_^gosu_T&E%N^iO2srwW(u3l1+qC6{B%i`HHVXS@12V*)R2AILe&l_y zP{dC6Krzl8qlBUW%5{7`VLe&uN~Cxl@Q2Q;>)y%pEekH65sPa18j6P)Fj$_qbu%< zW8e66H>IsqKE6I7Zi@^qkeWg1(R`@gxp3BPdZ#g|0Zk)5x}ex(&kXo}p)p5@pRsBs zJUAlrMFN5&)5s%*k_)-NSFXevK{t6Ib=X#C~TF#A4~gEvxM zK`SJ(;=_+{Xe@{A=DQg5C{!h-5x|LNMOiprhE^EY_ImWdXP>^2Kf}?H-;b=?{Nik< zDXv(hd3jgbEPiFLN~gYDE1%QO)%Fg(;8u7N#jZ8Av$Vup&G_fa5xHwcjd4<_$Cnj0 znw#K_hx-frx+)6wr}cq2vc>Rp#VjGoGRoIL(KGve%Di3jD8nq4$7{87yRC%JS*6ba z+TmxwD!gHEgpUB@q=C_Tc27%VOg9Kc#d|TnUANTMAS@xGp28#DpljE^r%!Qc0yhxX z1fZ#>X`i}MSC}Bz+!Ev5MxppnEUvAhGg= z26U|{njq!p?)uN%9(L{x->a|Nd)t( zIn`{|X1!SpUE`|U;1fB!ulEA6OOm>20Ux)o4h@LjX)N;pBA2+&Dz>(0o^WK=yS&Tn zFBl{n5#tjCVK)zlI0yVRzb5)6neK<8+}-C}JqA=NcY%Wbd^UF~8V#seNw_mF(r+J! zQ361sF>w560Oe7;%^H%Q-lTGz$?5081)j^m2F^@ThM1D3IUiDC{$a~_G!t7RMw`Nn z1SwQab>&UdNb+;h!D#nY0o=71r5c>T5JgEd_3V($SKZiUcA^92Tt@50X+{sMr`yZL zOF!5@2jJ2}18&NzDidN)t}J6d%x2%ZD${bgP;c-KFS#W>0}v(3ryfO0Y%=$(?(@xhYhtb69hk)doBKaFq3rv`!h9V?lLc4ufEe=8d4-J%dS$U z3_z=bFD$GLEQs0bX^_$zmy%<>qW)8)sZ%SEAC{cXKVFFhmob0pHP&`rqj`y(U(i)q z-|K2>tC_Q#Leb6uf-@m)!S61%MYfCBMz6TAy<>Vi@n+B1C8wzi#d+R=9Lus>@%rCt zBylP*=Z^msF04G!xX#jh>zCrJagjiMfB{k*AG(aV>!darFI@)-sFezpW@Q~;BU z($u!@M~Nfjx4*`JYKVfjeI#eZD~DWhK4VTcEsA;(c%Q5oqumyOY^4erVowuJz3+H& zcHT6`M))TR`4_Y>2`fn~jmO z$UOa|j3;q2u_-Ctx26BGZwi~@ao4~~H(+eH(XjHN2W+hsri3up^?q{7 zO#c~v-f(>I3iDcs_O4L?FiBu+33jvHo9li~WuR>7LNnC*O`Q*H6hEFy!y?f`ldg2t z@$I*}*-67sqRgXtunK3}64>4-=Ay}MimktsYk3W3+J7(9r>sZU*3_7A;w3oKxw0RK zc0S}uM)eZ0W!_L=J{YSGP=Hy|>}}KCUAg3_8+N+L{vN{_A)t_L#u#vqi*tT0Qa#Pk z2ngdc)5X)5U%wmH`vL0J$054kXWe^oa4(ns#VTa&bN}wJ@mLEW`%ubumv3?7j$g!=Css~vPhr@phy{=MOZ*rxpibw^&$l@)v0Zf3PKw9b z-FDug-tnT;RfJp^{_@)*ZmM`8FXkQcU&nA4_}Kr%=R2&&g%R5y*mE$e&0Y=hA1f>5 z)deHUVpL-QCw-yS(V|iiwb-j{DqAS;(ub`@&G-AE(_KU~AM)XSGX%p*py##cgwaGc*Kk-tSBHv5c!oUDQvxWjV9{qRDCbgFfzKM zqcz+)mM0|{FuBm4-7RE7l(l8dY~0Y4!WlT|;oRn0@HDWGG#8*Ssr09A?!%tHbd_=? zEGxjW7i|r+iuCxld0Db~Ah6*$^ESSnxx&!IhIf*Ez`vV{{*Bi@AoS#~$W0Y!Yaq(3 z&?Zx1_RCgNkQ&o_rsbZ z%rs1NcMOgey?#<>Bv3JlQM2g@>9)9c9GoRud*2Cc*QAbI=KN))HYqXXV}*TCCTV0& zWskHu5c3qV@#8I3YS9lN{BI%@KN{t8x*yr*-0|XsA6nq4?Qmqi7w;10QEBH7ix_NO z{R{ME!){d1{!Ai51X6)%Z&dQC3Gh?+Wz?mmT{ZObT#F6U6nbjuAksHAhRon$*W>l# zUnxwnCUk4n|1Y)~8IfB~mhj^%maB4>3xijhN7-B#6^D<4BHXI-bqPmozR~G~s&uMMg_@uXfez%n-iqSkpJzMSJ{(t5W{&~K@LRX_I_uDnH z1-Zdngx>>6atv`quidptks#I7kn=}v(|Wm%a+<6~>81xp=Bg#hO`7aqCgqiOE;g-H z@KQnXHf6705hdEpG=)Bc-736#?L;1xtCm#_F~iK(68c)FG_adX*c58nV&C_YZ8 zjDO?(eft*u$xRi?6l%W(!ZEd3480X{AzY39=2+Id@)~#3sJKVv-RUy`qyFd5j;;(O zP`B~#XFz;vwNwr+3yiJWSl+V4n^{lHLj^-8PAj_u$7QfYb0C1sex8bG(Z0k|B68e& zzn@M2k!ah2W=rrF1!-tyRWUh;94|<1-j&yvNa^}PxrD1?qpdN?vMEa-S;!=|zyc3O z4k(q3u@+P+jqCXM3{V`N-wuP^_G{g07EA zu#aMi^d^d38Y#pF`9Xcs!wNOo&V&ML9mc_7sNLGpFq0)KZ*Q2q#kY7CjwBu+eEdS~ zVm%V&dl#*1bS}g;Y4!uZ(ec)%l1|*v?xYOSoMkR$sS!LHD`T@xN;l8C=Z#9N?kxS4 z1BOKcu7{9BwsXiz0$E@z?Id5srSm8KiWOaW!Q+&?L~kSm>XHCN7skaP==d^K6dV7SrnS(6fefa@v9KohP7QKE+?f>59EIdKShDGW}xt z7xA+*CY^G2pywt(^090`KAgd=s_dqazCso%%i}QMyJzj=s=$UAg{tDDz8}^BZLQl!G+b#AcQAIN-4nlB($pZL8x!ZG!y7p*adl}UP(jD{7r%3B;^QGTlf$3D_w zd+xIj6Ph`I;MYT=TP&g{H{e!7bY7(0AwTq!11Cdk{Lz-Ma{ebd#<(3_krQoffE6(6 zQKr$iD36_^bF|nr(xM4Z>;nWbY&Gx;F~&F;)w)9;%CVPg@3F}{+saD{^=Y{I?B+yn{Q?!-2O*zJ0U5;B)8{q+#sg^p018y5`_H5 zp}()p&^gYNElVC1<49HPweLq%s;zN*b^E7471j%gR2CJVq0fizfSeDdbeJH{@46!Sw<6e2I|^BgLli#_`STBikdm$yJB5?fKj`N!V?o&oP0 z6~4!x09HT*DX_)wFfg1oA-_(7CBhrf!M**PZGJ?=c#a40oJGo#g`se)#^~e4u{PI?Dbsy`QB;cP; zj?wum$*9{#s68quJGNeMp84=oU6XI4EfH+QilkRqu7SZeNOPaU*=!g>2=%t&Wa!^^ zYa;UEek#uZOwNQ`PBS3v23QI3B2FVt&(tnRK*NGC zIGW(HvdJvJYHZT`m|Q-U)nS(&VLApGNMXLM*aT)`%oKas19Xrj@n&Gc86sF_n-G2a@Qxor6W+x zOoyhybfv&5#-&1gEQMjohYx^c^UV%e=4>=pO)F8A5 zyS=zngldy&N;Oznm#vUfm+dB)9rx_+D@bA{`(Z0c!gqO%O-#=r$2jp%4EfuWoKLtT zg?S-TWXyn%2qvc6lTl5?%ZZTeP&)WDlyQI5Di%fA)aZs?0%ihF{jt>*DcTg%kPh+c z?OGcPG#uAgWeD@7s{9LTfN`o~duAcQ{0hAZRCtyX>g@<{ugaHDCt0p&R*h~RFKm#8 z#9bhT{z8z>TcbJ0?2k~qmT14W#?Lb04u!fBaz++es$%Zd{!?5pQLVp7k)6ooawYu_ z;5jXDp*E2I@Wsi4aQvn1pZ zMZ7L7KlE%}OYkaOD|)am$z(zOO(oxaByqf_tW1MvXV8$|^U#7;Upu#^UxbR>jk1~W zh1l$Dd!{oLiqmAj*@e4qPZIIk;JpQdFuUK|TV_Cj{a&;CB)+Y$cJ=_Om*21eQ>M2l zlH;I_0Q*Ydhu(V)9k-MZN4YWr^k-f< z;eZe{iGB{}sZ}*cH>|3xZs?s(VOL}o0iw-<`jMojOi63Jw_kwTRqKeefRJ z(`nAnKnd1d5|Ouqo+!Y`c6;00wf@58Wh1-HvC&pBOzAYtr+o>$fN^8U^;uhh6MyDe z3Qh*o=B_MN_xIF6vB+(K1jwbf_#EC`@{$v!J(PH>p$ zg(DeelY4ZXn1++54T3qKhQfEpUFV?f;A8n-p1#1gX$}aNu=cur`q*=GUPbsAiDxn!S=@`*Ri_8}=1xV^-_Jb36>0Mp!B9NYnuBy$d z>0zSt4)e^DswGhmK)1FrrXo;tn29IH77k?W$ltp^9+4a^~oJTkp3!>n%s zw9Pc%YH<%WhUsMqh8d9GzsmT&0T4|w%I-ZUKggnhk)U_ZImpiO82flydEgIyh?m0r zd6_i9mIRHj-hzn}F5ZVNWKdNVgmBfqkLwh0NeZ82rUH>q*qu^+w~NNJs4RuWY{})i zP$&5?rEDJ1m^DkiPoe2Wg$x0x-Hg(ba%@_g;?DRLxEi>pBkQe?JUmzJX@y>>T`8ixMbgx(IpMp-U{Cy z-)R>ljf#voo@Pbi6jwirE;w*_{F1(wKsd`^X0ZHT~wa&g()-8j%>NCGYitH_y)Dy&C?!O@~cn~$n zd*VAP4jBAW>5Amc>;A@`zw!y5TIy7-KN`$*@Z%{r=`MQLrc~2r!Q0OD(7Y{PZ_jwT zwe`MqT(X%zY0F`{jREJ##en$!X+;m`@a;3ezIu;!UUi=x68B}IHhQQDie2%YfbD&N zZ}e>R7TKb^Mt;JLs@<%bd6w2bZWR8@<~;OIjpkoy-G4hbK6m{jrqI&R*pc}@^FYDu zeTX9FiShWY))nPK>ZfC0?1jQAoBLmtJ0BH`?sjz)Vj_6u>*{cR1w54?$&BjP%=y$c z;}4j{Tc&E=eM$ILXi5rnhv3MQ9kF^lIXDX43E?U3w!}5o2!C>xpQC1*oZVD1SmLNm zE19lvHs~jeqc$ez665sgMY$YemE%S`Pb#*Dz|lNn?1$GDw$Qko9Q(Mwi10Ogpfbkh zQd0W5_XFH@$l*UqS2`N%C-(>OKTHo}4)jpgXoe$I4Qwx0j{Dg=z8-hMN2fAZ8I!bf z4D!p|be{~j`m?W=rXnXIIm8jJgI)QMo(jQHMe=E8nD*O`yvh#ewJo(_oXhzRR663W+Fn3-2LP z?Lq!d2Sc~iT|?ylEh9oocQscRG$GR8uza!*#x4IUwawv&<~Iuh_=I%#)kex<^qPkH zN?vt;rR%D=?B)~ag}k8PT5n9@qgcJ>V9pLhu)b(?KgIGK8Yuw9eC_-d=*w<1$1`-VAr|Hwp_Z2*rqNZyL*h zm0l#qj`rf$cCSkSH$59z?nekzJ{1J9g(V3&SI4KEn0l2C+8kz4GV-rKMn{Vx!~-e? zK8=vQ6wjixhpH?(Vh?ccaJ<$`{(X6`vqH^KA67@SE$0`@4Np~#J|LDLj^5#Dw|w25 zytgKO)}vx?2+LOf?!U-05bzi8E_!@6nJ6w3db|!x37l=j515<3`ctu1z=j3ddGpxs zB|6yWgl41C^l|YN5Be)9g)LA4y_yOZvHe5h-9YDSD0N)1oyMEE9=T*ct1wOkOFFzC zEfn|hTRFIS;Kdr>JS0CeqXQ&0fVC z`3%3Im|krBeH`>xBo#mpdPxw3@KTWm#PP@7QX}G6k2`x*b}KYGn%H>Ce09l~;JhR9 z5R()HILKcPl2~3?2CLMBFRT(%7=$m^4ZBo%VcS4DOQB-T`_clKR zf~wRG^^_r?-X<&>gSG{aLCSsgIYyW4ywj?L4-nn3D}O4W)NGE+jBJy}GvIcb?8ycx zWloaW+s}0N%>7B|-J-v=lN*K1)}&fMYz4RLcr1?Z3%4%(qTPv$1-NWbaB$8$5oPWz z3HTWB5n3S1k(_(127x+Y0VFf+Dch299>u@72A4&*)u$~@fi#2Y5>LU5G}e;e5E@%f14G&Bx=C~5o7!icVwn&8*pqGT!Lug_^znE4_g9GQ8VDneUu*Gz$xhd2ne=0~* zM1xvbt%=D&qmYo}B3}EYR0Zq=2`{PVR9ywq+dFRULg=Duoa>##r#i#YDaV_WH4oa* z$+gPH+-u6ZKr=mcd#ld~xG0|hl6RDOhlACu6gbY#cc?1O{=r{o+dSEetmiIkHv#g!U?k>)s8x&R)yOtO3DJJ zYwycOg1mNr#s<})6Uo)aLw>qgq!FD{jY5uMfa zQnlmt;Gf0|pp}UMxYpL&Os&|bUmGAv0NQi016~(S&q14C2&Y{0l_6t^t$Y9|+YoQy-Jh*GJ zj00VODz-bdvO`;)iH*U+$Cj12si}&gp48~1snFgDq z>MjA;6akJmSV*)<0spMTPxOvP?amU6y-)bAD@_^R)X+HK-apmyo2>&gl7 z(xil2x6nJCXvLqig~E<8FGLC-_rKDM)oe}#-l1}#`hM^2TX4!jyO<4WpMJ|bT10(B zA=U9%=IL1-6j2ra(+YZ=lv-o-3nDXj>wZ;2;-V%ixOZRw$iA!NCBx)SDx)Nvp zW`)`pJXY*mJs#pLN)^}i@OS&AXcxH23T4zyk8Dih9Sx*!vOV3!>nE$%CwYUw$e&^S zS6Jj4pbU%_mrJ`aVic5Ct$y>>#=+vvXHBLz=Iz#Wnl+TyJLVq#4x4*XjD{C~$e;R58WB=9ro-62|sXA!ZP=k zv6qd@RCq#f96mR)(P@Y2$a)X4V{~qY$RP{~iyR1#xq5?@jhqZ-+H4eHkY9o{F`0W6ba z#ZQZ>u76Ov^$b*DJInZ)n$2@Sl(dBakHUKf6lG5>)bh5nmA=xmz+2HVjYPdiq`9{N zCRMZ2;PZVdcwCJEr@+b-a7q=mO6m=Z!6+@JMa*idqpU)k~5KS{(MMzIi+@NXF!s zs*|}o#t1mTv}P`nVt`fvV?;m3^cg7q1qV|u#31wZwc}SvVW8G==(R&uzmfA(Q%#7Qz<#>i)2j#GiQQIM}(smi_42YRf6WB8QKr z9iYcY11Wg~<{BK56QyNV7*}p#j0vx->Lrv8SJrR?sLE@_2+ah4a11{BizWKWBNiWP z$mcD;(Dw-}jFQz%z*GnI4_J!zm=vjC&I}4%xx~!N^{6$1P&G)?OWTR6aPM@xmxj>#vy&c5t8voUVd5JY%P zGHs_W@8;Unq{8od-D1l4Tzd#-fb{?6Wr^2gSjcPsjEhX9nR!YJ^TKJn^!_nG-q~z;Zphwc`I4}L zUpQa#<9bt6m|JfS$l{<;r%Lz|pVs2hXg6VBzrl<~0QRoG&RfgGCiG8wYe9;F)$Dfm zSp56gon&?C6>fy>rz(4A?|28rgd2F}ODwV4vYGRHYBwqJHd#tRazob(TKulHsd2X5 zEkCX`D}uHDD{dFg89awzb~>|zq_=gqI?7+dykd#6$Ln<>)=ufcNI(xVO(DSl4&tT- zVam=`^;aoj7Y%ANr{q(48#(x!+Y>kO=NV-z)MVCjn+8CHXKrEL@t136{;O~KhH4ooSRPd9b!#9W!l|gwkIj1cgpw5xlBh>0%S?}GAOcgte>jWz zxaETwOaiuzd2tC~bsr!)KN!%c0$0nL9Yzva?lh0a&kyVdBmG%6RDg2r^N#IAPE?ik zR`Q=a<&kQ7j?@5NOf8eI{HYqcAdkL9`4;SFPJdVOWZifL+Eqs60FKVwfnTzG3M~wN zCf8c$E9b0&Y4x^sRrSHPlRoo!*II7Fo-FCQyxv;-~9?p-KzgAg0n!0=odXUSHGC8i~keT z3sAS?5|4XPr926Xe(jT~6TT5oMy+=t(WIHK2Q5IlUBiwBY2p-D&H1nlEQ(=-A01&*C^}XZ>g7Z!51lv@;uN)*!t~e3)L?g6XJB}n{tz%ZbElII zHmVGz{#9(#Dp*weW_P)Al(qKsqZ}_9rNuwE>4H2kr$;VnV}=bZr9VrwCDmV`NojgI zI-t?)a`@`_4(HdnrEL-{e(&fKj<-%ZUb?mj`gYsPYE)E^~#x zsf9c};)=@9HpSDswrqRNVg0;cq=KYtHp~iCD%pc0PdAADGIqWoM5w>W;W|ew)D~qE zs;61G8(kE*cp3FiGV3P>^VJ04M5!;%`qKMHV!AHH&wwF0Fub+;vidbNyd}SXOM<4K zoa&KxrAM^ilO7m=COr9w%>CHG2EKX*+@XO%i`sCHZ~0bU>SafoUaM4!)lfj--;}wRqXWJ(Fl4rLksMXdXvH!}!l}FRUh>cXaQa&Gm^gOl^&7{$yp( zG>Xrb=e1B|(;QBjNJWPG=@*6bh|7KtOGanXT_$04h`pLL_)QbfpFWpgkME>ktnev~ z*4WgY+S`19$TPl0vV$)OCo{T-nh$cPpzi@#b90E(V74;fjY0t!U3}%freeLxu2{00 z7CXGS1A8HTvp`7AKog;)8tSF;&cViDdzl4aH*tHL{Z069E(+f#8v1l`V{M|2$G7hn z>eoO`!H+5>Hc|E%Yt8i6SG+@_V z4N31X>9B6>Lb1evbjLRIMNsxZiVc{`c-a_sE;R->9EipADrG)6Y_^i9%T1ug^`29x z@^ZU{{@h6V>0-=5h%f@|9ml>Jsz1Lda=$K9E4wnrTCy%a6e^ci^lD`t9ky1H1- z7c--`H*ca`yV@TV-S$XaXkU%Cf9Zq9?vXwNUR9Mv1K7U7(hy)yRr~OOgiE6Nmjs@_ zMW7n(X!*-K-W7wPqy*i;(fh4dZnuwpnm-#g)S!e7ADZ&<-8sZlk{$A&0aLPk22p!8 zeKr&CLc_?eE6x$^p!R_nGiQl#D6Jlo<~*ZxLoihf**%^y5i2y6^&&Q{v?gX|KLeJ$ zchuW!P&u9hvo7)S^F5*s5WXaSSJ)MmsS#wN+Q;FOl;1B7#bvD;R4lZS*nMPbzH#M} zA8oQm7w6aqSu9Yi2P)p1&5BcR4vt=Sp8Xu8^(l-iuJChUc)PUM#;MK(T9xNhA#-HP z9z--{G5!}%=zv{UxVsZ6O!Gsi9;7osLFjOk)-Pa&I>RjOx5gbWZGKhG6pvTX#5 ztF)7R2`=Op9JcTZlQcRFu3oCf0ifGm&fYH6(@%`o?yBS%B@`Z&>Uy+ zJ0DsLlZV>C)KT@d_qA)W#FroFoqnZ`oJq*PTZC3fc*%k49h(x|GwwoO9)P(nKD=G@ z+yF0^u%G_KOUay$Iy++m+P|Wb6DHk?6jMZy>s9H9AK`W?2!&^yB6FxZoRbt)fzC%V zyLRMPh$Z*3@O1NnVY=#ars<_1lsr@TYEZfJDlC}@OkI%UxlRh46mX0tS99jR*4&Oc zwmr>X#cAQuYX8l;1)&Q`hez^3a93);<}Pz`YQxhJkDRxvox;gci;!+JoQv@!30 zLrw=%0%g0qcz1vdc-t`V^KoChwMWfhm#47JbfllM9`ta^=*#R7-}+i3@65f`%r}{b zOiPOtOkaF~w+6L3GP0srMX>6@z|HusjR3*~MU$C;JA z#-_?4B#e(G1urT?d}yPJM|$gcTo9susKUDETe!7T4iUCMGvz;SiGFKiXMmM3jgiG-sk&mp zEi2!*Kw|H?_1FAp6Ea6Zpjj1BZQ#(n8&v4eFud_%i|h7em*#=P&q3PD!`>E2(MO!S zceTj5{4O)wDn9I`-d)@_kQpX$Z_fq;j{q|laiMG`khO-QB@>(bxca6AU4+M)Fy+$Y zIjoJaC;Pu1U6lU9jh|DF1(Q{#I3oFy15mnOf}UgSGQm}(?$&|4h_JMyagIEv{FaxT zf48oc-(z0kVi>&p3_#u$m)KB4_Kr}F_*kbLp@TFQxk5RxI^Q)ZVLy4Tg6JT>C3_b0 zDOemlpbhXZ$ruJoi9E8I>Ow{j4c3YTe87K%;-XGZwfGZe9xzvXi}*LnKrt{m}8x3JXSSgr5_(Q#rH@*czqnfKu1hsXrfVZ zKORwj`)ux8tVg|n+haU5nN1sTD#jV1{51)mpON*h%o{2^kiNHj1` zQIi)53Ptcd_B-bId-L|+LI~r2+zq}z<>4=2VF2xUY-+`s`PG{JceeHOk)HG3dAtOF z)3v*_53}mIT0pK0u;5ck@x4DE^i)_TUyUepkdXcMAH+GsA&_a{N^8lpj@w51ipr+c zC8U>y1Y{6AZq2fd<9lnX#L+*fIP8)c6IQa_ zO|-!b>A><;#Hwpn2GS&z!=lfVtU1k7XWpRTY}`M+V}yo>3ExN&7XY<7parmOmiVAS zq{;u(x0G4mB`PtADdXr&q{KK^&w3{2D}wCA3Gp!5!kN6E=534^off7nepDFHc*vuy zop(H_)bLI2G(8dYvcpC{{JcK|Z8;u#-5Dcf82EqfLNh(fTaLDUK6&6aMA(Pthpb!J zhH-s(TX!+i8u%9^e+`pc-iN=Pu-mYMJx^E3dPcn7C-@T_Q~GPB)(TbnrARxQSGjO% zX7c=I4{L)pE1uq>fooI*zZ>JLzDb(DGtN&EJfTB@kk__CX~wc&()dXirXw$(U$u^& zs!hEizo5%loplDtvw`=}M+mw4;WJzEPmRKl4tNYq$61*P+J(L>IhnpV<*$waoAi-jeH;St(uEs!L8yeJ6qDd9xW?;P@SdqJ^Wfpv;R$QVt~!n63^-2kDU7FULl zxY2{}Q%tOCN@B^ZGJon%OZ_exzGN-TO*`pHI5wq<`SM=y8|M#B9bW ztsp-6KdSO`P#)u19=q=%PyL0TJErBe`s?54sNHr$XyueIFALItMAy!gdhY-i66>-< z1daAWg!=j*m6j=&Z$;%AzJHGr0g~g&6!j928$DLHOY`tguvx=Vh9Tutmh#fKI#Cxm zB#E5uBfI4kNlyH2M?`1?(wJqWV5o>T`9i%p^olN?v9m1+LE$~|!(7wMeDFX;WZm06 z-C@}w%E{`QQLLh|n%{-sctw!r)9%n=3H|2g1zO9x?Kwy?mRf;*{`Y+d!5bxSJu_A9 zr{sEmuE?6l|F|GXz*22xq{eG*pKr<^x@XcHKNgrXd@V#l7#y`^isnL#5hCq;H77Yn zFvRY|0xI^Yc1dCHr{@74btpL6gm^&mhY^`BFIS!I`X@8CpdrT&{P;bDLV4*cBt`14 zOE{n)4k;KX4?|V^DNKAe)a~LszvJn%^I>4|DFW>>XM0XgpQKr0dimx@p8t(iRSu$r(8H@nvyM~4?2 zi&Pc~Wfv(3ndso>ezs5ICum!;6CuqVaq?%MJVFnR@v-uL5JvyU&vi-=XMg=D8wo!A z=-K*q6Mh-uTkPFc0Ovq0<{hiA>c7Zb@9KXLPj7D#L)Qa~Zx6goTX*A+FuXhZr5wvz z6@UGqhSOQS=%B3>;ByjU9#-8wBerePTSEM@Q#?>nUlAnSHqcIAzGYN=BRs&_J;c1f zKS9262jmPstlAtG8+*J0^PYM@M`LR-?CAv7<>0ZR2WUk|6c z<1ZK6pBjm4Ajl6KOquNB$m&mxd#~eXn(iGjJHeOvyL6i6?j;=#e#aW@Jianc5$BCH zQ6o25dP7(%!SktihhFD>@Ej}>qciDH?U~q?#xskm>F1j}p?NjxZykL##3eBP6*tFMx809cjG*S!WO*Ly{bWvQJR&FHkJa)$u6C<`APSuKyJ#fsl7%Pi1{2)1_ z#y8KRO$h|nXS=f#0%1h>6pC`CZ>@F6sGiN!3Gd8|uEY@EvdVgzR8R9U* zQw>XFbgm+Io@cq{_vwYprd>B?73{My=yEwFx@1S#Lt6N%%+ua1bp8vfkJ(9_*|D5x zweYn}4WO|b)1dz@QknSTsOv2B3!2(liAf;ePnSW25plmuf2x_DHtfUR2iB&9O@r|u z#mP?u_j$+-TRlP2UZ<$3_WQq9yd2JaTDZH7?(6jOCg}^?1h`uCePB9RBf-?8b{1u; zV&B$H9PIH7mQs|dQ)S&RLX45^%Qw4SOo4jBD7mZ^yxa@QJS$h)iJO%9{U7b?iNvEbj)s0Z z5_4-8VrTQ+< z6i1MePHgD)D^pMJA)}^oDJae8jx{ZYZ8);SWi#gl)7Hw{&C%_BBFj<$D6??>zdrh? z*40Qm$Xf=POjDO~TMiGum@#k@OQrZS;ncYIWb3CIqdh3;j$Yw#_vG}tmNYTQ5h|rE z+A&Y+noT@A|I&9}T~ou$$8oMXL2N$AkHG2bddi+9Q*6|2r>tz67j-dlZm8Fk3-_oL z#z=di4QB^9n+oI#`|78^3O%!69_Nos&OV#9f{Lr=D(7HI#1o>PJ^DbZ%yus1*30uRV+ z`~4-hDvd5F-E8%d8?lZl8jZJ04Ku&5&EZ%)<;1bw`&|q4)khwo4#MysxP#eEiSz=> z;@V*OtEiYW{!o!(jMB27s7@qZQ=JMO`SJ(fUzbd=?ZuaWx{|TPhe;dgXOl@vDRz{= z8JCdVD3dw#O&Dj%9VYYm$ZHgP-^f1fMq+>Rdg+-;kdz(EvR9Aku{Uh&!87d{yO5UX z3^kgXW?O0vw8ERGcYJ4omH?3Bu`Le>t#}XcG!D%~uVcs;?g~UN>Sr#b-`JsZU778N zI=yMK(C=^dQF&DSrPn1Fgb!~bKv~#s-sitNxCl$@+(cHg8Bz9ThlNRB=kYd~HKmca z>%Uil>X+uAW;wf>prokBXZaJu1{)8p9dA1uf#miIi;QAa)5+6b`&Ld=%arm+~;wG z*X050kqk*nw+mMv41^i_cb`$ zBA{D3Oitob{BX6`dd-5sN6U?2?k>+0swPhv-#8pdLzoiSsUE*#LQ7I*?0yeQ1uqu194;2#xgMOG2Bd^O>L zipF#NPN!gjRY^rF+}6(YMSNW&eQO4A*nW*V^#VcIuj&x^PYZmiPF#tTEq2Wz2E;=H)({}}cx#fIsD1CauzwAl#O6`$=;v&3=o1N1kH z%Vgu^`%>K)5Iy_E6Jf$I&-?8&=&-LX&6BWS6E)(Pt6c0-ADeVlvT2ulBqJ^9qd|## zTf+6BMF^MTmw|ahDGc9wdw|~VRa%2gxHX?DL3r!5l)sk7XYZ(u$4He^$d_Xi$m8>+ zN!PeOV!PyL@>?A=!}(bqk{Z1<;VTG=x-h$P3%J!LEv!NhEansJen*9jd~I9TDp6J(h0D ze-S|aXFXso_y>q3Xn!aG1NGD&vJ-OnGWfe8Ue1}#;5X?um7l{ck%&n8?Z>s#nV$`PCXc;a=|V6E8mO z<-GWlq$QP|NOzG#mK?Exk}@|hw5hO`i0C_Se!-DO&7zu$U*KK4=f8f#aB2cZ}oXe=nVuy#r{us((QVPJn&JR0Y_jjZ1$h z=?uknVI4PIj$$ojj<^9l4`T7#t*HyKL*@p`tSM9a2beVN@aa?E-+cm2%0*u+9+;uS zV-mm2ZrO&R)*IfQq8OSMPYM<~i;`?}QK{I5fS%-V=`$j*hh;~^V< z87?{S%8_MkH6Z&`O;#aExz5k=a(8swELvQaVjy=WdHyNE(v1kdFed( z5}r*_Emm8vj#YEh}pYAmqHeo=UWpc`|w0O0%OEK4zr)y|w1)D@IGE08y0E z4I5#W%YduQz2>7R2@(m!$7)0cTO#{R!4%`~f1Ki?$Ou68Nn7HB+{a>{*|@B-U+$c2 zCtqvzoHtJn|CXTJ5aoLynEyUVohjn#SGy7A*z!S5`K5<^1YlE=ICn5Q`M_Lq7S92a zn1%CW4W*<%d4}$IBBD0ldQ{lpg&Y(=_MOz}`6lgjItynJJ$xNj?D?mIP0zd##9E3mwtCz)?G!`f#|0XY$|!&w&suL&Rq z6*xEk1$9Kzwj(H;Nekh^7icf9jRomx@{NG)HhojDPmD{Ji(%8@ukMk{+4eUEy+1`D zh`EJsF;+Q@&Z+Uxh9rz*6?&RI18#fP;@f1I4!!{TeS1j_;oA&Z$?iRF?{guaKr_Wp ziI^*}e%zC(#f5ne^}dy8Dz2{pO5LKPS5sv@T8p;P!o?jc5z}nxJ3qs0;kh?2c4f(o zerMkCP?NDo*}su`_9njN%jwXCqRY@4smX3OO%3$^;y+xM!<9bNAASG7APu+MiTVRY z&()irSXpQ8$?(I{#rAKg)?x!0+t}V|6J2>ZZ*FZ(hj(w>aiUr^3M z3Q*RDVXSV*W1w520NA+n2p~onj{F5dFwlO)+b|3^3=8PCdNbW_Awa6Ph-V_FgO06FwbMJjFz#+ zXJUw z`2Sld2+PY&Y;c+p78iV!Hv?Tqgk6P=FaHHWfeo05IOc*YI2-*t670~?&@G{GD+V9B zs#J~m`4>cfLSE8=MacgJ0afza>)MF5)_;XE!EB?>Xw@h(S-^=d>`mTo&Yg)<-&L8Y zY>EV?y7)cy+E<2(rm=lsI1dHK#ltT)(38PM>=krXfo0(i!rVma`$A7ecIm;@0RA0re%j#vDU z8HQn%Ac~)}t{|`=anvCRml_Godb(Ci=#p0Av*gF>^Om&SSyuqksLT;hAJimf@Utyo zZf}mVPbOd6U@_PYURGjm_x#y?>uy#&VH3%O%1}p;xV8|LqT(~OK8c$I%a=5nXj`+d z9z6`86n!}1Lj2O__>_lD7Q8#DqUkY=gV#O~VSYrJ=JW5|7JP<1Z zQcXkqM3@p}&&30oPpvNji-l90C4V;HG&h%bSTzv$ZBbKo6{1<)f3P0ZY_~hbHkNj< z5^oX0A^+Kpq-dOByQkIeS9s*8p6;p53`AFN0Io3mRbWgGFolYq*}it0UyUkgR)>(k z$;sX}3^DX5pGH`1>xG2M-n^uePn>m*zqQ>ws%^ISSeBWH<-2s{@gMM78L&8*+ksUo zFRwe2Sdo$8Y$p(1gCIe5N7QNP&fQlQ&!6`pScVay(rvN;MamZ5Zb(#d^s0HXuI^-d zN5Vx|5FUabGWXi+f*oAc!ruudp7ng*^PIt_NwB&6M2s=@_#^ma+u)ZSndpYMdh}_F zP$vV;K-AbQ@2RTcTq?ChFCiXUe)!gAs>8t9z^6?+QD49((%Ej!Wg=pgaJL%ex7+)n znV(vQ20>@5_wn1{aVl@_4_--2Hu25mp(PGMYgPe!Q#{H&zLx}KH1H1G z-5^kVZnLRPh1s-IpoS%Tb0zKKV3{y-k%o07X`JN?1~WycVmJ?DODJ zzRn<1+rrGK9>+FaN#NkyEjS86CHgI=H@&jhVT4=5h=t2(P0uui=2+&hVM8skL{d>=&VB)%bOu-eqK-r>(lE_pdMB zd&-A^3R7zpL+E1Y!sHLRoOWe6uo(7bWrWyA-|T+f#ZHZ5^1hisMB8OZf(=rA)mEfm zGuxjXYCm41k|+0Na#Y6tk^3{*BIcWe(=B0=-c|2|*5(L(J7ESgN3CD(#0&9ayi?PX$ z(%_i#?cyidVvt|#LEO24a9=&Bf{6EZ%MQ31P3w8w;_U+@0l*^1{4%)B z8yqg2GUo7Ms#S!J;(a=vQ6klO9k?L90=x$=wSzSr5Nqj$TD2MezFqVrHTUUx1%Uv6m71Kg230OH+fH8%FOelt)_ zX`NK5jQB*1*S&_qcz!z1VHO=7J<9|x9n>O>1?5=> zHiSdvgbqyHgA=oqdYi#=vJB&Zy8z9Q=O63~KO=0*&r4IdCQYWK=!;&5-IoU}$rSxhopbK}h?$GJB|_b} zyw_?}S=>l!|ik?C&yc%Bs+3(MME>caTIJIk^^S@EL{~GUUHs6j;mR zKEzLpDHZuczKF1}gEP|QXq?%XI&>+J{Y;g0E!%I8n9rWM^k!RLLsQV`_p>cPx+(=@s_cyR!WTdOy+)A0hts3~fc_fRn4sEw zG;24X6^Py@T&^v%U3rDyo96SAwYiE%$9u#CMH;F5RGrCTRHB?hFxEGxz zrLY2}-^rq#FVkbShH>2cwH7izxZXDC)w6E+XSseYvi!k2!8Ph@D7mU0B=A}0$nS|$bIYGW(X&Y+?#!=u7D##|$pnI&p?YY(_IlG;RaCwi z!A;Ia3j?AhB6$GXyj*hAAoj<_MR|NoJ)rwAd0{rSZKH0#1_Cc6X8PiS;dyEs1A}eZ zoTqdb5#B70R)unwX7fZQv!sWtxBMv&Mp7q{ z9l9L4zF*+et(8jsK>ph?&JtEc6vKLbaZA=1i^XjN*s|J=m+QB#k`+dPstJt!JXjhR zmy;liS0eHY(aw)X4DS{L6#)+zqhM+4PcrqVy62;3?#89PwyoNiycqES@yGoUM1%!4 zj?6zLl)uIYKRF;%5Zu8IcA1KLy!pYjuP)MlAQLPy=*j=QafJ8JKv(qxjk%!>1ycvk z(>l7gwjVoQOWHAYMT-W8+jY<2>eHE#B`LqyY1mL>yRwXC0fOq(y&ExF^Gp$|sCOF; zL>Ot5QnA#SKC-Jxbuccy@BS}BcLLc>`k_^9Z=uR^rh&}(p$hRUClZb7XcyM(q~BRt z{ZIWGHR)#v@Z*SjclXuwUmDU*u;U^9_h4Q^>=lLgqQ1Ge70x1Ga`kW{t^h-G;CpkK%OQV28rp|d)g zMs@7@WMKs3_CYGYzTA^&6gx4aSJdcNjW$M|$Jg#TWIrQL7B7{WRVnTA(_MaTK6P#R z%#NRL{MsooldlZRaoO>;Ulv725EJ8K}=s@;2wNRK z0*-L-Tns`YIbtCH)SPp&xF@^gQ0Aqeu;uVEU4#BN0*=AL?yAu!3-Ki)0oEMt2C<5z zZ`v6Qy`DG0_3WNwddkdg_Z0ZJyXOF38wl-wTK-5;HXVhGj9vW&GFm8d*WqFu#`1*E zN~TWYaNxL)If+`spy*L{>8>^dH(UgRrfc{jN2F}{{Fp*n|FQIE?7MHf&uAqCcP1}z zZ8?MYiP=?kDow@j5vZJZyY%Q4tQe64Aj@;%Y9XT4%I|7*_~3;*+6Jlkrjm^}uH~e# zz>K5 z{+VI^tLD$dVxAJ)`5yplxvvJKrBzsJza{`5!T$%k;Q7a&aa{A<;VOF>D}TKP-R%Dd z;80%kpI?o=#s%<@2Z;8R8+icg!23VHuh!UGJb`v#f_LKpumV^?0IXn3zzzdI1L(9B zLkwM3Iu8RN7sCG&l~vw!>jr>T5KcY%3z9MZ4_2YXar+g(TyVDkGe+z_`9F{g`TvfP zW^zQS3egW>702W#d8{wMy!;2Nh*A4@B;XmAe?S)&0CWko#6Xa_u@-+psaWtDSU(JN z^&dQ_k16Hfu|TV(fW$*q8h~AJhc01o1(pH8DG3AZm1!WqIw8lT1L(0h|G%Q4Mp$`n zVgtO5uy?W#QMD1$=~&B4M4X>9_+;Zh2ok%q(Z3?$j;R4h*B}{oUnMG!nW}yKZD1eH zc8WTPc%3wn0Zg=lJj+)OvC-(sR$E{+F^5a}Gw{_mQ26+P4|AK+9-mV{ZeWz_7XUO5 zXR-R9*a*F4j}V7__=Nf%C8VAk{i8oyzB+|ddPMYwi|tq>RVLVt<)wl-%N%&J(FQof z^Z-Kz&lPT^l1?>}Z?)5=X3%|-w-Qs1Tc{=j`6vG|b>_{O<}TzslkZvAuWD2A<=Kki z<6lvN;O|>?Df3cFH);7a`KTe~?6WlA;514Xaj~2YtLNx*zT~SUmv1ZYL$cR7EVRG5 zPD1)uXr|RDr``~~TA+HApx78EU{Hac*Q)_G!6RbC^;YX(+4#Ay5j^|as&%k_Q+K>L zT_R8AZLo69ii$n58Mf^8af;GfJTZ5@&e%lXEeJw|g{v9Kj;C(li5^mdXn~e1I(9d01FlsQ?-+UyEi}nqv;WN{#Ph z4+4pcZ~w`NB{lF&?vSgQ3-TdRjeV)VrRaJ zrkY>qa=1;zeXWVdTa@@HfsLc=r%vYF9XG8u`m2Cm5iMV1w9(vlT>@Oio_EdCGdDs@ zS8gmSU*cHjZp_F*4w>!f^0%70=7y#S8nwvG?=CUuYGapE8j}Di!2RVrXY<(l=jlHC zY_sv>NIPz!W}8qFirJ(ulBR?SMJ$sCQno9Sp7t0dtKg&X2-*lmNNB(#)259D_PD6` zn9M`lk>VT@1c{ADT5SZM!-_IWq4VhJi^uuwsE8j_%5$a2@iSz^;~d=ozn-{&e~z*W zfX0+36Kg4>p0Zw7>U8Z|3jM7q@5B_ZiYi@->mWo@YYTRaCALJH>N$;W>;s#e{@+cg M|C#9!`0w=p0-le-PXGV_ literal 0 HcmV?d00001 diff --git a/exhibits/time-integrated-stdp/harvest_latents.sh b/exhibits/time-integrated-stdp/harvest_latents.sh new file mode 100755 index 0000000..b2e4b8f --- /dev/null +++ b/exhibits/time-integrated-stdp/harvest_latents.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +## get in user-provided program args +GPU_ID=$1 #1 +MODEL=$2 # evstdp trstdp tistdp stdp + +if [[ "$MODEL" != "evstdp" && "$MODEL" != "trstdp" && "$MODEL" != "tistdp" ]]; then + echo "Invalid Arg: $MODEL -- only 'evstdp', 'trstdp', 'tistdp' models supported!" + exit 1 +fi +echo " >>>> Setting up $MODEL on GPU $GPU_ID" + +SEEDS=(1234) # 77 811) + +PARAM_SUBDIR="/custom" +DISABLE_ADAPT_AT_EVAL=False ## set to true to turn off eval-time adaptive thresholds + +N_SAMPLES=50000 +DATA_X="../../data/mnist/trainX.npy" +DATA_Y="../../data/mnist/trainY.npy" + +for seed in "${SEEDS[@]}" +do + EXP_DIR="final_case1_results/exp_$MODEL""_$seed/" + echo " > Running Simulation/Model: $EXP_DIR" + + CODEBOOK=$EXP_DIR"training_codes.npy" + + CUDA_VISIBLE_DEVICES=$GPU_ID python extract_codes.py --dataX=$DATA_X \ + --n_samples=$N_SAMPLES \ + --codebook_fname=$CODEBOOK \ + --model_type=$MODEL \ + --model_fname=$EXP_DIR$MODEL \ + --disable_adaptation=$DISABLE_ADAPT_AT_EVAL \ + --param_subdir=$PARAM_SUBDIR +done diff --git a/exhibits/time-integrated-stdp/json_files/config.json b/exhibits/time-integrated-stdp/json_files/config.json new file mode 100644 index 0000000..9cc4515 --- /dev/null +++ b/exhibits/time-integrated-stdp/json_files/config.json @@ -0,0 +1,3 @@ +{ + "logging" : {"logging_level" : "ERROR"} +} diff --git a/exhibits/time-integrated-stdp/json_files/modules.json b/exhibits/time-integrated-stdp/json_files/modules.json new file mode 100644 index 0000000..1d16af0 --- /dev/null +++ b/exhibits/time-integrated-stdp/json_files/modules.json @@ -0,0 +1,16 @@ +[ + {"absolute_path": "ngclearn.components", + "attributes": [ + {"name": "VarTrace"}, + {"name": "PoissonCell"}, + {"name": "LIFCell"}, + {"name": "TraceSTDPSynapse"}, + {"name": "StaticSynapse"}] + }, + {"absolute_path": "ngcsimlib.operations", + "attributes": [ + {"name": "overwrite"}, + {"name": "summation"}] + } + +] diff --git a/exhibits/time-integrated-stdp/patch_model/assemble_patterns.py b/exhibits/time-integrated-stdp/patch_model/assemble_patterns.py new file mode 100755 index 0000000..bf71422 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/assemble_patterns.py @@ -0,0 +1,159 @@ +from jax import numpy as jnp, random +import sys, getopt as gopt, optparse, time +from ngclearn.utils.io_utils import makedir +from custom.patch_utils import Create_Patches +from ngclearn.utils.viz.synapse_plot import visualize + +import matplotlib #.pyplot as plt +matplotlib.use('Agg') +import matplotlib.pyplot as plt +cmap = plt.cm.jet + +################################################################################ +# read in general program arguments +options, remainder = gopt.getopt(sys.argv[1:], '', ["dataX=", "dataY=", "n_samples=", + "verbosity=", "exp_dir=", "model_dir=", + "model_type=", "param_subdir=", "seed="]) + +seed = 1234 +exp_dir = "exp/" +model_type = "tistdp" +model_dir = "" +param_subdir = "/custom" +n_samples = -1 +dataX = "../../data/mnist/trainX.npy" +dataY = "../../data/mnist/trainY.npy" +verbosity = 0 ## verbosity level (0 - fairly minimal, 1 - prints multiple lines on I/O) +for opt, arg in options: + if opt in ("--dataX"): + dataX = arg.strip() + elif opt in ("--dataY"): + dataY = arg.strip() + elif opt in ("--verbosity"): + verbosity = int(arg.strip()) + elif opt in ("--n_samples"): + n_samples = int(arg.strip()) + elif opt in ("--exp_dir"): + exp_dir = arg.strip() + elif opt in ('--param_subdir'): + param_subdir = arg.strip() + elif opt in ("--model_type"): + model_type = arg.strip() + elif opt in ("--model_dir"): + model_dir = arg.strip() + elif opt in ("--seed"): + seed = int(arg.strip()) + +if model_type == "tistdp": + print(" >> Setting up TI-STDP Patch-Model builder!") + from patch_tistdp_snn import load_from_disk, get_nodes +else: + print(" >> Model type ", model_type, " not supported!") + +print(">> X: {} Y: {}".format(dataX, dataY)) + +## load dataset +_X = jnp.load(dataX) +_Y = jnp.load(dataY) +if n_samples > 0: + _X = _X[0:n_samples, :] + _Y = _Y[0:n_samples, :] + print("-> Fitting model to only {} samples".format(n_samples)) +n_batches = _X.shape[0] ## num batches is = to num samples (online learning) + +## basic simulation hyper-parameter/configuration values go here +viz_mod = 100 # 1000 #10000 +mb_size = 1 ## locked to batch sizes of 1 +patch_shape = (28, 28) ## same as image_shape +py = px = 28 +img_shape = (py, px) +in_dim = patch_shape[0] * patch_shape[1] + +x = jnp.ones(patch_shape) +pty = ptx = 7 +in_patchShape = (pty, ptx) #(14, 14) #(7, 7) +in_patchSize = pty * ptx +stride_shape = (2, 2) #(0, 0) +patcher = Create_Patches(x, in_patchShape, stride_shape) +x_pat = patcher.create_patches(add_frame=False, center=False) +n_in_patches = patcher.nw_patches * patcher.nh_patches + +z1_patchSize = 4 * 4 +z1_patchCnt = 16 + +T = 250 #300 ## num time steps to simulate (stimulus presentation window length) +dt = 1. ## integration time constant + +################################################################################ +print("--- Building Model ---") +## Create model +makedir(exp_dir) +## get model from imported header file +disable_adaptation = False # True +model = load_from_disk(model_dir, param_dir=param_subdir, + disable_adaptation=disable_adaptation) +nodes, node_map = get_nodes(model) +################################################################################ +print("--- Starting Simulation ---") + +sim_start_time = time.time() ## start time profiling + +print("------------------------------------") +model.showStats(-1) + +## sampling concept +K = 300 #100 #400 #200 #100 +W2 = node_map.get("W2").weights.value +W1 = node_map.get("W1").weights.value + +print(">> Visualizing top-level filters!") +n_neurons = int(jnp.sqrt(W2.shape[0])) +visualize([W2], [(n_neurons, n_neurons)], "{}_toplevel_filters".format(exp_dir)) + +print(" >> Building Level 1 Block Filter Tensor...") +W1_images = [] +for i_ in range(W1.shape[1]): + img_i = [] ## combine along row axis (in cache) + strip_i = [] ## combine along column axis + for j_ in range(z1_patchCnt): + _filter = W1[in_patchSize * j_:(j_ + 1) * in_patchSize, i_:i_ + 1] + _filter = jnp.reshape(_filter, (ptx, pty)) # reshape to patch grid + strip_i.append(_filter) + if len(strip_i) >= int(py/pty): + img_i.append(jnp.concatenate(strip_i, axis=1)) + strip_i = [] + img_i = jnp.concatenate(img_i, axis=0) + # plt.imshow(img_i) + # plt.savefig("exp_evstdp_1234/test/filter{}.jpg".format(i_)) + print("\r {} filters built...".format(i_),end="") + W1_images.append(img_i) + img_i = [] ## clear cache +print() + +print(" >> Building super-imposed sampled images!") +samples = [] +for i in range(W2.shape[1]): + W2_i = W2[:, i] + indices = jnp.argsort(W2_i, descending=True)[0:K] + coefficients = W2_i[indices] + #indices = jnp.where(W2_i > thr) + #Z = jnp.amax(W2_i) + #indices = jnp.flip(jnp.argsort(W2_i))[0:K] + #coefficients = W2_i[indices]#/Z + + xSample = 0. + ptr = 0 + for idx in indices: # j in range(indices.shape[0]): + coeff_i = coefficients[ptr] + Ki = W1_images[idx] * coeff_i + xSample += Ki + ptr += 1 + xSample = xSample.reshape(1, -1).T + samples.append(xSample) + print("\r Crafted {} samples...".format(len(samples)), end="") +print() +samples = jnp.concatenate(samples, axis=1) + +visualize([samples], [(28, 28)], "{}{}".format(exp_dir, "samples")) + + diff --git a/exhibits/time-integrated-stdp/patch_model/custom/LCNSynapse.py b/exhibits/time-integrated-stdp/patch_model/custom/LCNSynapse.py new file mode 100755 index 0000000..6e7f332 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/custom/LCNSynapse.py @@ -0,0 +1,114 @@ +from jax import random, numpy as jnp, jit +from ngclearn import resolver, Component, Compartment +from ngclearn.components.jaxComponent import JaxComponent +from ngclearn.utils import tensorstats +from ngclearn.utils.weight_distribution import initialize_params +from ngcsimlib.logger import info + +## init function for enforcing locally-connected structure of synaptic cable +def init_block_synapse(weight_init, num_models, patch_shape, key): + weight_shape = (num_models * patch_shape[0], num_models * patch_shape[1]) + weights = initialize_params(key[2], {"dist": "constant", "value": 0.}, + weight_shape, use_numpy=True) + for i in range(num_models): + weights[patch_shape[0] * i:patch_shape[0] * (i + 1), + patch_shape[1] * i:patch_shape[1] * (i + 1)] = ( + initialize_params(key[1], init_kernel=weight_init, + shape=patch_shape, use_numpy=True)) + return weights + + +class LCNSynapse(JaxComponent): ## locally-connected synaptic cable + def __init__(self, name, shape, model_shape=(1,1), weight_init=None, bias_init=None, + resist_scale=1., p_conn=1., batch_size=1, **kwargs): + super().__init__(name, **kwargs) + + self.batch_size = batch_size + self.weight_init = weight_init + self.bias_init = bias_init + self.n_models = model_shape[0] + self.model_patches = model_shape[1] + + ## Synapse meta-parameters + self.sub_shape = (shape[0], shape[1]*self.model_patches) ## shape of synaptic efficacy matrix # = (in_dim, hid_dim) = (d3, d2) + self.shape = (self.sub_shape[0] * self.n_models, self.sub_shape[1] * self.n_models) + self.Rscale = resist_scale ## post-transformation scale factor + + ## Set up synaptic weight values + tmp_key, *subkeys = random.split(self.key.value, 4) + if self.weight_init is None: + info(self.name, "is using default weight initializer!") + self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8} + + weights = init_block_synapse(self.weight_init, self.n_models, self.sub_shape, subkeys) + + if 0. < p_conn < 1.: ## only non-zero and <1 probs allowed + mask = random.bernoulli(subkeys[1], p=p_conn, shape=self.shape) + weights = weights * mask ## sparsify matrix + + self.batch_size = 1 + ## Compartment setup + preVals = jnp.zeros((self.batch_size, self.shape[0])) + postVals = jnp.zeros((self.batch_size, self.shape[1])) + self.inputs = Compartment(preVals) + self.outputs = Compartment(postVals) + self.weights = Compartment(weights) + ## Set up (optional) bias values + if self.bias_init is None: + info(self.name, "is using default bias value of zero (no bias " + "kernel provided)!") + self.biases = Compartment( + initialize_params(subkeys[2], bias_init, (1, self.shape[1])) + if bias_init else 0.0 + ) + + @staticmethod + def _advance_state(Rscale, inputs, weights, biases): + outputs = (jnp.matmul(inputs, weights)) * Rscale + biases + return outputs + + @resolver(_advance_state) + def advance_state(self, outputs): + self.outputs.set(outputs) + + @staticmethod + def _reset(batch_size, shape): + preVals = jnp.zeros((batch_size, shape[0])) + postVals = jnp.zeros((batch_size, shape[1])) + inputs = preVals + outputs = postVals + return inputs, outputs + + @resolver(_reset) + def reset(self, inputs, outputs): + self.inputs.set(inputs) + self.outputs.set(outputs) + + def save(self, directory, **kwargs): + file_name = directory + "/" + self.name + ".npz" + if self.bias_init != None: + jnp.savez(file_name, weights=self.weights.value, + biases=self.biases.value) + else: + jnp.savez(file_name, weights=self.weights.value) + + def load(self, directory, **kwargs): + file_name = directory + "/" + self.name + ".npz" + data = jnp.load(file_name) + self.weights.set(data['weights']) + if "biases" in data.keys(): + self.biases.set(data['biases']) + + def __repr__(self): + comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + maxlen = max(len(c) for c in comps) + 5 + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" + for c in comps: + stats = tensorstats(getattr(self, c).value) + if stats is not None: + line = [f"{k}: {v}" for k, v in stats.items()] + line = ", ".join(line) + else: + line = "None" + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" + return lines diff --git a/exhibits/time-integrated-stdp/patch_model/custom/__init__.py b/exhibits/time-integrated-stdp/patch_model/custom/__init__.py new file mode 100755 index 0000000..45a5b57 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/custom/__init__.py @@ -0,0 +1,3 @@ +from .LCNSynapse import LCNSynapse +from .ti_STDP_Synapse import TI_STDP_Synapse +from .ti_STDP_LCNSynapse import TI_STDP_LCNSynapse diff --git a/exhibits/time-integrated-stdp/patch_model/custom/patch_utils.py b/exhibits/time-integrated-stdp/patch_model/custom/patch_utils.py new file mode 100755 index 0000000..4e79f9b --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/custom/patch_utils.py @@ -0,0 +1,95 @@ +""" +Image/tensor patching utility routines. +""" +import numpy as np +from jax import numpy as jnp + +class Create_Patches: + """ + This function will create small patches out of the image based on the provided attributes. + + Args: + img: jax array of size (H, W) + patched: (height_patch, width_patch) + overlap: (height_overlap, width_overlap) + + add_frame: increases the img size by (height_patch - height_overlap, width_patch - width_overlap) + create_patches: creates small patches out of the image based on the provided attributes. + + Returns: + jnp.array: Array containing the patches + shape: (num_patches, patch_height, patch_width) + + """ + + def __init__(self, img, patch_shape, overlap_shape): + self.img = img + self.height_patch, self.width_patch = patch_shape + self.height_over, self.width_over = overlap_shape + + self.height, self.width = self.img.shape + + if self.height_over > 0: + self.nw_patches = (self.width + self.width_over) // (self.width_patch - self.width_over) - 2 + else: + self.nw_patches = self.width // self.width_patch + if self.width_over > 0: + self.nh_patches = (self.height + self.height_over) // (self.height_patch - self.height_over) - 2 + else: + self.nh_patches = self.height // self.height_patch + # print('nw_patches', 'nh_patches') + # print(self.nw_patches, self.nh_patches) + # print("...........") + + def _add_frame(self): + """ + This function will add zero frames (increase the dimension) to the image + + Returns: + image with increased size (x.shape[0], x.shape[1]) -> (x.shape[0] + (height_patch - height_overlap), + x.shape[1] + (width_patch - width_overlap)) + """ + self.img = np.hstack((jnp.zeros((self.img.shape[0], (self.height_patch - self.height_over))), + self.img, + jnp.zeros((self.img.shape[0], (self.height_patch - self.height_over))))) + self.img = np.vstack((jnp.zeros(((self.width_patch - self.width_over), self.img.shape[1])), + self.img, + jnp.zeros(((self.width_patch - self.width_over), self.img.shape[1])))) + + self.height, self.width = self.img.shape + + self.nw_patches = (self.width + self.width_over) // (self.width_patch - self.width_over) - 2 + self.nh_patches = (self.height + self.height_over) // (self.height_patch - self.height_over) - 2 + + + + def create_patches(self, add_frame=False, center=True): + """ + This function will create small patches out of the image based on the provided attributes. + + Keyword Args: + add_frame: If true the function will add zero frames (increase the dimension) to the image + + Returns: + jnp.array: Array containing the patches + shape: (num_patches, patch_height, patch_width) + """ + + if add_frame == True: + self._add_frame() + + if center == True: + mu = np.mean(self.img, axis=0, keepdims=True) + self.img = self.img - mu + + result = [] + for nh_ in range(self.nh_patches): + for nw_ in range(self.nw_patches): + img_ = self.img[(self.height_patch - self.height_over) * nh_: nh_ * ( + self.height_patch - self.height_over) + self.height_patch + , (self.width_patch - self.width_over) * nw_: nw_ * ( + self.width_patch - self.width_over) + self.width_patch] + + if img_.shape == (self.height_patch, self.width_patch): + result.append(img_) + return jnp.array(result) \ No newline at end of file diff --git a/exhibits/time-integrated-stdp/patch_model/custom/ti_STDP_LCNSynapse.py b/exhibits/time-integrated-stdp/patch_model/custom/ti_STDP_LCNSynapse.py new file mode 100755 index 0000000..d089696 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/custom/ti_STDP_LCNSynapse.py @@ -0,0 +1,111 @@ +from ngclearn import Component, Compartment, resolver +from .LCNSynapse import LCNSynapse +from ngclearn.utils.model_utils import normalize_matrix + +from jax import numpy as jnp, random +import os.path + + +class TI_STDP_LCNSynapse(LCNSynapse): + def __init__(self, name, shape, model_shape=(1, 1), alpha=0.0075, + beta=0.5, pre_decay=0.5, resist_scale=1., p_conn=1., + weight_init=None, Aplus=1, Aminus=1, **kwargs): + super().__init__(name, shape=shape, model_shape=model_shape, + weight_init=weight_init, bias_init=None, + resist_scale=resist_scale, p_conn=p_conn, **kwargs) + + # Params + self.batch_size = 1 + self.alpha = alpha + self.beta = beta + self.pre_decay = pre_decay + self.Aplus = Aplus + self.Aminus = Aminus + + # Compartments + self.pre = Compartment(None) + self.post = Compartment(None) + + self.reset() + + @staticmethod + def _norm(weights, norm_scale): + return normalize_matrix(weights, wnorm=100, scale=norm_scale) + + @resolver(_norm) + def norm(self, weights): + self.weights.set(weights) + + @staticmethod + def _reset(batch_size, shape): + return jnp.zeros((batch_size, shape[0])), \ + jnp.zeros((batch_size, shape[1])), \ + jnp.zeros((batch_size, shape[0])), \ + jnp.zeros((batch_size, shape[1])) + + @resolver(_reset) + def reset(self, inputs, outputs, pre, post): + self.inputs.set(inputs) + self.outputs.set(outputs) + self.pre.set(pre) + self.post.set(post) + + @staticmethod + def _evolve(pre, post, weights, + shape, alpha, beta, pre_decay, Aplus, Aminus, + t, dt): + mask = jnp.where(0 != jnp.abs(weights), 1., 0.) + + pre_size = shape[0] + post_size = shape[1] + + pre_synaptic_binary_mask = jnp.where(pre > 0, 1., 0.) + post_synaptic_binary_mask = jnp.where(post > 0, 1., 0.) + + broadcast_pre_synaptic_binary_mask = jnp.zeros(shape) + jnp.reshape( + pre_synaptic_binary_mask, (pre_size, 1)) + broadcast_post_synaptic_binary_mask = jnp.zeros(shape) + jnp.reshape( + post_synaptic_binary_mask, (1, post_size)) + + broadcast_pre_synaptic_time = jnp.zeros(shape) + jnp.reshape(pre, ( + pre_size, 1)) + broadcast_post_synaptic_time = jnp.zeros(shape) + jnp.reshape(post, ( + 1, post_size)) + + no_pre_synaptic_spike_mask = broadcast_post_synaptic_binary_mask * ( + 1. - broadcast_pre_synaptic_binary_mask) + no_pre_synaptic_weight_update = (-alpha) * (pre_decay / jnp.exp( + (1 / dt) * (t - broadcast_post_synaptic_time))) + # no_pre_synaptic_weight_update = no_pre_synaptic_weight_update + + # Both have spiked + both_spike_mask = (broadcast_post_synaptic_binary_mask * + broadcast_pre_synaptic_binary_mask) + both_spike_update = (-alpha / ( + broadcast_pre_synaptic_time - broadcast_post_synaptic_time - ( + 0.5 * dt))) * \ + (beta / jnp.exp( + (1 / dt) * (t - broadcast_post_synaptic_time))) + + masked_no_pre_synaptic_weight_update = (no_pre_synaptic_spike_mask * + no_pre_synaptic_weight_update) + masked_both_spike_update = both_spike_mask * both_spike_update + + plasticity = jnp.where(masked_both_spike_update > 0, + Aplus * masked_both_spike_update, + Aminus * masked_both_spike_update) + + decay = masked_no_pre_synaptic_weight_update + + plasticity = plasticity * (1 - weights) + + decay = decay * weights + + update = plasticity + decay + _W = weights + update + ## return masked synaptic weight matrix (enforced structure) + return jnp.clip(_W, 0., 1.) * mask + + @resolver(_evolve) + def evolve(self, weights): + self.weights.set(weights) diff --git a/exhibits/time-integrated-stdp/patch_model/custom/ti_STDP_Synapse.py b/exhibits/time-integrated-stdp/patch_model/custom/ti_STDP_Synapse.py new file mode 100644 index 0000000..d8bc695 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/custom/ti_STDP_Synapse.py @@ -0,0 +1,97 @@ +from ngclearn import Component, Compartment, resolver +from ngclearn.components import DenseSynapse +from ngclearn.utils.model_utils import normalize_matrix +from jax import numpy as jnp + +class TI_STDP_Synapse(DenseSynapse): + def __init__(self, name, shape, alpha=0.0075, beta=0.5, pre_decay=0.5, + resist_scale=1., p_conn=1., weight_init=None, **kwargs): + super().__init__(name=name, shape=shape, resist_scale=resist_scale, + p_conn=p_conn, weight_init=weight_init, **kwargs) + + #Params + self.batch_size = 1 + self.alpha = alpha + self.beta = beta + self.pre_decay = pre_decay + self.Aplus = 1 + self.Aminus = 1 + + #Compartments + self.pre = Compartment(None) + self.post = Compartment(None) + + self.reset() + + @staticmethod + def _norm(weights, norm_scale): + return normalize_matrix(weights, wnorm=100, scale=norm_scale) + + @resolver(_norm) + def norm(self, weights): + self.weights.set(weights) + + + @staticmethod + def _reset(batch_size, shape): + return jnp.zeros((batch_size, shape[0])), \ + jnp.zeros((batch_size, shape[1])), \ + jnp.zeros((batch_size, shape[0])), \ + jnp.zeros((batch_size, shape[1])) + + @resolver(_reset) + def reset(self, inputs, outputs, pre, post): + self.inputs.set(inputs) + self.outputs.set(outputs) + self.pre.set(pre) + self.post.set(post) + + @staticmethod + def _evolve(pre, post, weights, + shape, alpha, beta, pre_decay, Aplus, Aminus, + t, dt): + pre_size = shape[0] + post_size = shape[1] + + pre_synaptic_binary_mask = jnp.where(pre > 0, 1., 0.) + post_synaptic_binary_mask = jnp.where(post > 0, 1., 0.) + + broadcast_pre_synaptic_binary_mask = jnp.zeros(shape) + jnp.reshape(pre_synaptic_binary_mask, (pre_size, 1)) + broadcast_post_synaptic_binary_mask = jnp.zeros(shape) + jnp.reshape(post_synaptic_binary_mask, (1, post_size)) + + broadcast_pre_synaptic_time = jnp.zeros(shape) + jnp.reshape(pre, (pre_size, 1)) + broadcast_post_synaptic_time = jnp.zeros(shape) + jnp.reshape(post, (1, post_size)) + + no_pre_synaptic_spike_mask = broadcast_post_synaptic_binary_mask * (1. - broadcast_pre_synaptic_binary_mask) + no_pre_synaptic_weight_update = (-alpha) * (pre_decay / jnp.exp((1 / dt) * (t - broadcast_post_synaptic_time))) + # no_pre_synaptic_weight_update = no_pre_synaptic_weight_update + + # Both have spiked + both_spike_mask = broadcast_post_synaptic_binary_mask * broadcast_pre_synaptic_binary_mask + both_spike_update = (-alpha / (broadcast_pre_synaptic_time - broadcast_post_synaptic_time - (0.5 * dt))) * \ + (beta / jnp.exp((1 / dt) * (t - broadcast_post_synaptic_time))) + + masked_no_pre_synaptic_weight_update = no_pre_synaptic_spike_mask * no_pre_synaptic_weight_update + masked_both_spike_update = both_spike_mask * both_spike_update + + plasticity = jnp.where(masked_both_spike_update > 0, + Aplus * masked_both_spike_update, + Aminus * masked_both_spike_update) + + decay = masked_no_pre_synaptic_weight_update + + plasticity = plasticity * (1 - weights) + + decay = decay * weights + + update = plasticity + decay + _W = weights + update + + return jnp.clip(_W, 0., 1.) + + @resolver(_evolve) + def evolve(self, weights): + self.weights.set(weights) + + + diff --git a/exhibits/time-integrated-stdp/patch_model/patch_tistdp_snn.py b/exhibits/time-integrated-stdp/patch_model/patch_tistdp_snn.py new file mode 100755 index 0000000..4b891b8 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/patch_tistdp_snn.py @@ -0,0 +1,235 @@ +from ngclearn import Context, numpy as jnp +from ngclearn.components import (LIFCell, PoissonCell, BernoulliCell, StaticSynapse, + VarTrace, Monitor) +from custom.ti_STDP_Synapse import TI_STDP_Synapse +from custom.ti_STDP_LCNSynapse import TI_STDP_LCNSynapse +from ngclearn.operations import summation +from jax import jit, random +from ngclearn.utils.viz.synapse_plot import viz_block +from ngclearn.utils.model_utils import scanner +import ngclearn.utils.weight_distribution as dists +from matplotlib import pyplot as plt +import ngclearn.utils as utils + +def get_nodes(model): + nodes = model.get_components("W1", "W1ie", "W1ei", "z0", "z1e", "z1i", + "W2", "W2ie", "W2ei", "z2e", "z2i", "M") + map = {} ## node-name look-up hash table + for node in nodes: + map[node.name] = node + return nodes, map + +def build_model(seed=1234, in_dim=1, in_patchShape=None, n_in_patches=1): + window_length = 250 #300 + + ## try no striding and normalize coefficients for filters in recon? + dt = 1 + X_size = int(jnp.sqrt(in_dim)) ## get input square dim + + z0_patchShape = in_patchShape #(7, 7) #(2, 2) #(7, 7) + z0_patchSize = z0_patchShape[0] * z0_patchShape[0] + z1_patchSize = 8 * 8 #6 * 6 # 5 * 5 # 4 * 4 #2 * 2 #4 * 4 + z1_patchCnt = n_in_patches #int(X_size/z0_patchShape[0]) * int(X_size/z0_patchShape[0]) # 16 + z1_RfieldCnt = 1 + + hidden_size = z1_patchSize * z1_patchCnt + out_size = 15 * 15 #10 * 10 # 8 * 8 #6 * 6 #4 * 4 + + ## INIT inhibitory/excitatory matrix with block diagonalization + + R1 = 1. #12. #6. #1. #12. #6. #1. #6. + R2 = 6. #12. + exc = 22.5 + inh = 120. #10. #120. #60. #15. #10. + tau_m_e = 100. # ms (excitatory membrane time constant) + tau_m_i = 100. # ms (inhibitory membrane time constant) + # tau_theta = 500. + # theta_plus = 0.2 + tau_theta = 1e5 #1e3 #1e5 #500. #1e4 #1e5 + theta_plus = 0.05 #0.1 #0.05 + thr_jitter = 0. + + px = py = X_size + hidx = hidy = int(jnp.sqrt(hidden_size)) + + dkey = random.PRNGKey(seed) + dkey, *subkeys = random.split(dkey, 12) + + with Context("model") as model: + M = Monitor("M", default_window_length=window_length) + ## layer 0 + z0 = PoissonCell("z0", n_units=in_dim, max_freq=63.75, key=subkeys[0]) + ## layer 1 + W1 = TI_STDP_LCNSynapse( + "W1", shape=(z0_patchSize, z1_patchSize), + model_shape=(z1_patchCnt, z1_RfieldCnt), + alpha=0.0075 * 0.5, beta=1.25, pre_decay=0.75, + weight_init=dists.uniform(amin=0.025, amax=0.8), + resist_scale=R1, key=subkeys[1] + ) + z1e = LIFCell("z1e", n_units=hidden_size, tau_m=tau_m_e, resist_m=tau_m_e/dt, + refract_time=5., thr_jitter=thr_jitter, tau_theta=tau_theta, + theta_plus=theta_plus, one_spike=True, key=subkeys[2]) + z1i = LIFCell("z1i", n_units=hidden_size, tau_m=tau_m_i, resist_m=tau_m_i/dt, + refract_time=5., thr_jitter=thr_jitter, thr=-40., v_rest=-60., + v_reset=-45., tau_theta=0.) + W1ie = StaticSynapse("W1ie", shape=(hidden_size, hidden_size), + weight_init=dists.hollow(-inh, block_diag_mask_width=z1_patchSize)) + W1ei = StaticSynapse("W1ei", shape=(hidden_size, hidden_size), + weight_init=dists.eye(exc, block_diag_mask_width=z1_patchSize)) + ## layer 2 + W2 = TI_STDP_Synapse("W2", alpha=0.025 * 2, beta=2, pre_decay=0.25 * 0.5, + shape=(hidden_size, out_size), + weight_init=dists.uniform(amin=0.025, amax=0.8), + resist_scale=R2, key=subkeys[6]) + z2e = LIFCell("z2e", n_units=out_size, tau_m=tau_m_e, resist_m=tau_m_e/dt, + refract_time=5., thr_jitter=thr_jitter, tau_theta=tau_theta, + theta_plus=theta_plus, one_spike=True, key=subkeys[4]) + z2i = LIFCell("z2i", n_units=out_size, tau_m=tau_m_i, resist_m=tau_m_i/dt, + refract_time=5., thr_jitter=thr_jitter, thr=-40., v_rest=-60., + v_reset=-45., tau_theta=0.) + W2ie = StaticSynapse("W2ie", shape=(out_size, out_size), + weight_init=dists.hollow(-inh)) + W2ei = StaticSynapse("W2ei", shape=(out_size, out_size), + weight_init=dists.eye(exc)) + + ## layer 0 to layer 1 + W1.inputs << z0.outputs + W1ie.inputs << z1i.s + z1e.j << summation(W1.outputs, W1ie.outputs) + W1ei.inputs << z1e.s + z1i.j << W1ei.outputs + ## layer 1 to layer 2 + W2.inputs << z1e.s_raw + W2ie.inputs << z2i.s + z2e.j << summation(W2.outputs, W2ie.outputs) + W2ei.inputs << z2e.s + z2i.j << W2ei.outputs + + # wire relevant plasticity statistics to synaptic cables W1 and W2 + W1.pre << z0.tols + W1.post << z1e.tols + W2.pre << z1e.tols + W2.post << z2e.tols + + ## wire statistics into global monitor + ## layer 1 stats + M << z1e.s + M << z1e.j + M << z1e.v + ## layer 2 stats + M << z2e.s + M << z2e.j + M << z2e.v + + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, M, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, + compile_key="reset") + model.wrap_and_add_command(jit(model.reset), name="reset") + + @model.dynamicCommand + def clamp(x): + z0.inputs.set(x) + + @model.dynamicCommand + def viz(name, low_rez=True, raster_name="exp/raster_plot"): + viz_block([W1.weights.value, W2.weights.value], + [(px, py), (hidx, hidy)], name + "_block", padding=2, + low_rez=low_rez) + + fig, ax = plt.subplots(3, 2, sharex=True, figsize=(15, 8)) + + for k in range(out_size): + ax[1][0].plot([i for i in range(window_length)], + M.view(z1e.v)[:, :, k]) + ax[0][0].plot([i for i in range(window_length)], + M.view(z1e.j)[:, :, k]) + + ax[1][1].plot([i for i in range(window_length)], + M.view(z2e.v)[:, :, k]) + ax[0][1].plot([i for i in range(window_length)], + M.view(z2e.j)[:, :, k]) + # print("----") + # data = M.view(z2e.v) + # print(jnp.amax(data, axis=0)) + # print("----") + + utils.viz.raster.create_raster_plot(M.view(z1e.s), ax=ax[2][0]) + utils.viz.raster.create_raster_plot(M.view(z2e.s), ax=ax[2][1]) + # plt.show() + plt.savefig(raster_name) + plt.close() + + @scanner + def observe(current_state, args): + _t, _dt = args + current_state = model.advance_state(current_state, t=_t, dt=_dt) + current_state = model.evolve(current_state, t=_t, dt=_dt) + return current_state, (current_state[z1e.v.path], + current_state[z2e.v.path]) + + @model.dynamicCommand + def showStats(i): + print(f"\n~~~~~Iteration {str(i)}~~~~~~") + print("W1:\n min {} ; max {} \n mu {} ; norm {}".format(jnp.amin(W1.weights.value), jnp.amax(W1.weights.value), jnp.mean(W1.weights.value), + jnp.linalg.norm(W1.weights.value))) + print("W2:\n min {} ; max {} \n mu {} ; norm {}".format(jnp.amin(W2.weights.value), jnp.amax(W2.weights.value), jnp.mean(W2.weights.value), + jnp.linalg.norm(W2.weights.value))) + #M.halt_all() + return model + +def load_from_disk(model_directory, param_dir="/custom", disable_adaptation=True): + with Context("model") as model: + model.load_from_dir(model_directory, custom_folder=param_dir) + nodes = model.get_components("W1", "W1ie", "W1ei", "W2", "W2ie", "W2ei", + "z0", "z1e", "z1i", "z2e", "z2i") + (W1, W1ie, W1ei, W2, W2ie, W2ei,z0, z1e, z1i, z2e, z2i) = nodes + if disable_adaptation: + z1e.tau_theta = 0. ## disable homeostatic adaptation + z2e.tau_theta = 0. ## disable homeostatic adaptation + + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, + compile_key="reset") + model.wrap_and_add_command(jit(model.reset), name="reset") + + @model.dynamicCommand + def clamp(x): + z0.inputs.set(x) + + @model.dynamicCommand + def viz(name, low_rez=True): + viz_block([W1.weights.value, W2.weights.value], + [(28, 28), (10, 10)], name + "_block", padding=2, + low_rez=low_rez) + + @scanner + def infer(current_state, args): + _t, _dt = args + current_state = model.advance_state(current_state, t=_t, dt=_dt) + return current_state, (current_state[z1e.s_raw.path], + current_state[z2e.s_raw.path]) + + @model.dynamicCommand + def showStats(i): + print(f"\n~~~~~Iteration {str(i)}~~~~~~") + print("W1:\n min {} ; max {} \n mu {} ; norm {}".format( + jnp.amin(W1.weights.value), jnp.amax(W1.weights.value), + jnp.mean(W1.weights.value), + jnp.linalg.norm(W1.weights.value))) + print("W2:\n min {} ; max {} \n mu {} ; norm {}".format( + jnp.amin(W2.weights.value), jnp.amax(W2.weights.value), + jnp.mean(W2.weights.value), + jnp.linalg.norm(W2.weights.value))) + return model + diff --git a/exhibits/time-integrated-stdp/patch_model/patched_train.py b/exhibits/time-integrated-stdp/patch_model/patched_train.py new file mode 100755 index 0000000..54726b8 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/patched_train.py @@ -0,0 +1,192 @@ +from jax import numpy as jnp, random +import sys, getopt as gopt, optparse, time +from ngclearn.utils.io_utils import makedir +from custom.patch_utils import Create_Patches +from ngclearn.utils.viz.synapse_plot import visualize + +import matplotlib #.pyplot as plt +matplotlib.use('Agg') +import matplotlib.pyplot as plt +cmap = plt.cm.jet + +def viz_fields(name, W1, n_in_patches): + z0_patchShape = (7, 7) + z0_patchSize = z0_patchShape[0] * z0_patchShape[0] + _W1 = W1.weights.value + B = [] + for i_ in range(_W1.shape[1]): + for j_ in range(n_in_patches): + _filter = _W1[z0_patchSize * j_:(j_ + 1) * z0_patchSize, i_:i_ + 1] + if jnp.sum(_filter) > 0.: + B.append(_filter) + B = jnp.concatenate(B, axis=1) + print("Viz.shape: ", B.shape) + visualize([B], [z0_patchShape], name + "_filters") + +def save_parameters(model_dir, nodes): ## model context saving routine + makedir(model_dir) + for node in nodes: + node.save(model_dir) ## call node's local save function + +################################################################################ +# read in general program arguments +options, remainder = gopt.getopt( + sys.argv[1:], '', ["dataX=", "dataY=", "n_samples=", "n_iter=", "verbosity=", + "bind_target=", "exp_dir=", "model_type=", "seed="] +) + +seed = 1234 +exp_dir = "exp/" +model_type = "tistdp" +n_iter = 1 # 10 ## total number passes through dataset +n_samples = -1 +bind_target = 40000 +dataX = "../../data/mnist/trainX.npy" +dataY = "../../data/mnist/trainY.npy" +verbosity = 0 ## verbosity level (0 - fairly minimal, 1 - prints multiple lines on I/O) +for opt, arg in options: + if opt in ("--dataX"): + dataX = arg.strip() + elif opt in ("--dataY"): + dataY = arg.strip() + elif opt in ("--verbosity"): + verbosity = int(arg.strip()) + elif opt in ("--n_samples"): + n_samples = int(arg.strip()) + elif opt in ("--n_iter"): + n_iter = int(arg.strip()) + elif opt in ("--bind_target"): + bind_target = int(arg.strip()) + elif opt in ("--exp_dir"): + exp_dir = arg.strip() + elif opt in ("--model_type"): + model_type = arg.strip() + elif opt in ("--seed"): + seed = int(arg.strip()) + +if model_type == "tistdp": + print(" >> Setting up TI-STDP Patch-Model builder!") + from patch_tistdp_snn import build_model, get_nodes +else: + print(" >> Model type ", model_type, " not supported!") + +print(">> X: {} Y: {}".format(dataX, dataY)) + +## load dataset +_X = jnp.load(dataX) +_Y = jnp.load(dataY) +if n_samples > 0: + _X = _X[0:n_samples, :] + _Y = _Y[0:n_samples, :] + print("-> Fitting model to only {} samples".format(n_samples)) +n_batches = _X.shape[0] ## num batches is = to num samples (online learning) + +## basic simulation hyper-parameter/configuration values go here +viz_mod = 1000 #10000 +mb_size = 1 ## locked to batch sizes of 1 +patch_shape = (28, 28) ## same as image_shape +num_patches = 10 +in_dim = patch_shape[0] * patch_shape[1] + +x = jnp.ones(patch_shape) +in_patchShape = (7, 7) +stride_shape = (0, 0) +patcher = Create_Patches(x, in_patchShape, stride_shape) +x_pat = patcher.create_patches(add_frame=False, center=True) +n_in_patches = patcher.nw_patches * patcher.nh_patches + +T = 250 ## num time steps to simulate (stimulus presentation window length) +dt = 1. ## integration time constant + +################################################################################ +print("--- Building Model ---") +dkey = random.PRNGKey(seed) +dkey, *subkeys = random.split(dkey, 3) +## Create model +makedir(exp_dir) +## get model from imported header file +model = build_model(seed, in_dim=in_dim, n_in_patches=n_in_patches, + in_patchShape=in_patchShape) +nodes, node_map = get_nodes(model) +model.save_to_json(exp_dir, model_type) +################################################################################ +print("--- Starting Simulation ---") + +sim_start_time = time.time() ## start time profiling + +model.viz(name="{}{}".format(exp_dir, "recFields"), low_rez=True, + raster_name="{}{}".format(exp_dir, "raster_plot")) + +print("------------------------------------") +model.showStats(-1) + +model_dir = "{}{}/custom_snapshot{}".format(exp_dir, model_type, 0) +save_parameters(model_dir, nodes) +viz_fields(name="{}{}".format(exp_dir, "init_W1"), W1=node_map.get("W1"), n_in_patches=n_in_patches) + +## enter main adaptation loop over data patterns +class_responses = jnp.zeros((_Y.shape[1], node_map.get("z2e").n_units)) +num_bound = 0 +n_total_samp_seen = 0 +for i in range(n_iter): + dkey, *subkeys = random.split(dkey, 2) + ptrs = random.permutation(subkeys[0], _X.shape[0]) + X = _X[ptrs, :] + Y = _Y[ptrs, :] + + tstart = time.time() + n_samps_seen = 0 + for j in range(n_batches): + idx = j + Xb = X[idx: idx + mb_size, :] + Yb = Y[idx: idx + mb_size, :] + + model.reset() + patcher.img = jnp.reshape(Xb, (28, 28)) + xs = patcher.create_patches(add_frame=False, center=True) + xs = xs.reshape(1, -1) + model.clamp(xs) + spikes1, spikes2 = model.observe(jnp.array([[dt * k, dt] for k in range(T)])) + + if n_total_samp_seen >= bind_target: + responses = Yb.T * jnp.sum(spikes2, axis=0) + class_responses = class_responses + responses + num_bound += 1 + + n_samps_seen += Xb.shape[0] + n_total_samp_seen += Xb.shape[0] + + print("\r Seen {} images (Binding {})...".format(n_samps_seen, num_bound), end="") + if (j+1) % viz_mod == 0: ## save intermediate receptive fields + tend = time.time() + print() + print(" -> Time = {} s".format(tend - tstart)) + tstart = tend + 0. + model.showStats(i) + viz_fields(name="{}{}".format(exp_dir, "W1"), W1=node_map.get("W1"), n_in_patches=n_in_patches) + model.viz(name="{}{}".format(exp_dir, "recFields"), low_rez=True, + raster_name="{}{}".format(exp_dir, "raster_plot")) + + ## save a running/current overridden copy of NPZ parameters + model_dir = "{}{}/custom".format(exp_dir, model_type) + save_parameters(model_dir, nodes) + + ## end of iteration/epoch + ## save a snapshot of the NPZ parameters at this particular epoch + model_dir = "{}{}/custom_snapshot{}".format(exp_dir, model_type, i) + save_parameters(model_dir, nodes) +print() + +class_responses = jnp.argmax(class_responses, axis=0, keepdims=True) +print("---- Max Class Responses ----") +print(class_responses) +print(class_responses.shape) +jnp.save("{}binded_labels.npy".format(exp_dir), class_responses) + +## stop time profiling +sim_end_time = time.time() +sim_time = sim_end_time - sim_start_time +sim_time_hr = (sim_time/3600.0) # convert time to hours + +print("------------------------------------") +print(" Trial.sim_time = {} h ({} sec)".format(sim_time_hr, sim_time)) diff --git a/exhibits/time-integrated-stdp/patch_model/sample_model.sh b/exhibits/time-integrated-stdp/patch_model/sample_model.sh new file mode 100755 index 0000000..66b948e --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/sample_model.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +## get in user-provided program args +GPU_ID=$1 #1 +MODEL=$2 # evstdp trstdp tistdp stdp + +if [[ "$MODEL" != "evstdp" && "$MODEL" != "tistdp" ]]; then + echo "Invalid Arg: $MODEL -- only 'tistdp', 'evstdp' models supported!" + exit 1 +fi +echo " >>>> Setting up $MODEL on GPU $GPU_ID" + +PARAM_SUBDIR="/custom" +SEED=1234 #(1234 77 811) +N_SAMPLES=5000 #1000 #50000 +DATA_X="../../data/mnist/trainX.npy" +DATA_Y="../../data/mnist/trainY.npy" +#DEV_X="../../data/mnist/testX.npy" # validX.npy +#DEV_Y="../../data/mnist/testY.npy" # validY.npy + +EXP_DIR="exp_$MODEL""_$SEED/" +echo " > Running Simulation/Model: $EXP_DIR" + +## train model +CUDA_VISIBLE_DEVICES=$GPU_ID python assemble_patterns.py --dataX=$DATA_X --dataY=$DATA_Y \ + --n_samples=$N_SAMPLES --exp_dir=$EXP_DIR \ + --model_dir=$EXP_DIR$MODEL \ + --param_subdir=$PARAM_SUBDIR \ + --model_type=$MODEL --seed=$SEED diff --git a/exhibits/time-integrated-stdp/patch_model/train_by_patch.py b/exhibits/time-integrated-stdp/patch_model/train_by_patch.py new file mode 100755 index 0000000..3849001 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/train_by_patch.py @@ -0,0 +1,167 @@ +from jax import numpy as jnp, random +import sys, getopt as gopt, optparse, time +from ngclearn import Context +from ngclearn.utils.io_utils import makedir +from ngclearn.utils.patch_utils import generate_patch_set + +import matplotlib #.pyplot as plt +matplotlib.use('Agg') +import matplotlib.pyplot as plt +cmap = plt.cm.jet + +def save_parameters(model_dir, nodes): ## model context saving routine + makedir(model_dir) + for node in nodes: + node.save(model_dir) ## call node's local save function + +################################################################################ +# read in general program arguments +options, remainder = gopt.getopt(sys.argv[1:], '', ["dataX=", "dataY=", "n_samples=", + "n_iter=", "verbosity=", + "bind_target=", "exp_dir=", + "model_type=", "seed="]) + +seed = 1234 +exp_dir = "exp/" +model_type = "tistdp" +n_iter = 1 # 10 ## total number passes through dataset +n_samples = -1 +bind_target = 40000 +dataX = "../../data/mnist/trainX.npy" +dataY = "../../data/mnist/trainY.npy" +verbosity = 0 ## verbosity level (0 - fairly minimal, 1 - prints multiple lines on I/O) +for opt, arg in options: + if opt in ("--dataX"): + dataX = arg.strip() + elif opt in ("--dataY"): + dataY = arg.strip() + elif opt in ("--verbosity"): + verbosity = int(arg.strip()) + elif opt in ("--n_samples"): + n_samples = int(arg.strip()) + elif opt in ("--n_iter"): + n_iter = int(arg.strip()) + elif opt in ("--bind_target"): + bind_target = int(arg.strip()) + elif opt in ("--exp_dir"): + exp_dir = arg.strip() + elif opt in ("--model_type"): + model_type = arg.strip() + elif opt in ("--seed"): + seed = int(arg.strip()) +if model_type == "tistdp": + print(" >> Setting up TI-STDP builder!") + from tistdp_snn import build_model, get_nodes +elif model_type == "trstdp": + print(" >> Setting up Trace-based STDP builder!") + from trstdp_snn import build_model, get_nodes +elif model_type == "evstdp": + print(" >> Setting up Event-Driven STDP builder!") + from evstdp_snn import build_model, get_nodes +elif model_type == "stdp": + print(" >> Setting up classical STDP builder!") + from stdp_snn import build_model, get_nodes +print(">> X: {} Y: {}".format(dataX, dataY)) + +## load dataset +_X = jnp.load(dataX) +_Y = jnp.load(dataY) +if n_samples > 0: + _X = _X[0:n_samples, :] + _Y = _Y[0:n_samples, :] + print("-> Fitting model to only {} samples".format(n_samples)) +n_batches = _X.shape[0] ## num batches is = to num samples (online learning) + +## basic simulation hyper-parameter/configuration values go here +viz_mod = 1000 #10000 +mb_size = 1 ## locked to batch sizes of 1 +patch_shape = (10, 10) #(28, 28) +in_dim = patch_shape[0] * patch_shape[1] +num_patches = 10 + +T = 250 #300 ## num time steps to simulate (stimulus presentation window length) +dt = 1. ## integration time constant + +################################################################################ +print("--- Building Model ---") +dkey = random.PRNGKey(seed) +dkey, *subkeys = random.split(dkey, 3) +## Create model +makedir(exp_dir) +model = build_model(seed, in_dim=in_dim) #Context("model") ## get model from imported header file +nodes, node_map = get_nodes(model) +model.save_to_json(exp_dir, model_type) +################################################################################ +print("--- Starting Simulation ---") + +sim_start_time = time.time() ## start time profiling + +model.viz(name="{}{}".format(exp_dir, "recFields"), low_rez=True, + raster_name="{}{}".format(exp_dir, "raster_plot")) + +print("------------------------------------") +model.showStats(-1) + +model_dir = "{}{}/custom_snapshot{}".format(exp_dir, model_type, 0) +save_parameters(model_dir, nodes) + +## enter main adaptation loop over data patterns +class_responses = jnp.zeros((_Y.shape[1], node_map.get("z2e").n_units)) +num_bound = 0 +n_total_samp_seen = 0 +for i in range(n_iter): + dkey, *subkeys = random.split(dkey, 2) + ptrs = random.permutation(subkeys[0], _X.shape[0]) + X = _X[ptrs, :] + Y = _Y[ptrs, :] + + tstart = time.time() + n_samps_seen = 0 + for j in range(n_batches): + idx = j + Xb = X[idx: idx + mb_size, :] + Yb = Y[idx: idx + mb_size, :] + + ## generate a set of patches from current pattern + Xb = generate_patch_set(Xb, patch_shape, num_patches, center=False) + for p in range(Xb.shape[0]): # within a batch of patches, adapt SNN + xs = jnp.expand_dims(Xb[p, :], axis=0) + flag = jnp.sum(xs) + if flag > 0.: + model.reset() + model.clamp(xs) + spikes1, spikes2 = model.observe(jnp.array([[dt * k, dt] for k in range(T)])) + + n_samps_seen += Xb.shape[0] + n_total_samp_seen += Xb.shape[0] + print("\r Seen {} images (Binding {})...".format(n_samps_seen, num_bound), end="") + if (j+1) % viz_mod == 0: ## save intermediate receptive fields + tend = time.time() + print() + print(" -> Time = {} s".format(tend - tstart)) + tstart = tend + 0. + model.showStats(i) + model.viz(name="{}{}".format(exp_dir, "recFields"), low_rez=True, + raster_name="{}{}".format(exp_dir, "raster_plot")) + + ## save a running/current overridden copy of NPZ parameters + model_dir = "{}{}/custom".format(exp_dir, model_type) + save_parameters(model_dir, nodes) + + ## end of iteration/epoch + ## save a snapshot of the NPZ parameters at this particular epoch + model_dir = "{}{}/custom_snapshot{}".format(exp_dir, model_type, i) + save_parameters(model_dir, nodes) +print() + +## stop time profiling +sim_end_time = time.time() +sim_time = sim_end_time - sim_start_time +sim_time_hr = (sim_time/3600.0) # convert time to hours + +#plot_clusters(W1, W2) + +print("------------------------------------") +print(" Trial.sim_time = {} h ({} sec)".format(sim_time_hr, sim_time)) + + diff --git a/exhibits/time-integrated-stdp/patch_model/train_patch_models.sh b/exhibits/time-integrated-stdp/patch_model/train_patch_models.sh new file mode 100755 index 0000000..db0c7fb --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/train_patch_models.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +## get in user-provided program args +GPU_ID=$1 #1 +MODEL=$2 # evstdp trstdp tistdp stdp + +if [[ "$MODEL" != "evstdp" && "$MODEL" != "tistdp" ]]; then + echo "Invalid Arg: $MODEL -- only 'tistdp', 'evstdp' models supported!" + exit 1 +fi +echo " >>>> Setting up $MODEL on GPU $GPU_ID" + +SEEDS=(1234) +N_ITER=10 +N_SAMPLES=50000 #10000 #5000 #1000 #50000 +BIND_COUNT=10000 +BIND_TARGET=$((($N_ITER - 1) * $N_SAMPLES + ($N_SAMPLES - $BIND_COUNT))) + +DATA_X="../../data/mnist/trainX.npy" +DATA_Y="../../data/mnist/trainY.npy" +#DEV_X="../../data/mnist/testX.npy" # validX.npy +#DEV_Y="../../data/mnist/testY.npy" # validY.npy + +if (( N_ITER * N_SAMPLES < BIND_COUNT )) ; then + echo "Not enough samples to reach bind target!" + exit 1 +fi + +for seed in "${SEEDS[@]}" +do + EXP_DIR="exp_$MODEL""_$seed/" + echo " > Running Simulation/Model: $EXP_DIR" + + rm -r $EXP_DIR* + ## train model + CUDA_VISIBLE_DEVICES=$GPU_ID python patched_train.py --dataX=$DATA_X --dataY=$DATA_Y \ + --n_iter=$N_ITER --bind_target=$BIND_TARGET \ + --n_samples=$N_SAMPLES --exp_dir=$EXP_DIR \ + --model_type=$MODEL --seed=$seed +done diff --git a/exhibits/time-integrated-stdp/snn_case1.py b/exhibits/time-integrated-stdp/snn_case1.py new file mode 100755 index 0000000..5883c5c --- /dev/null +++ b/exhibits/time-integrated-stdp/snn_case1.py @@ -0,0 +1,321 @@ +from ngclearn import Context, numpy as jnp +import math +from ngclearn.components import (LIFCell, PoissonCell, StaticSynapse, + VarTrace, Monitor, EventSTDPSynapse, TraceSTDPSynapse) +from custom.ti_stdp_synapse import TI_STDP_Synapse +from ngclearn.operations import summation +from jax import jit, random +from ngclearn.utils.viz.raster import create_raster_plot +from ngclearn.utils.viz.synapse_plot import viz_block +from ngclearn.utils.model_utils import scanner +import ngclearn.utils.weight_distribution as dists +from ngclearn.utils.model_utils import normalize_matrix +from matplotlib import pyplot as plt +import ngclearn.utils as utils + +def get_nodes(model): + nodes = model.get_components("W1", "W1ie", "W1ei", "z0", "z1e", "z1i", + "W2", "W2ie", "W2ei", "z2e", "z2i", "M") + map = {} ## node-name look-up hash table + for node in nodes: + map[node.name] = node + return nodes, map + +def build_model(seed=1234, in_dim=1, is_patch_model=False, algo="tistdp"): + window_length = 250 #300 + + dt = 1 + X_size = int(jnp.sqrt(in_dim)) + # X_size = 28 + # in_dim = (X_size * X_size) + if is_patch_model: + hidden_size = 10 * 10 + out_size = 6 * 6 + else: + hidden_size = 25 * 25 + out_size = 15 * 15 + + exc = 22.5 + inh = 120. #60. #15. #10. + tau_m_e = 100. # ms (excitatory membrane time constant) + tau_m_i = 100. # ms (inhibitory membrane time constant) + tau_theta = 1e5 + theta_plus = 0.05 + thr_jitter = 0. + + R1 = R2 = 1. + if algo == "tistdp": + R1 = 1. + R2 = 6. #1. + elif algo == "evstdp": + R1 = 1. # 6. + R2 = 6. # 12. + elif algo == "trstdp": + R1 = 1. # 6. + R2 = 6. # 12. + tau_theta = 1e4 #1e5 + theta_plus = 0.2 #0.05 + + px = py = X_size + hidx = hidy = int(jnp.sqrt(hidden_size)) + + dkey = random.PRNGKey(seed) + dkey, *subkeys = random.split(dkey, 12) + + with Context("model") as model: + M = Monitor("M", default_window_length=window_length) + ## layer 0 + z0 = PoissonCell("z0", n_units=in_dim, max_freq=63.75, key=subkeys[0]) + ## layer 1 + z1e = LIFCell("z1e", n_units=hidden_size, tau_m=tau_m_e, resist_m=tau_m_e/dt, + refract_time=5., thr_jitter=thr_jitter, tau_theta=tau_theta, + theta_plus=theta_plus, one_spike=True, key=subkeys[2]) + z1i = LIFCell("z1i", n_units=hidden_size, tau_m=tau_m_i, resist_m=tau_m_i/dt, + refract_time=5., thr_jitter=thr_jitter, thr=-40., v_rest=-60., + v_reset=-45., tau_theta=0.) + W1ie = StaticSynapse("W1ie", shape=(hidden_size, hidden_size), + weight_init=dists.hollow(-inh)) + W1ei = StaticSynapse("W1ei", shape=(hidden_size, hidden_size), + weight_init=dists.eye(exc)) + ## layer 2 + z2e = LIFCell("z2e", n_units=out_size, tau_m=tau_m_e, resist_m=tau_m_e/dt, + refract_time=5., thr_jitter=thr_jitter, tau_theta=tau_theta, + theta_plus=theta_plus, one_spike=True, key=subkeys[4]) + z2i = LIFCell("z2i", n_units=out_size, tau_m=tau_m_i, resist_m=tau_m_i/dt, + refract_time=5., thr_jitter=thr_jitter, thr=-40., v_rest=-60., + v_reset=-45., tau_theta=0.) + W2ie = StaticSynapse("W2ie", shape=(out_size, out_size), + weight_init=dists.hollow(-inh)) + W2ei = StaticSynapse("W2ei", shape=(out_size, out_size), + weight_init=dists.eye(exc)) + + tr0 = tr1 = tr2 = None + if algo == "tistdp": + print(" >> Equipping SNN with TI-STDP Adaptation") + W1 = TI_STDP_Synapse("W1", alpha=0.0075 * 0.5, beta=1.25, + pre_decay=0.75, + shape=((X_size ** 2), hidden_size), + weight_init=dists.uniform(amin=0.025, + amax=0.8), + resist_scale=R1, key=subkeys[1]) + W2 = TI_STDP_Synapse("W2", alpha=0.025 * 2, beta=2, + pre_decay=0.25 * 0.5, + shape=(hidden_size, out_size), + weight_init=dists.uniform(amin=0.025, + amax=0.8), + resist_scale=R2, key=subkeys[6]) + # wire relevant plasticity statistics to synaptic cables W1 and W2 + W1.pre << z0.tols + W1.post << z1e.tols + W2.pre << z1e.tols + W2.post << z2e.tols + elif algo == "trstdp": + print(" >> Equipping SNN with Trace/TR-STDP Adaptation") + tau_tr = 20. # 10. + trace_delta = 0. + x_tar1 = 0.3 + x_tar2 = 0.025 + Aplus = 1e-2 ## LTP learning rate (STDP); nu1 + Aminus = 1e-3 ## LTD learning rate (STDP); nu0 + mu = 1. # 0. + W1 = TraceSTDPSynapse("W1", shape=(in_dim, hidden_size), mu=mu, + A_plus=Aplus, A_minus=Aminus, eta=1., + pretrace_target=x_tar1, + weight_init=dists.uniform(amin=0.025, amax=0.8), + resist_scale=R1, key=subkeys[1]) + W2 = TraceSTDPSynapse("W2", shape=(hidden_size, out_size), mu=mu, + A_plus=Aplus, A_minus=Aminus, eta=1., + pretrace_target=x_tar2, resist_scale=R2, + weight_init=dists.uniform(amin=0.025, amax=0.8), + key=subkeys[3]) + ## traces + tr0 = VarTrace("tr0", n_units=in_dim, tau_tr=tau_tr, + decay_type="exp", + a_delta=trace_delta) + tr1 = VarTrace("tr1", n_units=hidden_size, tau_tr=tau_tr, + decay_type="exp", + a_delta=trace_delta) + tr2 = VarTrace("tr2", n_units=out_size, tau_tr=tau_tr, + decay_type="exp", + a_delta=trace_delta) + # wire relevant compartment statistics to synaptic cables W1 and W2 + W1.preTrace << tr0.trace + W1.preSpike << z0.outputs + W1.postTrace << tr1.trace + W1.postSpike << z1e.s + + W2.preTrace << tr1.trace + W2.preSpike << z1e.s + W2.postTrace << tr2.trace + W2.postSpike << z2e.s + elif algo == "evstdp": + print(" >> Equipping SNN with EV-STDP Adaptation") + ## EV-STDP meta-parameters + eta_w = 1. + Aplus1 = 0.0055 + Aminus1 = 0.25 * 0.0055 + Aplus2 = 0.0055 + Aminus2 = 0.05 * 0.0055 + lmbda = 0. # 0.01 + W1 = EventSTDPSynapse("W1", shape=(in_dim, hidden_size), eta=eta_w, + A_plus=Aplus1, A_minus=Aminus1, + lmbda=lmbda, w_bound=1., presyn_win_len=2., + weight_init=dists.uniform(0.025, 0.8), + resist_scale=R1, key=subkeys[1]) + W2 = EventSTDPSynapse("W2", shape=(hidden_size, out_size), + eta=eta_w, + A_plus=Aplus2, A_minus=Aminus2, + lmbda=lmbda, w_bound=1., presyn_win_len=2., + weight_init=dists.uniform(amin=0.025, amax=0.8), + resist_scale=R2, key=subkeys[6]) + # wire relevant plasticity statistics to synaptic cables W1 and W2 + W1.pre_tols << z0.tols + W1.postSpike << z1e.s + W2.pre_tols << z1e.tols + W2.postSpike << z2e.s + else: + print("ERROR: algorithm ", algo, " provided is not supported!") + exit() + + ## layer 0 to layer 1 + W1.inputs << z0.outputs + W1ie.inputs << z1i.s + z1e.j << summation(W1.outputs, W1ie.outputs) + W1ei.inputs << z1e.s + z1i.j << W1ei.outputs + ## layer 1 to layer 2 + W2.inputs << z1e.s_raw + W2ie.inputs << z2i.s + z2e.j << summation(W2.outputs, W2ie.outputs) + W2ei.inputs << z2e.s + z2i.j << W2ei.outputs + + ## wire statistics into global monitor + ## layer 1 stats + M << z1e.s + M << z1e.j + M << z1e.v + ## layer 2 stats + M << z2e.s + M << z2e.j + M << z2e.v + + if algo == "trstdp": + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, + tr0, tr1, tr2, M, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, + compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, tr0, + tr1, tr2, + compile_key="reset") + else: + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, M, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, + compile_key="reset") + model.wrap_and_add_command(jit(model.reset), name="reset") + + @model.dynamicCommand + def clamp(x): + z0.inputs.set(x) + + @model.dynamicCommand + def viz(name, low_rez=True, raster_name="exp/raster_plot"): + viz_block([W1.weights.value, W2.weights.value], + [(px, py), (hidx, hidy)], name + "_block", padding=2, + low_rez=low_rez) + + fig, ax = plt.subplots(3, 2, sharex=True, figsize=(15, 8)) + + for i in range(out_size): + ax[1][0].plot([i for i in range(window_length)], + M.view(z1e.v)[:, :, i]) + ax[0][0].plot([i for i in range(window_length)], + M.view(z1e.j)[:, :, i]) + + ax[1][1].plot([i for i in range(window_length)], + M.view(z2e.v)[:, :, i]) + ax[0][1].plot([i for i in range(window_length)], + M.view(z2e.j)[:, :, i]) + utils.viz.raster.create_raster_plot(M.view(z1e.s), ax=ax[2][0]) + utils.viz.raster.create_raster_plot(M.view(z2e.s), ax=ax[2][1]) + # plt.show() + plt.savefig(raster_name) + plt.close() + + @scanner + def observe(current_state, args): + _t, _dt = args + current_state = model.advance_state(current_state, t=_t, dt=_dt) + current_state = model.evolve(current_state, t=_t, dt=_dt) + return current_state, (current_state[z1e.s_raw.path], + current_state[z2e.s_raw.path]) + + @model.dynamicCommand + def showStats(i): + print(f"\n~~~~~Iteration {str(i)}~~~~~~") + print("W1:\n min {} ; max {} \n mu {} ; norm {}".format(jnp.amin(W1.weights.value), jnp.amax(W1.weights.value), jnp.mean(W1.weights.value), + jnp.linalg.norm(W1.weights.value))) + print("W2:\n min {} ; max {} \n mu {} ; norm {}".format(jnp.amin(W2.weights.value), jnp.amax(W2.weights.value), jnp.mean(W2.weights.value), + jnp.linalg.norm(W2.weights.value))) + #M.halt_all() + return model + +def load_from_disk(model_directory, param_dir="/custom", disable_adaptation=True): + with Context("model") as model: + model.load_from_dir(model_directory, custom_folder=param_dir) + nodes = model.get_components("W1", "W1ie", "W1ei", "W2", "W2ie", "W2ei", + "z0", "z1e", "z1i", "z2e", "z2i") + (W1, W1ie, W1ei, W2, W2ie, W2ei,z0, z1e, z1i, z2e, z2i) = nodes + if disable_adaptation: + z1e.tau_theta = 0. ## disable homeostatic adaptation + z2e.tau_theta = 0. ## disable homeostatic adaptation + + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, + compile_key="reset") + model.wrap_and_add_command(jit(model.reset), name="reset") + + @model.dynamicCommand + def clamp(x): + z0.inputs.set(x) + + @model.dynamicCommand + def viz(name, low_rez=True): + viz_block([W1.weights.value, W2.weights.value], + [(28, 28), (10, 10)], name + "_block", padding=2, + low_rez=low_rez) + + @scanner + def infer(current_state, args): + _t, _dt = args + current_state = model.advance_state(current_state, t=_t, dt=_dt) + return current_state, (current_state[z1e.s_raw.path], + current_state[z2e.s_raw.path]) + + @model.dynamicCommand + def showStats(i): + print(f"\n~~~~~Iteration {str(i)}~~~~~~") + print("W1:\n min {} ; max {} \n mu {} ; norm {}".format( + jnp.amin(W1.weights.value), jnp.amax(W1.weights.value), + jnp.mean(W1.weights.value), + jnp.linalg.norm(W1.weights.value))) + print("W2:\n min {} ; max {} \n mu {} ; norm {}".format( + jnp.amin(W2.weights.value), jnp.amax(W2.weights.value), + jnp.mean(W2.weights.value), + jnp.linalg.norm(W2.weights.value))) + return model + diff --git a/exhibits/time-integrated-stdp/snn_case2.py b/exhibits/time-integrated-stdp/snn_case2.py new file mode 100755 index 0000000..489d3f2 --- /dev/null +++ b/exhibits/time-integrated-stdp/snn_case2.py @@ -0,0 +1,324 @@ +from ngclearn import Context, numpy as jnp +import math +from ngclearn.components import (LIFCell, PoissonCell, StaticSynapse, + VarTrace, Monitor, EventSTDPSynapse, TraceSTDPSynapse) +from custom.ti_stdp_synapse import TI_STDP_Synapse +from ngclearn.operations import summation +from jax import jit, random +from ngclearn.utils.viz.raster import create_raster_plot +from ngclearn.utils.viz.synapse_plot import viz_block +from ngclearn.utils.model_utils import scanner +import ngclearn.utils.weight_distribution as dists +from ngclearn.utils.model_utils import normalize_matrix +from matplotlib import pyplot as plt +import ngclearn.utils as utils + +def get_nodes(model): + nodes = model.get_components("W1", "W1ie", "W1ei", "z0", "z1e", "z1i", + "W2", "W2ie", "W2ei", "z2e", "z2i", "M") + map = {} ## node-name look-up hash table + for node in nodes: + map[node.name] = node + return nodes, map + +def build_model(seed=1234, in_dim=1, is_patch_model=False, algo="tistdp"): + window_length = 250 #300 + + dt = 1 + X_size = int(jnp.sqrt(in_dim)) + if is_patch_model: + hidden_size = 10 * 10 + out_size = 6 * 6 + else: + hidden_size = 25 * 25 + out_size = 15 * 15 + + # exc = 22.5 + # inh = 120. #60. #15. #10. + tau_m_e = 100. # ms (excitatory membrane time constant) + tau_m_i = 100. # ms (inhibitory membrane time constant) + # tau_theta = 1e5 + # theta_plus = 0.05 + # thr_jitter = 0. + + R1 = R2 = 1. + if algo == "tistdp": + R1 = 1. + R2 = 1. + elif algo == "evstdp": + R1 = 1. #12. #1. + R2 = 6. # 12. + elif algo == "trstdp": + R1 = 1. # 6. + R2 = 6. # 12. + + px = py = X_size + hidx = hidy = int(jnp.sqrt(hidden_size)) + + dkey = random.PRNGKey(seed) + dkey, *subkeys = random.split(dkey, 12) + + with Context("model") as model: + M = Monitor("M", default_window_length=window_length) + ## layer 0 + z0 = PoissonCell("z0", n_units=in_dim, max_freq=63.75, key=subkeys[0]) + ## layer 1 + z1e = LIFCell("z1e", n_units=hidden_size, tau_m=tau_m_e, + resist_m=25, tau_theta=1e4, theta_plus=0.2, refract_time=5., + one_spike=True, key=subkeys[2]) + z1i = LIFCell("z1i", n_units=hidden_size, tau_m=tau_m_i, + resist_m=tau_m_i / dt, thr=-40., v_rest=-60., + v_reset=-45., tau_theta=0., refract_time=5., ) + W1ie = StaticSynapse("W1ie", shape=(hidden_size, hidden_size), + weight_init=dists.hollow(-10)) + W1ei = StaticSynapse("W1ei", shape=(hidden_size, hidden_size), + weight_init=dists.eye(22.5)) + ## layer 2 + z2e = LIFCell("z2e", n_units=out_size, tau_m=tau_m_e, + resist_m=500., thr_jitter=20., tau_theta=1e4, refract_time=5., + theta_plus=0.2, one_spike=True, key=subkeys[4]) + z2i = LIFCell("z2i", n_units=out_size, tau_m=tau_m_i, + resist_m=tau_m_i / dt, refract_time=5., + thr=-40., v_rest=-60., v_reset=-45., tau_theta=0.,) + ## if make z2e bigger, then make the inhibition weaker possibly + W2ie = StaticSynapse("W2ie", shape=(out_size, out_size), + weight_init=dists.hollow(-10)) + W2ei = StaticSynapse("W2ei", shape=(out_size, out_size), + weight_init=dists.eye(22.5)) + + tr0 = tr1 = tr2 = None + if algo == "tistdp": + # Working numbers 1.5, 0.2 + # Single pass -> alpha=0.025, beta=2, pre_decay=0.25 * 0.5, acc: 60% + print(" >> Equipping SNN with TI-STDP Adaptation") + W1 = TI_STDP_Synapse("W1", alpha=0.0075 * 0.5, beta=1.25, + pre_decay=0.75, + shape=((X_size ** 2), hidden_size), + weight_init=dists.uniform(amin=0.025, + amax=0.8), + resist_scale=R1, key=subkeys[1]) + W2 = TI_STDP_Synapse("W2", alpha=0.025, beta=2, + pre_decay=0.25 * 0.5, + shape=(hidden_size, out_size), + weight_init=dists.uniform(amin=0.025, + amax=0.8), + resist_scale=R2, key=subkeys[6]) + # wire relevant plasticity statistics to synaptic cables W1 and W2 + W1.pre << z0.tols + W1.post << z1e.tols + W2.pre << z1e.tols + W2.post << z2e.tols + elif algo == "trstdp": + print(" >> Equipping SNN with Trace/TR-STDP Adaptation") + tau_tr = 20. # 40. # 10. + trace_delta = 0. + x_tar1 = 0.3 + x_tar2 = 0.025 + Aplus = 1e-2 ## LTP learning rate (STDP); nu1 + Aminus = 1e-3 ## LTD learning rate (STDP); nu0 + mu = 1. # 0. + W1 = TraceSTDPSynapse("W1", shape=(in_dim, hidden_size), mu=mu, + A_plus=Aplus, A_minus=Aminus, eta=1., + pretrace_target=x_tar1, + weight_init=dists.uniform(amin=0.025, amax=0.8), + resist_scale=R1, key=subkeys[1]) + W2 = TraceSTDPSynapse("W2", shape=(hidden_size, out_size), mu=mu, + A_plus=Aplus, A_minus=Aminus, eta=1., + pretrace_target=x_tar2, resist_scale=R2, + weight_init=dists.uniform(amin=0.025, amax=0.8), + key=subkeys[3]) + ## traces + tr0 = VarTrace("tr0", n_units=in_dim, tau_tr=tau_tr, + decay_type="exp", + a_delta=trace_delta) + tr1 = VarTrace("tr1", n_units=hidden_size, tau_tr=tau_tr, + decay_type="exp", + a_delta=trace_delta) + tr2 = VarTrace("tr2", n_units=out_size, tau_tr=tau_tr, + decay_type="exp", + a_delta=trace_delta) + # wire cells z0 and z1e to their respective traces + tr0.inputs << z0.outputs + tr1.inputs << z1e.s + tr2.inputs << z2e.s + # wire relevant compartment statistics to synaptic cables W1 and W2 + W1.preTrace << tr0.trace + W1.preSpike << z0.outputs + W1.postTrace << tr1.trace + W1.postSpike << z1e.s + + W2.preTrace << tr1.trace + W2.preSpike << z1e.s + W2.postTrace << tr2.trace + W2.postSpike << z2e.s + elif algo == "evstdp": + print(" >> Equipping SNN with EV-STDP Adaptation") + ## EV-STDP meta-parameters + eta_w = 0.01 # 0.002 #1. + Aplus1 = 1. #0.0055 * 2 + Aminus1 = 0.3 #0.25 * 0.0055 + Aplus2 = 1. #0.0055 * 2 + Aminus2 = 0.075 # 0.06 # 0.3 #0.05 * 0.0055 + lmbda = 0. # 0.01 + W1 = EventSTDPSynapse("W1", shape=(in_dim, hidden_size), eta=eta_w, + A_plus=Aplus1, A_minus=Aminus1, + lmbda=lmbda, w_bound=1., presyn_win_len=2., + weight_init=dists.uniform(0.025, 0.8), + resist_scale=R1, key=subkeys[1]) + W2 = EventSTDPSynapse("W2", shape=(hidden_size, out_size), + eta=eta_w, A_plus=Aplus2, A_minus=Aminus2, + lmbda=lmbda, w_bound=1., presyn_win_len=2., + weight_init=dists.uniform(amin=0.025, amax=0.8), + resist_scale=R2, key=subkeys[6]) + # wire relevant plasticity statistics to synaptic cables W1 and W2 + W1.pre_tols << z0.tols + W1.postSpike << z1e.s + W2.pre_tols << z1e.tols + W2.postSpike << z2e.s + else: + print("ERROR: algorithm ", algo, " provided is not supported!") + exit() + + ## layer 0 to layer 1 + W1.inputs << z0.outputs + W1ie.inputs << z1i.s + z1e.j << summation(W1.outputs, W1ie.outputs) + W1ei.inputs << z1e.s + z1i.j << W1ei.outputs + ## layer 1 to layer 2 + W2.inputs << z1e.s_raw + W2ie.inputs << z2i.s + z2e.j << summation(W2.outputs, W2ie.outputs) + W2ei.inputs << z2e.s + z2i.j << W2ei.outputs + + ## wire statistics into global monitor + ## layer 1 stats + M << z1e.s + M << z1e.j + M << z1e.v + ## layer 2 stats + M << z2e.s + M << z2e.j + M << z2e.v + + if algo == "trstdp": + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, + tr0, tr1, tr2, M, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, + compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, tr0, + tr1, tr2, + compile_key="reset") + model.wrap_and_add_command(jit(model.reset), name="reset") + else: + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, M, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, + compile_key="reset") + model.wrap_and_add_command(jit(model.reset), name="reset") + + @model.dynamicCommand + def clamp(x): + z0.inputs.set(x) + + @model.dynamicCommand + def viz(name, low_rez=True, raster_name="exp/raster_plot"): + viz_block([W1.weights.value, W2.weights.value], + [(px, py), (hidx, hidy)], name + "_block", padding=2, + low_rez=low_rez) + + fig, ax = plt.subplots(3, 2, sharex=True, figsize=(15, 8)) + + for i in range(out_size): + ax[1][0].plot([i for i in range(window_length)], + M.view(z1e.v)[:, :, i]) + ax[0][0].plot([i for i in range(window_length)], + M.view(z1e.j)[:, :, i]) + + ax[1][1].plot([i for i in range(window_length)], + M.view(z2e.v)[:, :, i]) + ax[0][1].plot([i for i in range(window_length)], + M.view(z2e.j)[:, :, i]) + utils.viz.raster.create_raster_plot(M.view(z1e.s), ax=ax[2][0]) + utils.viz.raster.create_raster_plot(M.view(z2e.s), ax=ax[2][1]) + # plt.show() + plt.savefig(raster_name) + plt.close() + + @scanner + def observe(current_state, args): + _t, _dt = args + current_state = model.advance_state(current_state, t=_t, dt=_dt) + current_state = model.evolve(current_state, t=_t, dt=_dt) + return current_state, (current_state[z1e.s_raw.path], + current_state[z2e.s_raw.path]) + + @model.dynamicCommand + def showStats(i): + print(f"\n~~~~~Iteration {str(i)}~~~~~~") + print("W1:\n min {} ; max {} \n mu {} ; norm {}".format(jnp.amin(W1.weights.value), jnp.amax(W1.weights.value), jnp.mean(W1.weights.value), + jnp.linalg.norm(W1.weights.value))) + print("W2:\n min {} ; max {} \n mu {} ; norm {}".format(jnp.amin(W2.weights.value), jnp.amax(W2.weights.value), jnp.mean(W2.weights.value), + jnp.linalg.norm(W2.weights.value))) + #M.halt_all() + return model + +def load_from_disk(model_directory, param_dir="/custom", disable_adaptation=True): + with Context("model") as model: + model.load_from_dir(model_directory, custom_folder=param_dir) + nodes = model.get_components("W1", "W1ie", "W1ei", "W2", "W2ie", "W2ei", + "z0", "z1e", "z1i", "z2e", "z2i") + (W1, W1ie, W1ei, W2, W2ie, W2ei,z0, z1e, z1i, z2e, z2i) = nodes + if disable_adaptation: + z1e.tau_theta = 0. ## disable homeostatic adaptation + z2e.tau_theta = 0. ## disable homeostatic adaptation + + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, + compile_key="reset") + model.wrap_and_add_command(jit(model.reset), name="reset") + + @model.dynamicCommand + def clamp(x): + z0.inputs.set(x) + + @model.dynamicCommand + def viz(name, low_rez=True): + viz_block([W1.weights.value, W2.weights.value], + [(28, 28), (10, 10)], name + "_block", padding=2, + low_rez=low_rez) + + @scanner + def infer(current_state, args): + _t, _dt = args + current_state = model.advance_state(current_state, t=_t, dt=_dt) + return current_state, (current_state[z1e.s_raw.path], + current_state[z2e.s_raw.path]) + + @model.dynamicCommand + def showStats(i): + print(f"\n~~~~~Iteration {str(i)}~~~~~~") + print("W1:\n min {} ; max {} \n mu {} ; norm {}".format( + jnp.amin(W1.weights.value), jnp.amax(W1.weights.value), + jnp.mean(W1.weights.value), + jnp.linalg.norm(W1.weights.value))) + print("W2:\n min {} ; max {} \n mu {} ; norm {}".format( + jnp.amin(W2.weights.value), jnp.amax(W2.weights.value), + jnp.mean(W2.weights.value), + jnp.linalg.norm(W2.weights.value))) + return model + diff --git a/exhibits/time-integrated-stdp/train.py b/exhibits/time-integrated-stdp/train.py new file mode 100755 index 0000000..1b88c7e --- /dev/null +++ b/exhibits/time-integrated-stdp/train.py @@ -0,0 +1,206 @@ +from jax import numpy as jnp, random +import sys, getopt as gopt, optparse, time +from ngclearn.utils.io_utils import makedir +from ngclearn.utils.patch_utils import generate_patch_set + +import matplotlib #.pyplot as plt +matplotlib.use('Agg') +import matplotlib.pyplot as plt +cmap = plt.cm.jet + +def save_parameters(model_dir, nodes): ## model context saving routine + makedir(model_dir) + for node in nodes: + node.save(model_dir) ## call node's local save function + +################################################################################ +# read in general program arguments +options, remainder = gopt.getopt( + sys.argv[1:], '', ["dataX=", "dataY=", "n_samples=", "n_iter=", "verbosity=", + "bind_target=", "exp_dir=", "model_type=", "seed=", + "use_patches=", "model_case="] +) + +model_case = "snn_case1" +seed = 1234 +exp_dir = "exp/" +model_type = "tistdp" +n_iter = 1 # 10 ## total number passes through dataset +n_samples = -1 +bind_target = 40000 +use_patches = False +dataX = "../../data/mnist/trainX.npy" +dataY = "../../data/mnist/trainY.npy" +verbosity = 0 ## verbosity level (0 - fairly minimal, 1 - prints multiple lines on I/O) +for opt, arg in options: + if opt in ("--dataX"): + dataX = arg.strip() + elif opt in ("--dataY"): + dataY = arg.strip() + elif opt in ("--verbosity"): + verbosity = int(arg.strip()) + elif opt in ("--n_samples"): + n_samples = int(arg.strip()) + elif opt in ("--n_iter"): + n_iter = int(arg.strip()) + elif opt in ("--bind_target"): + bind_target = int(arg.strip()) + elif opt in ("--exp_dir"): + exp_dir = arg.strip() + elif opt in ("--model_type"): + model_type = arg.strip() + elif opt in ("--model_case"): + model_case = arg.strip() + elif opt in ("--seed"): + seed = int(arg.strip()) + elif opt in ("--use_patches"): + use_patches = int(arg.strip()) #(arg.strip().lower() == "true") + use_patches = (use_patches == 1) + +if model_case == "snn_case1": + print(" >> Setting up Case 1 model!") + from snn_case1 import build_model, get_nodes +elif model_case == "snn_case2": + print(" >> Setting up Case 2 model!") + from snn_case2 import build_model, get_nodes +else: + print("Error: No other model case studies supported! (", model_case, " invalid)") + exit() + +print(">> X: {} Y: {}".format(dataX, dataY)) + +## load dataset +_X = jnp.load(dataX) +_Y = jnp.load(dataY) +if n_samples > 0: + _X = _X[0:n_samples, :] + _Y = _Y[0:n_samples, :] + print("-> Fitting model to only {} samples".format(n_samples)) +n_batches = _X.shape[0] ## num batches is = to num samples (online learning) + +## basic simulation hyper-parameter/configuration values go here +viz_mod = 1000 #10000 +mb_size = 1 ## locked to batch sizes of 1 +patch_shape = (28, 28) ## same as image_shape +num_patches = 10 +if use_patches: + patch_shape = (10, 10) +in_dim = patch_shape[0] * patch_shape[1] + +T = 250 #300 ## num time steps to simulate (stimulus presentation window length) +dt = 1. ## integration time constant + +################################################################################ +print("--- Building Model ---") +dkey = random.PRNGKey(seed) +dkey, *subkeys = random.split(dkey, 3) +## Create model +makedir(exp_dir) +## get model from imported header file +model = build_model(seed, in_dim=in_dim, is_patch_model=use_patches, algo=model_type) +nodes, node_map = get_nodes(model) +model.save_to_json(exp_dir, model_type) +################################################################################ +print("--- Starting Simulation ---") + +sim_start_time = time.time() ## start time profiling + +model.viz(name="{}{}".format(exp_dir, "recFields"), low_rez=True, + raster_name="{}{}".format(exp_dir, "raster_plot")) + +print("------------------------------------") +model.showStats(-1) + +model_dir = "{}{}/custom_snapshot{}".format(exp_dir, model_type, 0) +save_parameters(model_dir, nodes) + +## enter main adaptation loop over data patterns +class_responses = jnp.zeros((_Y.shape[1], node_map.get("z2e").n_units)) +num_bound = 0 +n_total_samp_seen = 0 +for i in range(n_iter): + dkey, *subkeys = random.split(dkey, 2) + ptrs = random.permutation(subkeys[0], _X.shape[0]) + X = _X[ptrs, :] + Y = _Y[ptrs, :] + + tstart = time.time() + n_samps_seen = 0 + for j in range(n_batches): + idx = j + Xb = X[idx: idx + mb_size, :] + Yb = Y[idx: idx + mb_size, :] + + if use_patches: + ## generate a set of patches from current pattern + X_patches = generate_patch_set(Xb, patch_shape, num_patches, center=False) + for p in range(X_patches.shape[0]): # within a batch of patches, adapt SNN + xs = jnp.expand_dims(X_patches[p, :], axis=0) + flag = jnp.sum(xs) + if flag > 0.: + model.reset() + model.clamp(xs) + spikes1, spikes2 = model.observe( + jnp.array([[dt * k, dt] for k in range(T)])) + + if n_total_samp_seen >= bind_target: + responses = Yb.T * jnp.sum(spikes2, axis=0) + class_responses = class_responses + responses + num_bound += 1 + else: + model.reset() + model.clamp(Xb) + spikes1, spikes2 = model.observe(jnp.array([[dt * k, dt] for k in range(T)])) + # print(tr1) + # print("...") + # print(jnp.sum(spikes1, axis=0)) + # print(jnp.sum(spikes2, axis=0)) + # exit() + + if n_total_samp_seen >= bind_target: + responses = Yb.T * jnp.sum(spikes2, axis=0) + class_responses = class_responses + responses + num_bound += 1 + + n_samps_seen += Xb.shape[0] + n_total_samp_seen += Xb.shape[0] + + print("\r Seen {} images (Binding {})...".format(n_samps_seen, num_bound), end="") + if (j+1) % viz_mod == 0: ## save intermediate receptive fields + tend = time.time() + print() + print(" -> Time = {} s".format(tend - tstart)) + tstart = tend + 0. + model.showStats(i) + model.viz(name="{}{}".format(exp_dir, "recFields"), low_rez=True, + raster_name="{}{}".format(exp_dir, "raster_plot")) + + ## save a running/current overridden copy of NPZ parameters + model_dir = "{}{}/custom".format(exp_dir, model_type) + save_parameters(model_dir, nodes) + + ## end of iteration/epoch + ## save a snapshot of the NPZ parameters at this particular epoch + model_dir = "{}{}/custom_snapshot{}".format(exp_dir, model_type, i) + save_parameters(model_dir, nodes) +print() + +# print(" >> Producing final rec-fields / raster plots!") +# model.viz(name="{}{}".format(exp_dir, "recFields"), low_rez=False, +# raster_name="{}{}".format(exp_dir, "raster_plot")) + +class_responses = jnp.argmax(class_responses, axis=0, keepdims=True) +print("---- Max Class Responses ----") +print(class_responses) +print(class_responses.shape) +jnp.save("{}binded_labels.npy".format(exp_dir), class_responses) + +## stop time profiling +sim_end_time = time.time() +sim_time = sim_end_time - sim_start_time +sim_time_hr = (sim_time/3600.0) # convert time to hours + +print("------------------------------------") +print(" Trial.sim_time = {} h ({} sec)".format(sim_time_hr, sim_time)) + + diff --git a/exhibits/time-integrated-stdp/train_models.sh b/exhibits/time-integrated-stdp/train_models.sh new file mode 100755 index 0000000..f2fddbc --- /dev/null +++ b/exhibits/time-integrated-stdp/train_models.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +## get in user-provided program args +GPU_ID=$1 #1 +MODEL=$2 # evstdp trstdp tistdp stdp +CASE_STUDY=$3 ## snn_case1 snn_case2 +USE_PATCHES=0 ## set to false for this simulations study + +if [[ "$MODEL" != "evstdp" && "$MODEL" != "trstdp" && "$MODEL" != "tistdp" && "$MODEL" != "stdp" ]]; then + echo "Invalid Arg: $MODEL -- only 'evstdp', 'trstdp', 'tistdp', 'stdp', 'patch_evstdp' models supported!" + exit 1 +fi +if [[ "$CASE_STUDY" != "snn_case1" && "$CASE_STUDY" != "snn_case2" ]]; then + echo "Invalid Case provided: $CASE_STUDY -- only 'snn_case1', 'snn_case2' supported!" + exit 1 +fi +echo " >>>> Setting up $MODEL on GPU $GPU_ID" + +SEEDS=(1234 77 811) +N_ITER=20 +if [[ "$CASE_STUDY" == "snn_case2" ]]; then + N_ITER=1 ## Case 2 models focus on online learning; so only 1 pass allowed +fi + +if (( USE_PATCHES == 1 )) ; then + echo ">> Patch-level modeling configured!" + N_ITER=1 +fi +N_SAMPLES=50000 +BIND_COUNT=10000 +BIND_TARGET=$((($N_ITER - 1) * $N_SAMPLES + ($N_SAMPLES - $BIND_COUNT))) + +DATA_X="../../data/mnist/trainX.npy" +DATA_Y="../../data/mnist/trainY.npy" +#DEV_X="../../data/mnist/testX.npy" # validX.npy +#DEV_Y="../../data/mnist/testY.npy" # validY.npy + +if (( N_ITER * N_SAMPLES < BIND_COUNT )) ; then + echo "Not enough samples to reach bind target!" + exit 1 +fi + +for seed in "${SEEDS[@]}" +do + EXP_DIR="exp_$MODEL""_$seed/" + echo " > Running Simulation/Model: $EXP_DIR" + + rm -r $EXP_DIR* + ## train model + CUDA_VISIBLE_DEVICES=$GPU_ID python train.py --dataX=$DATA_X --dataY=$DATA_Y \ + --n_iter=$N_ITER --bind_target=$BIND_TARGET \ + --n_samples=$N_SAMPLES --exp_dir=$EXP_DIR \ + --model_type=$MODEL --seed=$seed \ + --use_patches=$USE_PATCHES \ + --model_case=$CASE_STUDY +done diff --git a/exhibits/time-integrated-stdp/viz_codes.py b/exhibits/time-integrated-stdp/viz_codes.py new file mode 100755 index 0000000..4c22b43 --- /dev/null +++ b/exhibits/time-integrated-stdp/viz_codes.py @@ -0,0 +1,29 @@ +from jax import numpy as jnp +import sys, getopt as gopt, optparse +from ngclearn.utils.viz.dim_reduce import extract_tsne_latents, plot_latents + + +# read in general program arguments +options, remainder = gopt.getopt(sys.argv[1:], '', ["labels_fname=", "codes_fname=", + "plot_fname="]) + +plot_fname = "codes.jpg" +labels_fname = "../../data/mnist/testY.npy" +codes_fname = "exp/test_codes.npy" +for opt, arg in options: + if opt in ("--labels_fname"): + labels_fname = arg.strip() + elif opt in ("--codes_fname"): + codes_fname = arg.strip() + elif opt in ("--plot_fname"): + plot_fname = arg.strip() + +labels = jnp.load(labels_fname) +print("Lab.shape: ", labels.shape) +codes = jnp.load(codes_fname) +print("Codes.shape: ", codes.shape) + +## visualize the above data via the t-SNE algorithm +tsne_codes = extract_tsne_latents(codes) +print("tSNE-codes.shape = ", tsne_codes.shape) +plot_latents(tsne_codes, labels, plot_fname=plot_fname) \ No newline at end of file From 04234d7e55d3c86bbe9c11e0776767c1da9bec00 Mon Sep 17 00:00:00 2001 From: ago109 Date: Wed, 24 Jul 2024 15:47:50 -0400 Subject: [PATCH 2/3] mod to dcsnn args --- exhibits/diehl_cook_snn/dcsnn_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/exhibits/diehl_cook_snn/dcsnn_model.py b/exhibits/diehl_cook_snn/dcsnn_model.py index a33949d..ea4d185 100755 --- a/exhibits/diehl_cook_snn/dcsnn_model.py +++ b/exhibits/diehl_cook_snn/dcsnn_model.py @@ -82,11 +82,13 @@ def __init__(self, dkey, in_dim=1, hid_dim=100, T=200, dt=1., exp_dir="exp", self.z1e = LIFCell("z1e", n_units=hid_dim, tau_m=tau_m_e, resist_m=tau_m_e/dt, thr=-52., v_rest=-65., v_reset=-60., tau_theta=1e7, theta_plus=0.05, - refract_time=5., one_spike=True, key=subkeys[2]) + refract_time=5., one_spike=True, + lower_clamp_voltage=False, key=subkeys[2]) self.z1i = LIFCell("z1i", n_units=hid_dim, tau_m=tau_m_i, resist_m=tau_m_i/dt, thr=-40., v_rest=-60., v_reset=-45., tau_theta=0., refract_time=5., - one_spike=False, key=subkeys[3]) + lower_clamp_voltage=False, one_spike=False, + key=subkeys[3]) # ie -> inhibitory to excitatory; ei -> excitatory to inhibitory # (eta = 0 means no learning) From 48c01e71b60ed6b5b180fa5113b59c15edc08ac3 Mon Sep 17 00:00:00 2001 From: ago109 Date: Mon, 5 Aug 2024 23:57:14 -0400 Subject: [PATCH 3/3] fixed minor bug in snn case1 --- exhibits/time-integrated-stdp/snn_case1.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/exhibits/time-integrated-stdp/snn_case1.py b/exhibits/time-integrated-stdp/snn_case1.py index 5883c5c..b3ba73a 100755 --- a/exhibits/time-integrated-stdp/snn_case1.py +++ b/exhibits/time-integrated-stdp/snn_case1.py @@ -138,6 +138,11 @@ def build_model(seed=1234, in_dim=1, is_patch_model=False, algo="tistdp"): tr2 = VarTrace("tr2", n_units=out_size, tau_tr=tau_tr, decay_type="exp", a_delta=trace_delta) + # wire cells z0 and z1e to their respective traces + tr0.inputs << z0.outputs + tr1.inputs << z1e.s + tr2.inputs << z2e.s + # wire relevant compartment statistics to synaptic cables W1 and W2 W1.preTrace << tr0.trace W1.preSpike << z0.outputs