diff --git a/exhibits/diehl_cook_snn/dcsnn_model.py b/exhibits/diehl_cook_snn/dcsnn_model.py index a33949d..ea4d185 100755 --- a/exhibits/diehl_cook_snn/dcsnn_model.py +++ b/exhibits/diehl_cook_snn/dcsnn_model.py @@ -82,11 +82,13 @@ def __init__(self, dkey, in_dim=1, hid_dim=100, T=200, dt=1., exp_dir="exp", self.z1e = LIFCell("z1e", n_units=hid_dim, tau_m=tau_m_e, resist_m=tau_m_e/dt, thr=-52., v_rest=-65., v_reset=-60., tau_theta=1e7, theta_plus=0.05, - refract_time=5., one_spike=True, key=subkeys[2]) + refract_time=5., one_spike=True, + lower_clamp_voltage=False, key=subkeys[2]) self.z1i = LIFCell("z1i", n_units=hid_dim, tau_m=tau_m_i, resist_m=tau_m_i/dt, thr=-40., v_rest=-60., v_reset=-45., tau_theta=0., refract_time=5., - one_spike=False, key=subkeys[3]) + lower_clamp_voltage=False, one_spike=False, + key=subkeys[3]) # ie -> inhibitory to excitatory; ei -> excitatory to inhibitory # (eta = 0 means no learning) diff --git a/exhibits/time-integrated-stdp/README.md b/exhibits/time-integrated-stdp/README.md new file mode 100755 index 0000000..36475e8 --- /dev/null +++ b/exhibits/time-integrated-stdp/README.md @@ -0,0 +1,75 @@ +# Time-Integrated Spike-Timing-Dependent Plasticity + +Version: ngclearn>=1.2.beta1, ngcsimlib==0.3.beta4 + +This exhibit contains an implementation of the spiking neuronal model +and credit assignment process proposed and studied in: + +Gebhardt, William, and Alexander G. Ororbia. "Time-Integrated Spike-Timing- +Dependent Plasticity." arXiv preprint arXiv:2407.10028 (2024). + +

+
+ Visual depiction of the TI-STDP-adapted SNN architecture. +

+ + + +## Running and Analyzing the Model Simulations + +### Unsupervised Digit-Level Biophysical Model + +To run the main TI-STDP SNN experiments of the paper, simply execute: + +```console +$ ./train_models.sh 0 tistdp snn_case1 +``` + +which will trigger three experimental trials for adaptation of the +`Case 1` model described in the paper on MNIST. If you want to train +the online, `Case 2` model described in the paper on MNIST, you simply +need to change the third argument to the Bash script like so: + +```console +$ ./train_models.sh 0 tistdp snn_case2 +``` + +Independent of whichever case-study you select above, you can analyze the +trained models, in accordance with what was done in the paper, by executing +the analysis bash script as follows: + +```console +$ ./analyze_models.sh 0 tistdp ## run on GPU 0 the "tistdp" config +``` + +Task: Models under this section engage in unsupervised representation +learning and jointly learn, through spike-timing driven credit assignment, +a low-level and higher-level abstract distributed, discrete representations +of sensory input data. In this exhibit, this is particularly focused on +using patterns in the MNIST database. + +### Part-Whole Assembly SNN Model + +To run the patch-model SNN adapted with TI-STDP, enter the `patch_model/` +sub-directory and then execute: + +```console +$ ./train_patch_models.sh +``` + +which will run a single trial to produce the SNN generative assembly +(or part-whole hierarchical) model constructed in the paper. + +Task: This biophysical model engages in a form unsupervised +representation learning that is focused on learning a simple bi-level +part-whole hierarchy of sensory input data (in this exhibit, the focus is on +using data from the MNIST database). + +## Model Descriptions, Hyperparameters, and Configuration Details + +Model explanations, meta-parameters settings and experimental details are +provided in the above reference paper. diff --git a/exhibits/time-integrated-stdp/analyze_models.sh b/exhibits/time-integrated-stdp/analyze_models.sh new file mode 100755 index 0000000..5d4db3b --- /dev/null +++ b/exhibits/time-integrated-stdp/analyze_models.sh @@ -0,0 +1,74 @@ +#!/bin/bash + +## get in user-provided program args +GPU_ID=$1 #1 +MODEL=$2 # evstdp trstdp tistdp + +if [[ "$MODEL" != "evstdp" && "$MODEL" != "trstdp" && "$MODEL" != "tistdp" ]]; then + echo "Invalid Arg: $MODEL -- only 'evstdp', 'trstdp', 'tistdp' models supported!" + exit 1 +fi +echo " >>>> Setting up $MODEL on GPU $GPU_ID" + +SEEDS=(1234 77 811) + +PARAM_SUBDIR="/custom" +DISABLE_ADAPT_AT_EVAL=False ## set to true to turn off eval-time adaptive thresholds +MAKE_CLUSTER_PLOT=False #True +REBIND_LABELS=0 ## rebind labels to train model? + +N_SAMPLES=50000 +DATA_X="../../data/mnist/trainX.npy" +DATA_Y="../../data/mnist/trainY.npy" +DEV_X="../../data/mnist/testX.npy" # validX.npy +DEV_Y="../../data/mnist/testY.npy" # validY.npy +EXTRACT_TRAINING_SPIKES=0 # set to 1 if you want to extract training set codes + +for seed in "${SEEDS[@]}" +do + EXP_DIR="exp_$MODEL""_$seed/" + echo " > Running Simulation/Model: $EXP_DIR" + + CODEBOOK=$EXP_DIR"training_codes.npy" + TEST_CODEBOOK=$EXP_DIR"test_codes.npy" + PLOT_FNAME=$EXP_DIR"codes.jpg" + + if [[ $REBIND_LABELS == 1 ]]; then + CUDA_VISIBLE_DEVICES=$GPU_ID python bind_labels.py --dataX=$DATA_X --dataY=$DATA_Y \ + --model_type=$MODEL \ + --model_dir=$EXP_DIR$MODEL \ + --n_samples=$N_SAMPLES \ + --exp_dir=$EXP_DIR \ + --disable_adaptation=$DISABLE_ADAPT_AT_EVAL \ + --param_subdir=$PARAM_SUBDIR + fi + + ## eval model +# CUDA_VISIBLE_DEVICES=$GPU_ID python eval.py --dataX=$DEV_X --dataY=$DEV_Y \ +# --model_type=$MODEL --model_dir=$EXP_DIR$MODEL \ +# --label_fname=$EXP_DIR"binded_labels.npy" \ +# --exp_dir=$EXP_DIR \ +# --disable_adaptation=$DISABLE_ADAPT_AT_EVAL \ +# --param_subdir=$PARAM_SUBDIR \ +# --make_cluster_plot=$MAKE_CLUSTER_PLOT + ## call codebook extraction processes + if [[ $EXTRACT_TRAINING_SPIKES == 1 ]]; then + CUDA_VISIBLE_DEVICES=$GPU_ID python extract_codes.py --dataX=$DATA_X \ + --n_samples=$N_SAMPLES \ + --codebook_fname=$CODEBOOK \ + --model_type=$MODEL \ + --model_fname=$EXP_DIR$MODEL \ + --disable_adaptation=False \ + --param_subdir=$PARAM_SUBDIR + fi + CUDA_VISIBLE_DEVICES=$GPU_ID python extract_codes.py --dataX=$DEV_X \ + --codebook_fname=$TEST_CODEBOOK \ + --model_type=$MODEL \ + --model_fname=$EXP_DIR$MODEL \ + --disable_adaptation=False \ + --param_subdir=$PARAM_SUBDIR + ## visualize latent codes + CUDA_VISIBLE_DEVICES=$GPU_ID python viz_codes.py --plot_fname=$PLOT_FNAME \ + --codes_fname=$TEST_CODEBOOK \ + --labels_fname=$DEV_Y +done diff --git a/exhibits/time-integrated-stdp/bind.sh b/exhibits/time-integrated-stdp/bind.sh new file mode 100755 index 0000000..d08f116 --- /dev/null +++ b/exhibits/time-integrated-stdp/bind.sh @@ -0,0 +1,23 @@ +#!/bin/bash +GPU_ID=1 #0 + +N_SAMPLES=10000 +DISABLE_ADAPT_AT_EVAL=False + +EXP_DIR="exp_trstdp/" +MODEL="trstdp" +#EXP_DIR="exp_evstdp/" +#MODEL="evstdp" +DEV_X="../../data/mnist/trainX.npy" # validX.npy +DEV_Y="../../data/mnist/trainY.npy" # validY.npy +PARAM_SUBDIR="/custom_snapshot2" +#PARAM_SUBDIR="/custom" + +## eval model +CUDA_VISIBLE_DEVICES=$GPU_ID python bind_labels.py --dataX=$DEV_X --dataY=$DEV_Y \ + --model_type=$MODEL \ + --model_dir=$EXP_DIR$MODEL \ + --n_samples=$N_SAMPLES \ + --exp_dir=$EXP_DIR \ + --disable_adaptation=$DISABLE_ADAPT_AT_EVAL \ + --param_subdir=$PARAM_SUBDIR diff --git a/exhibits/time-integrated-stdp/bind_labels.py b/exhibits/time-integrated-stdp/bind_labels.py new file mode 100755 index 0000000..2e7a339 --- /dev/null +++ b/exhibits/time-integrated-stdp/bind_labels.py @@ -0,0 +1,137 @@ +from jax import numpy as jnp, random +import sys, getopt as gopt, optparse, time + +import matplotlib #.pyplot as plt +matplotlib.use('Agg') +import matplotlib.pyplot as plt +cmap = plt.cm.jet + +################################################################################ +# read in general program arguments +options, remainder = gopt.getopt(sys.argv[1:], '', ["dataX=", "dataY=", + "model_dir=", "model_type=", + "exp_dir=", "n_samples=", + "disable_adaptation=", + "param_subdir="]) + +model_case = "snn_case1" +disable_adaptation = True +exp_dir = "exp/" +param_subdir = "/custom" +model_type = "tistdp" +model_dir = "exp/tistdp" +dataX = "../../data/mnist/trainX.npy" +dataY = "../../data/mnist/trainY.npy" +n_samples = 10000 +verbosity = 0 ## verbosity level (0 - fairly minimal, 1 - prints multiple lines on I/O) +for opt, arg in options: + if opt in ("--dataX"): + dataX = arg.strip() + elif opt in ("--dataY"): + dataY = arg.strip() + elif opt in ('--model_dir'): + model_dir = arg.strip() + elif opt in ('--model_type'): + model_type = arg.strip() + elif opt in ('--exp_dir'): + exp_dir = arg.strip() + elif opt in ('--param_subdir'): + param_subdir = arg.strip() + elif opt in ('--n_samples'): + n_samples = int(arg.strip()) + elif opt in ('--disable_adaptation'): + disable_adaptation = (arg.strip().lower() == "true") + print(" > Disable short-term adaptation? ", disable_adaptation) + +if model_case == "snn_case1": + print(" >> Setting up Case 1 model!") + from snn_case1 import load_from_disk, get_nodes +elif model_case == "snn_case2": + print(" >> Setting up Case 2 model!") + from snn_case2 import load_from_disk, get_nodes +else: + print("Error: No other model case studies supported! (", model_case, " invalid)") + exit() + +print(">> X: {} Y: {}".format(dataX, dataY)) + +dkey = random.PRNGKey(1234) +dkey, *subkeys = random.split(dkey, 3) + +## load dataset +_X = jnp.load(dataX) +_Y = jnp.load(dataY) +if 0 < n_samples < _X.shape[0]: + ptrs = random.permutation(subkeys[0], _X.shape[0])[0:n_samples] + _X = _X[ptrs, :] + _Y = _Y[ptrs, :] + # _X = _X[0:n_samples, :] + # _Y = _Y[0:n_samples, :] + print("-> Binding {} first randomly selected samples to model".format(n_samples)) +n_batches = _X.shape[0] ## num batches is = to num samples (online learning) + +## basic simulation hyper-parameter/configuration values go here +viz_mod = 1000 #10000 +mb_size = 1 ## locked to batch sizes of 1 +patch_shape = (28, 28) +in_dim = patch_shape[0] * patch_shape[1] + +T = 250 #300 ## num time steps to simulate (stimulus presentation window length) +dt = 1. ## integration time constant + +################################################################################ +print("--- Loading Model ---") + +## Load in model +model = load_from_disk(model_dir, param_dir=param_subdir, + disable_adaptation=disable_adaptation) +nodes, node_map = get_nodes(model) + +################################################################################ +print("--- Starting Binding Process ---") + +print("------------------------------------") +model.showStats(-1) + +## enter main adaptation loop over data patterns +class_responses = jnp.zeros((_Y.shape[1], node_map.get("z2e").n_units)) +num_bound = 0 +n_total_samp_seen = 0 +tstart = time.time() +n_samps_seen = 0 +for j in range(n_batches): + idx = j + Xb = _X[idx: idx + mb_size, :] + Yb = _Y[idx: idx + mb_size, :] + + model.reset() + model.clamp(Xb) + spikes1, spikes2 = model.infer( + jnp.array([[dt * k, dt] for k in range(T)])) + ## bind output spike train(s) + responses = Yb.T * jnp.sum(spikes2, axis=0) + class_responses = class_responses + responses + num_bound += 1 + + n_samps_seen += Xb.shape[0] + n_total_samp_seen += Xb.shape[0] + print("\r Binding {} images...".format(n_samps_seen), end="") +tend = time.time() +print() +sim_time = tend - tstart +sim_time_hr = (sim_time/3600.0) # convert time to hours +print(" -> Binding.Time = {} s".format(sim_time_hr)) +print("------------------------------------") + +## compute max-frequency (~firing rate) spike responses +class_responses = jnp.argmax(class_responses, axis=0, keepdims=True) +print("---- Max Class Responses ----") +print(class_responses) +print(class_responses.shape) +bind_fname = "{}binded_labels.npy".format(exp_dir) +print(" >> Saving label bindings to: ", bind_fname) +jnp.save(bind_fname, class_responses) + + + + diff --git a/exhibits/time-integrated-stdp/custom/__init__.py b/exhibits/time-integrated-stdp/custom/__init__.py new file mode 100755 index 0000000..0129e54 --- /dev/null +++ b/exhibits/time-integrated-stdp/custom/__init__.py @@ -0,0 +1 @@ +from .ti_stdp_synapse import TI_STDP_Synapse diff --git a/exhibits/time-integrated-stdp/custom/ti_stdp_synapse.py b/exhibits/time-integrated-stdp/custom/ti_stdp_synapse.py new file mode 100755 index 0000000..439db30 --- /dev/null +++ b/exhibits/time-integrated-stdp/custom/ti_stdp_synapse.py @@ -0,0 +1,98 @@ +import time + +from ngclearn import Component, Compartment, resolver +from ngclearn.components import DenseSynapse +from ngclearn.utils.model_utils import normalize_matrix + +from jax import numpy as jnp, random +import os.path + +class TI_STDP_Synapse(DenseSynapse): + def __init__(self, name, shape, alpha=0.0075, beta=0.5, pre_decay=0.5, + resist_scale=1., p_conn=1., weight_init=None, **kwargs): + super().__init__(name=name, shape=shape, resist_scale=resist_scale, + p_conn=p_conn, weight_init=weight_init, **kwargs) + + #Params + self.batch_size = 1 + self.alpha = alpha + self.beta = beta + self.pre_decay = pre_decay + self.Aplus = 1 + self.Aminus = 1 + + #Compartments + self.pre = Compartment(None) + self.post = Compartment(None) + + self.reset() + + @staticmethod + def _norm(weights, norm_scale): + return normalize_matrix(weights, wnorm=100, scale=norm_scale) + + @resolver(_norm) + def norm(self, weights): + self.weights.set(weights) + + + @staticmethod + def _reset(batch_size, shape): + return jnp.zeros((batch_size, shape[0])), \ + jnp.zeros((batch_size, shape[1])), \ + jnp.zeros((batch_size, shape[0])), \ + jnp.zeros((batch_size, shape[1])) + + @resolver(_reset) + def reset(self, inputs, outputs, pre, post): + self.inputs.set(inputs) + self.outputs.set(outputs) + self.pre.set(pre) + self.post.set(post) + + @staticmethod + def _evolve(pre, post, weights, + shape, alpha, beta, pre_decay, Aplus, Aminus, + t, dt): + pre_size = shape[0] + post_size = shape[1] + + pre_synaptic_binary_mask = jnp.where(pre > 0, 1., 0.) + post_synaptic_binary_mask = jnp.where(post > 0, 1., 0.) + + broadcast_pre_synaptic_binary_mask = jnp.zeros(shape) + jnp.reshape(pre_synaptic_binary_mask, (pre_size, 1)) + broadcast_post_synaptic_binary_mask = jnp.zeros(shape) + jnp.reshape(post_synaptic_binary_mask, (1, post_size)) + + broadcast_pre_synaptic_time = jnp.zeros(shape) + jnp.reshape(pre, (pre_size, 1)) + broadcast_post_synaptic_time = jnp.zeros(shape) + jnp.reshape(post, (1, post_size)) + + no_pre_synaptic_spike_mask = broadcast_post_synaptic_binary_mask * (1. - broadcast_pre_synaptic_binary_mask) + no_pre_synaptic_weight_update = (-alpha) * (pre_decay / jnp.exp((1 / dt) * (t - broadcast_post_synaptic_time))) + # no_pre_synaptic_weight_update = no_pre_synaptic_weight_update + + # Both have spiked + both_spike_mask = broadcast_post_synaptic_binary_mask * broadcast_pre_synaptic_binary_mask + both_spike_update = (-alpha / (broadcast_pre_synaptic_time - broadcast_post_synaptic_time - (0.5 * dt))) * \ + (beta / jnp.exp((1 / dt) * (t - broadcast_post_synaptic_time))) + + masked_no_pre_synaptic_weight_update = no_pre_synaptic_spike_mask * no_pre_synaptic_weight_update + masked_both_spike_update = both_spike_mask * both_spike_update + + plasticity = jnp.where(masked_both_spike_update > 0, + Aplus * masked_both_spike_update, + Aminus * masked_both_spike_update) + + decay = masked_no_pre_synaptic_weight_update + + plasticity = plasticity * (1 - weights) + + decay = decay * weights + + update = plasticity + decay + _W = weights + update + + return jnp.clip(_W, 0., 1.) + + @resolver(_evolve) + def evolve(self, weights): + self.weights.set(weights) diff --git a/exhibits/time-integrated-stdp/eval.py b/exhibits/time-integrated-stdp/eval.py new file mode 100755 index 0000000..e4abb5c --- /dev/null +++ b/exhibits/time-integrated-stdp/eval.py @@ -0,0 +1,158 @@ +from jax import numpy as jnp, random, nn, jit +import numpy as np, time +import sys, getopt as gopt, optparse +## bring in ngc-learn analysis tools +from ngclearn.utils.viz.raster import create_raster_plot +import ngclearn.utils.metric_utils as metrics + +import matplotlib #.pyplot as plt +matplotlib.use('Agg') +import matplotlib.pyplot as plt +cmap = plt.cm.jet + +def plot_clusters(W1, W2, figdim=50): ## generalized dual-layer plotting code + n_input = W1.weights.value.shape[0] ## input layer + n_hid1 = W1.weights.value.shape[1] ## layer size 1 + n_hid2 = W2.weights.value.shape[1] ## layer size 2 + ndim0 = int(jnp.sqrt(n_input)) ## sqrt(layer size 0) + ndim1 = int(jnp.sqrt(n_hid1)) ## sqrt(layer size 1) + ndim2 = int(jnp.sqrt(n_hid2)) ## sqrt(layer size 2) + plt.figure(figsize=(figdim, figdim)) + plt.subplots_adjust(hspace=0.1, wspace=0.1) + + for q in range(W2.weights.value.shape[1]): + masked = W1.weights.value * W2.weights.value[:, q] + + dim = ((ndim0 * ndim1) + (ndim1 - 1)) #(28 * 10) + (10 - 1) + + full = jnp.ones((dim, dim)) * jnp.amax(masked) + + for k in range(n_hid1): + r = k // ndim1 #k // 10 #sqrt(hidden layer size) + c = k % ndim1 # k % 10 + + full = full.at[(r * (ndim0 + 1)):(r + 1) * ndim0 + r, + (c * (ndim0 + 1)):(c + 1) * ndim0 + c].set( + jnp.reshape(masked[:, k], (ndim0, ndim0))) + + plt.subplot(ndim2, ndim2, q + 1) # 5 = sqrt(output layer size) + plt.imshow(full, cmap=plt.cm.bone, interpolation='nearest') + plt.axis("off") + + plt.subplots_adjust(top=0.9) + plt.savefig("{}clusters.jpg".format(exp_dir), bbox_inches='tight') + plt.clf() + plt.close() + +# read in general program arguments +options, remainder = gopt.getopt(sys.argv[1:], '', ["dataX=", "dataY=", + "model_dir=", "model_type=", + "label_fname=", "exp_dir=", + "param_subdir=", + "disable_adaptation=", + "make_cluster_plot="]) + +model_case = "snn_case1" +exp_dir = "exp/" +label_fname = "exp/labs.npy" +model_type = "tistdp" +model_dir = "exp/tistdp" +param_subdir = "/custom" +disable_adaptation = True +dataX = "../../data/mnist/trainX.npy" +dataY = "../../data/mnist/trainY.npy" +make_cluster_plot = True +verbosity = 0 ## verbosity level (0 - fairly minimal, 1 - prints multiple lines on I/O) +for opt, arg in options: + if opt in ("--dataX"): + dataX = arg.strip() + elif opt in ("--dataY"): + dataY = arg.strip() + elif opt in ('--model_dir'): + model_dir = arg.strip() + elif opt in ('--model_type'): + model_type = arg.strip() + elif opt in ('--label_fname'): + label_fname = arg.strip() + elif opt in ('--exp_dir'): + exp_dir = arg.strip() + elif opt in ('--param_subdir'): + param_subdir = arg.strip() + elif opt in ('--disable_adaptation'): + disable_adaptation = (arg.strip().lower() == "true") + print(" > Disable short-term adaptation? ", disable_adaptation) + elif opt in ('--make_cluster_plot'): + make_cluster_plot = (arg.strip().lower() == "true") + print(" > Make cluster plot? ", make_cluster_plot) + +if model_case == "snn_case1": + print(" >> Setting up Case 1 model!") + from snn_case1 import load_from_disk, get_nodes +elif model_case == "snn_case2": + print(" >> Setting up Case 2 model!") + from snn_case2 import load_from_disk, get_nodes +else: + print("Error: No other model case studies supported! (", model_case, " invalid)") + exit() + +print(">> X: {} Y: {}".format(dataX, dataY)) + +T = 250 # 300 +dt = 1. + +## load dataset +batch_size = 1 #100 +_X = jnp.load(dataX) +_Y = jnp.load(dataY) +n_batches = int(_X.shape[0]/batch_size) +patch_shape = (28, 28) + +dkey = random.PRNGKey(time.time_ns()) +model = load_from_disk(model_dir, param_dir=param_subdir, + disable_adaptation=disable_adaptation) +nodes = model.get_components("W1", "W1ie", "W1ei", "z0", "z1e", "z1i", + "W2", "W2ie", "W2ei", "z2e", "z2i") +W1, W1ie, W1ei, z0, z1e, z1i, W2, W2ie, W2ei, z2e, z2i = nodes + +if make_cluster_plot: + ## plot clusters formed by 2nd layer of model's spikes + print(" >> Creating model cluster plot...") + plot_clusters(W1, W2) + +## extract label bindings +bindings = jnp.load(label_fname) + +acc = 0. +Ns = 0. +yMu = [] +for i in range(n_batches): + Xb = _X[i * batch_size:(i+1) * batch_size, :] + Yb = _Y[i * batch_size:(i+1) * batch_size, :] + Ns += Xb.shape[0] + model.reset() + model.clamp(Xb) + spikes1, spikes2 = model.infer( + jnp.array([[dt * k, dt] for k in range(T)])) + winner = jnp.argmax(jnp.sum(spikes2, axis=0)) + yHat = nn.one_hot(bindings[:, winner], num_classes=Yb.shape[1]) + yMu.append(yHat) + acc = metrics.measure_ACC(yHat, Yb) + acc + print("\r Acc = {} (over {} samples)".format(acc/Ns, i+1), end="") +print() +print("===============================================") +yMu = jnp.concatenate(yMu, axis=0) +conf_matrix, precision, recall, misses, acc, adj_acc = metrics.analyze_scores(yMu, _Y) +print(conf_matrix) +print("---") +print(" >> Number of Misses = {}".format(misses)) +print(" >> Acc = {} Precision = {} Recall = {}".format(acc, precision, recall)) +msg = "{}".format(conf_matrix) +jnp.save("{}confusion.npy".format(exp_dir), conf_matrix) +msg = ("{}\n---\n" + "Number of Misses = {}\n" + "Acc = {} Adjusted-Acc = {}\n" + "Precision = {}\nRecall = {}\n").format(msg, misses, acc, adj_acc, precision, recall) +fd = open("{}scores.txt".format(exp_dir), "w") +fd.write(msg) +fd.close() + diff --git a/exhibits/time-integrated-stdp/eval.sh b/exhibits/time-integrated-stdp/eval.sh new file mode 100755 index 0000000..0fd21f7 --- /dev/null +++ b/exhibits/time-integrated-stdp/eval.sh @@ -0,0 +1,23 @@ +#!/bin/bash +GPU_ID=1 #0 + +EXP_DIR="exp_trstdp/" +MODEL="trstdp" +#EXP_DIR="exp_evstdp/" +#MODEL="evstdp" +PARAM_SUBDIR="/custom_snapshot2" +#PARAM_SUBDIR="/custom" +DISABLE_ADAPT_AT_EVAL=False ## set to true to turn off eval-time adaptive thresholds +MAKE_CLUSTER_PLOT=False + +DEV_X="../../data/mnist/testX.npy" # validX.npy +DEV_Y="../../data/mnist/testY.npy" # validY.npy + +# eval model +CUDA_VISIBLE_DEVICES=$GPU_ID python eval.py --dataX=$DEV_X --dataY=$DEV_Y \ + --model_type=$MODEL --model_dir=$EXP_DIR$MODEL \ + --label_fname=$EXP_DIR"binded_labels.npy" \ + --exp_dir=$EXP_DIR \ + --disable_adaptation=$DISABLE_ADAPT_AT_EVAL \ + --param_subdir=$PARAM_SUBDIR \ + --make_cluster_plot=$MAKE_CLUSTER_PLOT diff --git a/exhibits/time-integrated-stdp/extract_codes.py b/exhibits/time-integrated-stdp/extract_codes.py new file mode 100755 index 0000000..43f3e42 --- /dev/null +++ b/exhibits/time-integrated-stdp/extract_codes.py @@ -0,0 +1,86 @@ +from jax import numpy as jnp, random, nn, jit +import numpy as np, time +import sys, getopt as gopt, optparse +## bring in ngc-learn analysis tools +from ngclearn.utils.viz.raster import create_raster_plot +from ngclearn.utils.metric_utils import measure_ACC + +# read in general program arguments +options, remainder = gopt.getopt(sys.argv[1:], '', ["dataX=", + "codebook_fname=", + "model_fname=", + "n_samples=", "model_type=", + "param_subdir=", + "disable_adaptation="]) + +model_case = "snn_case1" +n_samples = -1 +model_type = "tistdp" +model_fname = "exp/tistdp" +param_subdir = "/custom" +disable_adaptation = True +codebook_fname = "codes.npy" +dataX = "../../data/mnist/trainX.npy" +for opt, arg in options: + if opt in ("--dataX"): + dataX = arg.strip() + elif opt in ("--codebook_fname"): + codebook_fname = arg.strip() + elif opt in ("--model_fname"): + model_fname = arg.strip() + elif opt in ("--n_samples"): + n_samples = int(arg.strip()) + elif opt in ('--model_type'): + model_type = arg.strip() + elif opt in ('--param_subdir'): + param_subdir = arg.strip() + elif opt in ('--disable_adaptation'): + disable_adaptation = (arg.strip().lower() == "true") + print(" > Disable short-term adaptation? ", disable_adaptation) + +if model_case == "snn_case1": + print(" >> Setting up Case 1 model!") + from snn_case1 import load_from_disk, get_nodes +elif model_case == "snn_case2": + print(" >> Setting up Case 2 model!") + from snn_case2 import load_from_disk, get_nodes +else: + print("Error: No other model case studies supported! (", model_case, " invalid)") + exit() + +print(">> X: {}".format(dataX)) + +## load dataset +batch_size = 1 #100 +_X = jnp.load(dataX) +if 0 < n_samples < _X.shape[0]: + _X = _X[0:n_samples, :] +n_batches = int(_X.shape[0]/batch_size) +patch_shape = (28, 28) + +dkey = random.PRNGKey(time.time_ns()) +model = load_from_disk(model_directory=model_fname, param_dir=param_subdir, + disable_adaptation=disable_adaptation) + +T = 250 # 300 +dt = 1. + +codes = [] ## get latent (spike) codes +acc = 0. +Ns = 0. +for i in range(n_batches): + Xb = _X[i * batch_size:(i + 1) * batch_size, :] + Ns += Xb.shape[0] + model.reset() + model.clamp(Xb) + spikes1, spikes2 = model.infer( + jnp.array([[dt * k, dt] for k in range(T)])) + counts = jnp.sum(spikes2, axis=0) ## get counts + codes.append(counts) + print("\r > Processed ({} samples)".format(Ns), end="") +print() + +print(" >> Saving code-book to disk: {}".format(codebook_fname)) +codes = jnp.concatenate(codes, axis=0) +print(" >> Code.shape = ", codes.shape) +jnp.save(codebook_fname, codes) diff --git a/exhibits/time-integrated-stdp/fig/tistdp_snn.jpg b/exhibits/time-integrated-stdp/fig/tistdp_snn.jpg new file mode 100755 index 0000000..d0db7b2 Binary files /dev/null and b/exhibits/time-integrated-stdp/fig/tistdp_snn.jpg differ diff --git a/exhibits/time-integrated-stdp/harvest_latents.sh b/exhibits/time-integrated-stdp/harvest_latents.sh new file mode 100755 index 0000000..b2e4b8f --- /dev/null +++ b/exhibits/time-integrated-stdp/harvest_latents.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +## get in user-provided program args +GPU_ID=$1 #1 +MODEL=$2 # evstdp trstdp tistdp stdp + +if [[ "$MODEL" != "evstdp" && "$MODEL" != "trstdp" && "$MODEL" != "tistdp" ]]; then + echo "Invalid Arg: $MODEL -- only 'evstdp', 'trstdp', 'tistdp' models supported!" + exit 1 +fi +echo " >>>> Setting up $MODEL on GPU $GPU_ID" + +SEEDS=(1234) # 77 811) + +PARAM_SUBDIR="/custom" +DISABLE_ADAPT_AT_EVAL=False ## set to true to turn off eval-time adaptive thresholds + +N_SAMPLES=50000 +DATA_X="../../data/mnist/trainX.npy" +DATA_Y="../../data/mnist/trainY.npy" + +for seed in "${SEEDS[@]}" +do + EXP_DIR="final_case1_results/exp_$MODEL""_$seed/" + echo " > Running Simulation/Model: $EXP_DIR" + + CODEBOOK=$EXP_DIR"training_codes.npy" + + CUDA_VISIBLE_DEVICES=$GPU_ID python extract_codes.py --dataX=$DATA_X \ + --n_samples=$N_SAMPLES \ + --codebook_fname=$CODEBOOK \ + --model_type=$MODEL \ + --model_fname=$EXP_DIR$MODEL \ + --disable_adaptation=$DISABLE_ADAPT_AT_EVAL \ + --param_subdir=$PARAM_SUBDIR +done diff --git a/exhibits/time-integrated-stdp/json_files/config.json b/exhibits/time-integrated-stdp/json_files/config.json new file mode 100644 index 0000000..9cc4515 --- /dev/null +++ b/exhibits/time-integrated-stdp/json_files/config.json @@ -0,0 +1,3 @@ +{ + "logging" : {"logging_level" : "ERROR"} +} diff --git a/exhibits/time-integrated-stdp/json_files/modules.json b/exhibits/time-integrated-stdp/json_files/modules.json new file mode 100644 index 0000000..1d16af0 --- /dev/null +++ b/exhibits/time-integrated-stdp/json_files/modules.json @@ -0,0 +1,16 @@ +[ + {"absolute_path": "ngclearn.components", + "attributes": [ + {"name": "VarTrace"}, + {"name": "PoissonCell"}, + {"name": "LIFCell"}, + {"name": "TraceSTDPSynapse"}, + {"name": "StaticSynapse"}] + }, + {"absolute_path": "ngcsimlib.operations", + "attributes": [ + {"name": "overwrite"}, + {"name": "summation"}] + } + +] diff --git a/exhibits/time-integrated-stdp/patch_model/assemble_patterns.py b/exhibits/time-integrated-stdp/patch_model/assemble_patterns.py new file mode 100755 index 0000000..bf71422 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/assemble_patterns.py @@ -0,0 +1,159 @@ +from jax import numpy as jnp, random +import sys, getopt as gopt, optparse, time +from ngclearn.utils.io_utils import makedir +from custom.patch_utils import Create_Patches +from ngclearn.utils.viz.synapse_plot import visualize + +import matplotlib #.pyplot as plt +matplotlib.use('Agg') +import matplotlib.pyplot as plt +cmap = plt.cm.jet + +################################################################################ +# read in general program arguments +options, remainder = gopt.getopt(sys.argv[1:], '', ["dataX=", "dataY=", "n_samples=", + "verbosity=", "exp_dir=", "model_dir=", + "model_type=", "param_subdir=", "seed="]) + +seed = 1234 +exp_dir = "exp/" +model_type = "tistdp" +model_dir = "" +param_subdir = "/custom" +n_samples = -1 +dataX = "../../data/mnist/trainX.npy" +dataY = "../../data/mnist/trainY.npy" +verbosity = 0 ## verbosity level (0 - fairly minimal, 1 - prints multiple lines on I/O) +for opt, arg in options: + if opt in ("--dataX"): + dataX = arg.strip() + elif opt in ("--dataY"): + dataY = arg.strip() + elif opt in ("--verbosity"): + verbosity = int(arg.strip()) + elif opt in ("--n_samples"): + n_samples = int(arg.strip()) + elif opt in ("--exp_dir"): + exp_dir = arg.strip() + elif opt in ('--param_subdir'): + param_subdir = arg.strip() + elif opt in ("--model_type"): + model_type = arg.strip() + elif opt in ("--model_dir"): + model_dir = arg.strip() + elif opt in ("--seed"): + seed = int(arg.strip()) + +if model_type == "tistdp": + print(" >> Setting up TI-STDP Patch-Model builder!") + from patch_tistdp_snn import load_from_disk, get_nodes +else: + print(" >> Model type ", model_type, " not supported!") + +print(">> X: {} Y: {}".format(dataX, dataY)) + +## load dataset +_X = jnp.load(dataX) +_Y = jnp.load(dataY) +if n_samples > 0: + _X = _X[0:n_samples, :] + _Y = _Y[0:n_samples, :] + print("-> Fitting model to only {} samples".format(n_samples)) +n_batches = _X.shape[0] ## num batches is = to num samples (online learning) + +## basic simulation hyper-parameter/configuration values go here +viz_mod = 100 # 1000 #10000 +mb_size = 1 ## locked to batch sizes of 1 +patch_shape = (28, 28) ## same as image_shape +py = px = 28 +img_shape = (py, px) +in_dim = patch_shape[0] * patch_shape[1] + +x = jnp.ones(patch_shape) +pty = ptx = 7 +in_patchShape = (pty, ptx) #(14, 14) #(7, 7) +in_patchSize = pty * ptx +stride_shape = (2, 2) #(0, 0) +patcher = Create_Patches(x, in_patchShape, stride_shape) +x_pat = patcher.create_patches(add_frame=False, center=False) +n_in_patches = patcher.nw_patches * patcher.nh_patches + +z1_patchSize = 4 * 4 +z1_patchCnt = 16 + +T = 250 #300 ## num time steps to simulate (stimulus presentation window length) +dt = 1. ## integration time constant + +################################################################################ +print("--- Building Model ---") +## Create model +makedir(exp_dir) +## get model from imported header file +disable_adaptation = False # True +model = load_from_disk(model_dir, param_dir=param_subdir, + disable_adaptation=disable_adaptation) +nodes, node_map = get_nodes(model) +################################################################################ +print("--- Starting Simulation ---") + +sim_start_time = time.time() ## start time profiling + +print("------------------------------------") +model.showStats(-1) + +## sampling concept +K = 300 #100 #400 #200 #100 +W2 = node_map.get("W2").weights.value +W1 = node_map.get("W1").weights.value + +print(">> Visualizing top-level filters!") +n_neurons = int(jnp.sqrt(W2.shape[0])) +visualize([W2], [(n_neurons, n_neurons)], "{}_toplevel_filters".format(exp_dir)) + +print(" >> Building Level 1 Block Filter Tensor...") +W1_images = [] +for i_ in range(W1.shape[1]): + img_i = [] ## combine along row axis (in cache) + strip_i = [] ## combine along column axis + for j_ in range(z1_patchCnt): + _filter = W1[in_patchSize * j_:(j_ + 1) * in_patchSize, i_:i_ + 1] + _filter = jnp.reshape(_filter, (ptx, pty)) # reshape to patch grid + strip_i.append(_filter) + if len(strip_i) >= int(py/pty): + img_i.append(jnp.concatenate(strip_i, axis=1)) + strip_i = [] + img_i = jnp.concatenate(img_i, axis=0) + # plt.imshow(img_i) + # plt.savefig("exp_evstdp_1234/test/filter{}.jpg".format(i_)) + print("\r {} filters built...".format(i_),end="") + W1_images.append(img_i) + img_i = [] ## clear cache +print() + +print(" >> Building super-imposed sampled images!") +samples = [] +for i in range(W2.shape[1]): + W2_i = W2[:, i] + indices = jnp.argsort(W2_i, descending=True)[0:K] + coefficients = W2_i[indices] + #indices = jnp.where(W2_i > thr) + #Z = jnp.amax(W2_i) + #indices = jnp.flip(jnp.argsort(W2_i))[0:K] + #coefficients = W2_i[indices]#/Z + + xSample = 0. + ptr = 0 + for idx in indices: # j in range(indices.shape[0]): + coeff_i = coefficients[ptr] + Ki = W1_images[idx] * coeff_i + xSample += Ki + ptr += 1 + xSample = xSample.reshape(1, -1).T + samples.append(xSample) + print("\r Crafted {} samples...".format(len(samples)), end="") +print() +samples = jnp.concatenate(samples, axis=1) + +visualize([samples], [(28, 28)], "{}{}".format(exp_dir, "samples")) + + diff --git a/exhibits/time-integrated-stdp/patch_model/custom/LCNSynapse.py b/exhibits/time-integrated-stdp/patch_model/custom/LCNSynapse.py new file mode 100755 index 0000000..6e7f332 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/custom/LCNSynapse.py @@ -0,0 +1,114 @@ +from jax import random, numpy as jnp, jit +from ngclearn import resolver, Component, Compartment +from ngclearn.components.jaxComponent import JaxComponent +from ngclearn.utils import tensorstats +from ngclearn.utils.weight_distribution import initialize_params +from ngcsimlib.logger import info + +## init function for enforcing locally-connected structure of synaptic cable +def init_block_synapse(weight_init, num_models, patch_shape, key): + weight_shape = (num_models * patch_shape[0], num_models * patch_shape[1]) + weights = initialize_params(key[2], {"dist": "constant", "value": 0.}, + weight_shape, use_numpy=True) + for i in range(num_models): + weights[patch_shape[0] * i:patch_shape[0] * (i + 1), + patch_shape[1] * i:patch_shape[1] * (i + 1)] = ( + initialize_params(key[1], init_kernel=weight_init, + shape=patch_shape, use_numpy=True)) + return weights + + +class LCNSynapse(JaxComponent): ## locally-connected synaptic cable + def __init__(self, name, shape, model_shape=(1,1), weight_init=None, bias_init=None, + resist_scale=1., p_conn=1., batch_size=1, **kwargs): + super().__init__(name, **kwargs) + + self.batch_size = batch_size + self.weight_init = weight_init + self.bias_init = bias_init + self.n_models = model_shape[0] + self.model_patches = model_shape[1] + + ## Synapse meta-parameters + self.sub_shape = (shape[0], shape[1]*self.model_patches) ## shape of synaptic efficacy matrix # = (in_dim, hid_dim) = (d3, d2) + self.shape = (self.sub_shape[0] * self.n_models, self.sub_shape[1] * self.n_models) + self.Rscale = resist_scale ## post-transformation scale factor + + ## Set up synaptic weight values + tmp_key, *subkeys = random.split(self.key.value, 4) + if self.weight_init is None: + info(self.name, "is using default weight initializer!") + self.weight_init = {"dist": "uniform", "amin": 0.025, "amax": 0.8} + + weights = init_block_synapse(self.weight_init, self.n_models, self.sub_shape, subkeys) + + if 0. < p_conn < 1.: ## only non-zero and <1 probs allowed + mask = random.bernoulli(subkeys[1], p=p_conn, shape=self.shape) + weights = weights * mask ## sparsify matrix + + self.batch_size = 1 + ## Compartment setup + preVals = jnp.zeros((self.batch_size, self.shape[0])) + postVals = jnp.zeros((self.batch_size, self.shape[1])) + self.inputs = Compartment(preVals) + self.outputs = Compartment(postVals) + self.weights = Compartment(weights) + ## Set up (optional) bias values + if self.bias_init is None: + info(self.name, "is using default bias value of zero (no bias " + "kernel provided)!") + self.biases = Compartment( + initialize_params(subkeys[2], bias_init, (1, self.shape[1])) + if bias_init else 0.0 + ) + + @staticmethod + def _advance_state(Rscale, inputs, weights, biases): + outputs = (jnp.matmul(inputs, weights)) * Rscale + biases + return outputs + + @resolver(_advance_state) + def advance_state(self, outputs): + self.outputs.set(outputs) + + @staticmethod + def _reset(batch_size, shape): + preVals = jnp.zeros((batch_size, shape[0])) + postVals = jnp.zeros((batch_size, shape[1])) + inputs = preVals + outputs = postVals + return inputs, outputs + + @resolver(_reset) + def reset(self, inputs, outputs): + self.inputs.set(inputs) + self.outputs.set(outputs) + + def save(self, directory, **kwargs): + file_name = directory + "/" + self.name + ".npz" + if self.bias_init != None: + jnp.savez(file_name, weights=self.weights.value, + biases=self.biases.value) + else: + jnp.savez(file_name, weights=self.weights.value) + + def load(self, directory, **kwargs): + file_name = directory + "/" + self.name + ".npz" + data = jnp.load(file_name) + self.weights.set(data['weights']) + if "biases" in data.keys(): + self.biases.set(data['biases']) + + def __repr__(self): + comps = [varname for varname in dir(self) if Compartment.is_compartment(getattr(self, varname))] + maxlen = max(len(c) for c in comps) + 5 + lines = f"[{self.__class__.__name__}] PATH: {self.name}\n" + for c in comps: + stats = tensorstats(getattr(self, c).value) + if stats is not None: + line = [f"{k}: {v}" for k, v in stats.items()] + line = ", ".join(line) + else: + line = "None" + lines += f" {f'({c})'.ljust(maxlen)}{line}\n" + return lines diff --git a/exhibits/time-integrated-stdp/patch_model/custom/__init__.py b/exhibits/time-integrated-stdp/patch_model/custom/__init__.py new file mode 100755 index 0000000..45a5b57 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/custom/__init__.py @@ -0,0 +1,3 @@ +from .LCNSynapse import LCNSynapse +from .ti_STDP_Synapse import TI_STDP_Synapse +from .ti_STDP_LCNSynapse import TI_STDP_LCNSynapse diff --git a/exhibits/time-integrated-stdp/patch_model/custom/patch_utils.py b/exhibits/time-integrated-stdp/patch_model/custom/patch_utils.py new file mode 100755 index 0000000..4e79f9b --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/custom/patch_utils.py @@ -0,0 +1,95 @@ +""" +Image/tensor patching utility routines. +""" +import numpy as np +from jax import numpy as jnp + +class Create_Patches: + """ + This function will create small patches out of the image based on the provided attributes. + + Args: + img: jax array of size (H, W) + patched: (height_patch, width_patch) + overlap: (height_overlap, width_overlap) + + add_frame: increases the img size by (height_patch - height_overlap, width_patch - width_overlap) + create_patches: creates small patches out of the image based on the provided attributes. + + Returns: + jnp.array: Array containing the patches + shape: (num_patches, patch_height, patch_width) + + """ + + def __init__(self, img, patch_shape, overlap_shape): + self.img = img + self.height_patch, self.width_patch = patch_shape + self.height_over, self.width_over = overlap_shape + + self.height, self.width = self.img.shape + + if self.height_over > 0: + self.nw_patches = (self.width + self.width_over) // (self.width_patch - self.width_over) - 2 + else: + self.nw_patches = self.width // self.width_patch + if self.width_over > 0: + self.nh_patches = (self.height + self.height_over) // (self.height_patch - self.height_over) - 2 + else: + self.nh_patches = self.height // self.height_patch + # print('nw_patches', 'nh_patches') + # print(self.nw_patches, self.nh_patches) + # print("...........") + + def _add_frame(self): + """ + This function will add zero frames (increase the dimension) to the image + + Returns: + image with increased size (x.shape[0], x.shape[1]) -> (x.shape[0] + (height_patch - height_overlap), + x.shape[1] + (width_patch - width_overlap)) + """ + self.img = np.hstack((jnp.zeros((self.img.shape[0], (self.height_patch - self.height_over))), + self.img, + jnp.zeros((self.img.shape[0], (self.height_patch - self.height_over))))) + self.img = np.vstack((jnp.zeros(((self.width_patch - self.width_over), self.img.shape[1])), + self.img, + jnp.zeros(((self.width_patch - self.width_over), self.img.shape[1])))) + + self.height, self.width = self.img.shape + + self.nw_patches = (self.width + self.width_over) // (self.width_patch - self.width_over) - 2 + self.nh_patches = (self.height + self.height_over) // (self.height_patch - self.height_over) - 2 + + + + def create_patches(self, add_frame=False, center=True): + """ + This function will create small patches out of the image based on the provided attributes. + + Keyword Args: + add_frame: If true the function will add zero frames (increase the dimension) to the image + + Returns: + jnp.array: Array containing the patches + shape: (num_patches, patch_height, patch_width) + """ + + if add_frame == True: + self._add_frame() + + if center == True: + mu = np.mean(self.img, axis=0, keepdims=True) + self.img = self.img - mu + + result = [] + for nh_ in range(self.nh_patches): + for nw_ in range(self.nw_patches): + img_ = self.img[(self.height_patch - self.height_over) * nh_: nh_ * ( + self.height_patch - self.height_over) + self.height_patch + , (self.width_patch - self.width_over) * nw_: nw_ * ( + self.width_patch - self.width_over) + self.width_patch] + + if img_.shape == (self.height_patch, self.width_patch): + result.append(img_) + return jnp.array(result) \ No newline at end of file diff --git a/exhibits/time-integrated-stdp/patch_model/custom/ti_STDP_LCNSynapse.py b/exhibits/time-integrated-stdp/patch_model/custom/ti_STDP_LCNSynapse.py new file mode 100755 index 0000000..d089696 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/custom/ti_STDP_LCNSynapse.py @@ -0,0 +1,111 @@ +from ngclearn import Component, Compartment, resolver +from .LCNSynapse import LCNSynapse +from ngclearn.utils.model_utils import normalize_matrix + +from jax import numpy as jnp, random +import os.path + + +class TI_STDP_LCNSynapse(LCNSynapse): + def __init__(self, name, shape, model_shape=(1, 1), alpha=0.0075, + beta=0.5, pre_decay=0.5, resist_scale=1., p_conn=1., + weight_init=None, Aplus=1, Aminus=1, **kwargs): + super().__init__(name, shape=shape, model_shape=model_shape, + weight_init=weight_init, bias_init=None, + resist_scale=resist_scale, p_conn=p_conn, **kwargs) + + # Params + self.batch_size = 1 + self.alpha = alpha + self.beta = beta + self.pre_decay = pre_decay + self.Aplus = Aplus + self.Aminus = Aminus + + # Compartments + self.pre = Compartment(None) + self.post = Compartment(None) + + self.reset() + + @staticmethod + def _norm(weights, norm_scale): + return normalize_matrix(weights, wnorm=100, scale=norm_scale) + + @resolver(_norm) + def norm(self, weights): + self.weights.set(weights) + + @staticmethod + def _reset(batch_size, shape): + return jnp.zeros((batch_size, shape[0])), \ + jnp.zeros((batch_size, shape[1])), \ + jnp.zeros((batch_size, shape[0])), \ + jnp.zeros((batch_size, shape[1])) + + @resolver(_reset) + def reset(self, inputs, outputs, pre, post): + self.inputs.set(inputs) + self.outputs.set(outputs) + self.pre.set(pre) + self.post.set(post) + + @staticmethod + def _evolve(pre, post, weights, + shape, alpha, beta, pre_decay, Aplus, Aminus, + t, dt): + mask = jnp.where(0 != jnp.abs(weights), 1., 0.) + + pre_size = shape[0] + post_size = shape[1] + + pre_synaptic_binary_mask = jnp.where(pre > 0, 1., 0.) + post_synaptic_binary_mask = jnp.where(post > 0, 1., 0.) + + broadcast_pre_synaptic_binary_mask = jnp.zeros(shape) + jnp.reshape( + pre_synaptic_binary_mask, (pre_size, 1)) + broadcast_post_synaptic_binary_mask = jnp.zeros(shape) + jnp.reshape( + post_synaptic_binary_mask, (1, post_size)) + + broadcast_pre_synaptic_time = jnp.zeros(shape) + jnp.reshape(pre, ( + pre_size, 1)) + broadcast_post_synaptic_time = jnp.zeros(shape) + jnp.reshape(post, ( + 1, post_size)) + + no_pre_synaptic_spike_mask = broadcast_post_synaptic_binary_mask * ( + 1. - broadcast_pre_synaptic_binary_mask) + no_pre_synaptic_weight_update = (-alpha) * (pre_decay / jnp.exp( + (1 / dt) * (t - broadcast_post_synaptic_time))) + # no_pre_synaptic_weight_update = no_pre_synaptic_weight_update + + # Both have spiked + both_spike_mask = (broadcast_post_synaptic_binary_mask * + broadcast_pre_synaptic_binary_mask) + both_spike_update = (-alpha / ( + broadcast_pre_synaptic_time - broadcast_post_synaptic_time - ( + 0.5 * dt))) * \ + (beta / jnp.exp( + (1 / dt) * (t - broadcast_post_synaptic_time))) + + masked_no_pre_synaptic_weight_update = (no_pre_synaptic_spike_mask * + no_pre_synaptic_weight_update) + masked_both_spike_update = both_spike_mask * both_spike_update + + plasticity = jnp.where(masked_both_spike_update > 0, + Aplus * masked_both_spike_update, + Aminus * masked_both_spike_update) + + decay = masked_no_pre_synaptic_weight_update + + plasticity = plasticity * (1 - weights) + + decay = decay * weights + + update = plasticity + decay + _W = weights + update + ## return masked synaptic weight matrix (enforced structure) + return jnp.clip(_W, 0., 1.) * mask + + @resolver(_evolve) + def evolve(self, weights): + self.weights.set(weights) diff --git a/exhibits/time-integrated-stdp/patch_model/custom/ti_STDP_Synapse.py b/exhibits/time-integrated-stdp/patch_model/custom/ti_STDP_Synapse.py new file mode 100644 index 0000000..d8bc695 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/custom/ti_STDP_Synapse.py @@ -0,0 +1,97 @@ +from ngclearn import Component, Compartment, resolver +from ngclearn.components import DenseSynapse +from ngclearn.utils.model_utils import normalize_matrix +from jax import numpy as jnp + +class TI_STDP_Synapse(DenseSynapse): + def __init__(self, name, shape, alpha=0.0075, beta=0.5, pre_decay=0.5, + resist_scale=1., p_conn=1., weight_init=None, **kwargs): + super().__init__(name=name, shape=shape, resist_scale=resist_scale, + p_conn=p_conn, weight_init=weight_init, **kwargs) + + #Params + self.batch_size = 1 + self.alpha = alpha + self.beta = beta + self.pre_decay = pre_decay + self.Aplus = 1 + self.Aminus = 1 + + #Compartments + self.pre = Compartment(None) + self.post = Compartment(None) + + self.reset() + + @staticmethod + def _norm(weights, norm_scale): + return normalize_matrix(weights, wnorm=100, scale=norm_scale) + + @resolver(_norm) + def norm(self, weights): + self.weights.set(weights) + + + @staticmethod + def _reset(batch_size, shape): + return jnp.zeros((batch_size, shape[0])), \ + jnp.zeros((batch_size, shape[1])), \ + jnp.zeros((batch_size, shape[0])), \ + jnp.zeros((batch_size, shape[1])) + + @resolver(_reset) + def reset(self, inputs, outputs, pre, post): + self.inputs.set(inputs) + self.outputs.set(outputs) + self.pre.set(pre) + self.post.set(post) + + @staticmethod + def _evolve(pre, post, weights, + shape, alpha, beta, pre_decay, Aplus, Aminus, + t, dt): + pre_size = shape[0] + post_size = shape[1] + + pre_synaptic_binary_mask = jnp.where(pre > 0, 1., 0.) + post_synaptic_binary_mask = jnp.where(post > 0, 1., 0.) + + broadcast_pre_synaptic_binary_mask = jnp.zeros(shape) + jnp.reshape(pre_synaptic_binary_mask, (pre_size, 1)) + broadcast_post_synaptic_binary_mask = jnp.zeros(shape) + jnp.reshape(post_synaptic_binary_mask, (1, post_size)) + + broadcast_pre_synaptic_time = jnp.zeros(shape) + jnp.reshape(pre, (pre_size, 1)) + broadcast_post_synaptic_time = jnp.zeros(shape) + jnp.reshape(post, (1, post_size)) + + no_pre_synaptic_spike_mask = broadcast_post_synaptic_binary_mask * (1. - broadcast_pre_synaptic_binary_mask) + no_pre_synaptic_weight_update = (-alpha) * (pre_decay / jnp.exp((1 / dt) * (t - broadcast_post_synaptic_time))) + # no_pre_synaptic_weight_update = no_pre_synaptic_weight_update + + # Both have spiked + both_spike_mask = broadcast_post_synaptic_binary_mask * broadcast_pre_synaptic_binary_mask + both_spike_update = (-alpha / (broadcast_pre_synaptic_time - broadcast_post_synaptic_time - (0.5 * dt))) * \ + (beta / jnp.exp((1 / dt) * (t - broadcast_post_synaptic_time))) + + masked_no_pre_synaptic_weight_update = no_pre_synaptic_spike_mask * no_pre_synaptic_weight_update + masked_both_spike_update = both_spike_mask * both_spike_update + + plasticity = jnp.where(masked_both_spike_update > 0, + Aplus * masked_both_spike_update, + Aminus * masked_both_spike_update) + + decay = masked_no_pre_synaptic_weight_update + + plasticity = plasticity * (1 - weights) + + decay = decay * weights + + update = plasticity + decay + _W = weights + update + + return jnp.clip(_W, 0., 1.) + + @resolver(_evolve) + def evolve(self, weights): + self.weights.set(weights) + + + diff --git a/exhibits/time-integrated-stdp/patch_model/patch_tistdp_snn.py b/exhibits/time-integrated-stdp/patch_model/patch_tistdp_snn.py new file mode 100755 index 0000000..4b891b8 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/patch_tistdp_snn.py @@ -0,0 +1,235 @@ +from ngclearn import Context, numpy as jnp +from ngclearn.components import (LIFCell, PoissonCell, BernoulliCell, StaticSynapse, + VarTrace, Monitor) +from custom.ti_STDP_Synapse import TI_STDP_Synapse +from custom.ti_STDP_LCNSynapse import TI_STDP_LCNSynapse +from ngclearn.operations import summation +from jax import jit, random +from ngclearn.utils.viz.synapse_plot import viz_block +from ngclearn.utils.model_utils import scanner +import ngclearn.utils.weight_distribution as dists +from matplotlib import pyplot as plt +import ngclearn.utils as utils + +def get_nodes(model): + nodes = model.get_components("W1", "W1ie", "W1ei", "z0", "z1e", "z1i", + "W2", "W2ie", "W2ei", "z2e", "z2i", "M") + map = {} ## node-name look-up hash table + for node in nodes: + map[node.name] = node + return nodes, map + +def build_model(seed=1234, in_dim=1, in_patchShape=None, n_in_patches=1): + window_length = 250 #300 + + ## try no striding and normalize coefficients for filters in recon? + dt = 1 + X_size = int(jnp.sqrt(in_dim)) ## get input square dim + + z0_patchShape = in_patchShape #(7, 7) #(2, 2) #(7, 7) + z0_patchSize = z0_patchShape[0] * z0_patchShape[0] + z1_patchSize = 8 * 8 #6 * 6 # 5 * 5 # 4 * 4 #2 * 2 #4 * 4 + z1_patchCnt = n_in_patches #int(X_size/z0_patchShape[0]) * int(X_size/z0_patchShape[0]) # 16 + z1_RfieldCnt = 1 + + hidden_size = z1_patchSize * z1_patchCnt + out_size = 15 * 15 #10 * 10 # 8 * 8 #6 * 6 #4 * 4 + + ## INIT inhibitory/excitatory matrix with block diagonalization + + R1 = 1. #12. #6. #1. #12. #6. #1. #6. + R2 = 6. #12. + exc = 22.5 + inh = 120. #10. #120. #60. #15. #10. + tau_m_e = 100. # ms (excitatory membrane time constant) + tau_m_i = 100. # ms (inhibitory membrane time constant) + # tau_theta = 500. + # theta_plus = 0.2 + tau_theta = 1e5 #1e3 #1e5 #500. #1e4 #1e5 + theta_plus = 0.05 #0.1 #0.05 + thr_jitter = 0. + + px = py = X_size + hidx = hidy = int(jnp.sqrt(hidden_size)) + + dkey = random.PRNGKey(seed) + dkey, *subkeys = random.split(dkey, 12) + + with Context("model") as model: + M = Monitor("M", default_window_length=window_length) + ## layer 0 + z0 = PoissonCell("z0", n_units=in_dim, max_freq=63.75, key=subkeys[0]) + ## layer 1 + W1 = TI_STDP_LCNSynapse( + "W1", shape=(z0_patchSize, z1_patchSize), + model_shape=(z1_patchCnt, z1_RfieldCnt), + alpha=0.0075 * 0.5, beta=1.25, pre_decay=0.75, + weight_init=dists.uniform(amin=0.025, amax=0.8), + resist_scale=R1, key=subkeys[1] + ) + z1e = LIFCell("z1e", n_units=hidden_size, tau_m=tau_m_e, resist_m=tau_m_e/dt, + refract_time=5., thr_jitter=thr_jitter, tau_theta=tau_theta, + theta_plus=theta_plus, one_spike=True, key=subkeys[2]) + z1i = LIFCell("z1i", n_units=hidden_size, tau_m=tau_m_i, resist_m=tau_m_i/dt, + refract_time=5., thr_jitter=thr_jitter, thr=-40., v_rest=-60., + v_reset=-45., tau_theta=0.) + W1ie = StaticSynapse("W1ie", shape=(hidden_size, hidden_size), + weight_init=dists.hollow(-inh, block_diag_mask_width=z1_patchSize)) + W1ei = StaticSynapse("W1ei", shape=(hidden_size, hidden_size), + weight_init=dists.eye(exc, block_diag_mask_width=z1_patchSize)) + ## layer 2 + W2 = TI_STDP_Synapse("W2", alpha=0.025 * 2, beta=2, pre_decay=0.25 * 0.5, + shape=(hidden_size, out_size), + weight_init=dists.uniform(amin=0.025, amax=0.8), + resist_scale=R2, key=subkeys[6]) + z2e = LIFCell("z2e", n_units=out_size, tau_m=tau_m_e, resist_m=tau_m_e/dt, + refract_time=5., thr_jitter=thr_jitter, tau_theta=tau_theta, + theta_plus=theta_plus, one_spike=True, key=subkeys[4]) + z2i = LIFCell("z2i", n_units=out_size, tau_m=tau_m_i, resist_m=tau_m_i/dt, + refract_time=5., thr_jitter=thr_jitter, thr=-40., v_rest=-60., + v_reset=-45., tau_theta=0.) + W2ie = StaticSynapse("W2ie", shape=(out_size, out_size), + weight_init=dists.hollow(-inh)) + W2ei = StaticSynapse("W2ei", shape=(out_size, out_size), + weight_init=dists.eye(exc)) + + ## layer 0 to layer 1 + W1.inputs << z0.outputs + W1ie.inputs << z1i.s + z1e.j << summation(W1.outputs, W1ie.outputs) + W1ei.inputs << z1e.s + z1i.j << W1ei.outputs + ## layer 1 to layer 2 + W2.inputs << z1e.s_raw + W2ie.inputs << z2i.s + z2e.j << summation(W2.outputs, W2ie.outputs) + W2ei.inputs << z2e.s + z2i.j << W2ei.outputs + + # wire relevant plasticity statistics to synaptic cables W1 and W2 + W1.pre << z0.tols + W1.post << z1e.tols + W2.pre << z1e.tols + W2.post << z2e.tols + + ## wire statistics into global monitor + ## layer 1 stats + M << z1e.s + M << z1e.j + M << z1e.v + ## layer 2 stats + M << z2e.s + M << z2e.j + M << z2e.v + + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, M, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, + compile_key="reset") + model.wrap_and_add_command(jit(model.reset), name="reset") + + @model.dynamicCommand + def clamp(x): + z0.inputs.set(x) + + @model.dynamicCommand + def viz(name, low_rez=True, raster_name="exp/raster_plot"): + viz_block([W1.weights.value, W2.weights.value], + [(px, py), (hidx, hidy)], name + "_block", padding=2, + low_rez=low_rez) + + fig, ax = plt.subplots(3, 2, sharex=True, figsize=(15, 8)) + + for k in range(out_size): + ax[1][0].plot([i for i in range(window_length)], + M.view(z1e.v)[:, :, k]) + ax[0][0].plot([i for i in range(window_length)], + M.view(z1e.j)[:, :, k]) + + ax[1][1].plot([i for i in range(window_length)], + M.view(z2e.v)[:, :, k]) + ax[0][1].plot([i for i in range(window_length)], + M.view(z2e.j)[:, :, k]) + # print("----") + # data = M.view(z2e.v) + # print(jnp.amax(data, axis=0)) + # print("----") + + utils.viz.raster.create_raster_plot(M.view(z1e.s), ax=ax[2][0]) + utils.viz.raster.create_raster_plot(M.view(z2e.s), ax=ax[2][1]) + # plt.show() + plt.savefig(raster_name) + plt.close() + + @scanner + def observe(current_state, args): + _t, _dt = args + current_state = model.advance_state(current_state, t=_t, dt=_dt) + current_state = model.evolve(current_state, t=_t, dt=_dt) + return current_state, (current_state[z1e.v.path], + current_state[z2e.v.path]) + + @model.dynamicCommand + def showStats(i): + print(f"\n~~~~~Iteration {str(i)}~~~~~~") + print("W1:\n min {} ; max {} \n mu {} ; norm {}".format(jnp.amin(W1.weights.value), jnp.amax(W1.weights.value), jnp.mean(W1.weights.value), + jnp.linalg.norm(W1.weights.value))) + print("W2:\n min {} ; max {} \n mu {} ; norm {}".format(jnp.amin(W2.weights.value), jnp.amax(W2.weights.value), jnp.mean(W2.weights.value), + jnp.linalg.norm(W2.weights.value))) + #M.halt_all() + return model + +def load_from_disk(model_directory, param_dir="/custom", disable_adaptation=True): + with Context("model") as model: + model.load_from_dir(model_directory, custom_folder=param_dir) + nodes = model.get_components("W1", "W1ie", "W1ei", "W2", "W2ie", "W2ei", + "z0", "z1e", "z1i", "z2e", "z2i") + (W1, W1ie, W1ei, W2, W2ie, W2ei,z0, z1e, z1i, z2e, z2i) = nodes + if disable_adaptation: + z1e.tau_theta = 0. ## disable homeostatic adaptation + z2e.tau_theta = 0. ## disable homeostatic adaptation + + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, + compile_key="reset") + model.wrap_and_add_command(jit(model.reset), name="reset") + + @model.dynamicCommand + def clamp(x): + z0.inputs.set(x) + + @model.dynamicCommand + def viz(name, low_rez=True): + viz_block([W1.weights.value, W2.weights.value], + [(28, 28), (10, 10)], name + "_block", padding=2, + low_rez=low_rez) + + @scanner + def infer(current_state, args): + _t, _dt = args + current_state = model.advance_state(current_state, t=_t, dt=_dt) + return current_state, (current_state[z1e.s_raw.path], + current_state[z2e.s_raw.path]) + + @model.dynamicCommand + def showStats(i): + print(f"\n~~~~~Iteration {str(i)}~~~~~~") + print("W1:\n min {} ; max {} \n mu {} ; norm {}".format( + jnp.amin(W1.weights.value), jnp.amax(W1.weights.value), + jnp.mean(W1.weights.value), + jnp.linalg.norm(W1.weights.value))) + print("W2:\n min {} ; max {} \n mu {} ; norm {}".format( + jnp.amin(W2.weights.value), jnp.amax(W2.weights.value), + jnp.mean(W2.weights.value), + jnp.linalg.norm(W2.weights.value))) + return model + diff --git a/exhibits/time-integrated-stdp/patch_model/patched_train.py b/exhibits/time-integrated-stdp/patch_model/patched_train.py new file mode 100755 index 0000000..54726b8 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/patched_train.py @@ -0,0 +1,192 @@ +from jax import numpy as jnp, random +import sys, getopt as gopt, optparse, time +from ngclearn.utils.io_utils import makedir +from custom.patch_utils import Create_Patches +from ngclearn.utils.viz.synapse_plot import visualize + +import matplotlib #.pyplot as plt +matplotlib.use('Agg') +import matplotlib.pyplot as plt +cmap = plt.cm.jet + +def viz_fields(name, W1, n_in_patches): + z0_patchShape = (7, 7) + z0_patchSize = z0_patchShape[0] * z0_patchShape[0] + _W1 = W1.weights.value + B = [] + for i_ in range(_W1.shape[1]): + for j_ in range(n_in_patches): + _filter = _W1[z0_patchSize * j_:(j_ + 1) * z0_patchSize, i_:i_ + 1] + if jnp.sum(_filter) > 0.: + B.append(_filter) + B = jnp.concatenate(B, axis=1) + print("Viz.shape: ", B.shape) + visualize([B], [z0_patchShape], name + "_filters") + +def save_parameters(model_dir, nodes): ## model context saving routine + makedir(model_dir) + for node in nodes: + node.save(model_dir) ## call node's local save function + +################################################################################ +# read in general program arguments +options, remainder = gopt.getopt( + sys.argv[1:], '', ["dataX=", "dataY=", "n_samples=", "n_iter=", "verbosity=", + "bind_target=", "exp_dir=", "model_type=", "seed="] +) + +seed = 1234 +exp_dir = "exp/" +model_type = "tistdp" +n_iter = 1 # 10 ## total number passes through dataset +n_samples = -1 +bind_target = 40000 +dataX = "../../data/mnist/trainX.npy" +dataY = "../../data/mnist/trainY.npy" +verbosity = 0 ## verbosity level (0 - fairly minimal, 1 - prints multiple lines on I/O) +for opt, arg in options: + if opt in ("--dataX"): + dataX = arg.strip() + elif opt in ("--dataY"): + dataY = arg.strip() + elif opt in ("--verbosity"): + verbosity = int(arg.strip()) + elif opt in ("--n_samples"): + n_samples = int(arg.strip()) + elif opt in ("--n_iter"): + n_iter = int(arg.strip()) + elif opt in ("--bind_target"): + bind_target = int(arg.strip()) + elif opt in ("--exp_dir"): + exp_dir = arg.strip() + elif opt in ("--model_type"): + model_type = arg.strip() + elif opt in ("--seed"): + seed = int(arg.strip()) + +if model_type == "tistdp": + print(" >> Setting up TI-STDP Patch-Model builder!") + from patch_tistdp_snn import build_model, get_nodes +else: + print(" >> Model type ", model_type, " not supported!") + +print(">> X: {} Y: {}".format(dataX, dataY)) + +## load dataset +_X = jnp.load(dataX) +_Y = jnp.load(dataY) +if n_samples > 0: + _X = _X[0:n_samples, :] + _Y = _Y[0:n_samples, :] + print("-> Fitting model to only {} samples".format(n_samples)) +n_batches = _X.shape[0] ## num batches is = to num samples (online learning) + +## basic simulation hyper-parameter/configuration values go here +viz_mod = 1000 #10000 +mb_size = 1 ## locked to batch sizes of 1 +patch_shape = (28, 28) ## same as image_shape +num_patches = 10 +in_dim = patch_shape[0] * patch_shape[1] + +x = jnp.ones(patch_shape) +in_patchShape = (7, 7) +stride_shape = (0, 0) +patcher = Create_Patches(x, in_patchShape, stride_shape) +x_pat = patcher.create_patches(add_frame=False, center=True) +n_in_patches = patcher.nw_patches * patcher.nh_patches + +T = 250 ## num time steps to simulate (stimulus presentation window length) +dt = 1. ## integration time constant + +################################################################################ +print("--- Building Model ---") +dkey = random.PRNGKey(seed) +dkey, *subkeys = random.split(dkey, 3) +## Create model +makedir(exp_dir) +## get model from imported header file +model = build_model(seed, in_dim=in_dim, n_in_patches=n_in_patches, + in_patchShape=in_patchShape) +nodes, node_map = get_nodes(model) +model.save_to_json(exp_dir, model_type) +################################################################################ +print("--- Starting Simulation ---") + +sim_start_time = time.time() ## start time profiling + +model.viz(name="{}{}".format(exp_dir, "recFields"), low_rez=True, + raster_name="{}{}".format(exp_dir, "raster_plot")) + +print("------------------------------------") +model.showStats(-1) + +model_dir = "{}{}/custom_snapshot{}".format(exp_dir, model_type, 0) +save_parameters(model_dir, nodes) +viz_fields(name="{}{}".format(exp_dir, "init_W1"), W1=node_map.get("W1"), n_in_patches=n_in_patches) + +## enter main adaptation loop over data patterns +class_responses = jnp.zeros((_Y.shape[1], node_map.get("z2e").n_units)) +num_bound = 0 +n_total_samp_seen = 0 +for i in range(n_iter): + dkey, *subkeys = random.split(dkey, 2) + ptrs = random.permutation(subkeys[0], _X.shape[0]) + X = _X[ptrs, :] + Y = _Y[ptrs, :] + + tstart = time.time() + n_samps_seen = 0 + for j in range(n_batches): + idx = j + Xb = X[idx: idx + mb_size, :] + Yb = Y[idx: idx + mb_size, :] + + model.reset() + patcher.img = jnp.reshape(Xb, (28, 28)) + xs = patcher.create_patches(add_frame=False, center=True) + xs = xs.reshape(1, -1) + model.clamp(xs) + spikes1, spikes2 = model.observe(jnp.array([[dt * k, dt] for k in range(T)])) + + if n_total_samp_seen >= bind_target: + responses = Yb.T * jnp.sum(spikes2, axis=0) + class_responses = class_responses + responses + num_bound += 1 + + n_samps_seen += Xb.shape[0] + n_total_samp_seen += Xb.shape[0] + + print("\r Seen {} images (Binding {})...".format(n_samps_seen, num_bound), end="") + if (j+1) % viz_mod == 0: ## save intermediate receptive fields + tend = time.time() + print() + print(" -> Time = {} s".format(tend - tstart)) + tstart = tend + 0. + model.showStats(i) + viz_fields(name="{}{}".format(exp_dir, "W1"), W1=node_map.get("W1"), n_in_patches=n_in_patches) + model.viz(name="{}{}".format(exp_dir, "recFields"), low_rez=True, + raster_name="{}{}".format(exp_dir, "raster_plot")) + + ## save a running/current overridden copy of NPZ parameters + model_dir = "{}{}/custom".format(exp_dir, model_type) + save_parameters(model_dir, nodes) + + ## end of iteration/epoch + ## save a snapshot of the NPZ parameters at this particular epoch + model_dir = "{}{}/custom_snapshot{}".format(exp_dir, model_type, i) + save_parameters(model_dir, nodes) +print() + +class_responses = jnp.argmax(class_responses, axis=0, keepdims=True) +print("---- Max Class Responses ----") +print(class_responses) +print(class_responses.shape) +jnp.save("{}binded_labels.npy".format(exp_dir), class_responses) + +## stop time profiling +sim_end_time = time.time() +sim_time = sim_end_time - sim_start_time +sim_time_hr = (sim_time/3600.0) # convert time to hours + +print("------------------------------------") +print(" Trial.sim_time = {} h ({} sec)".format(sim_time_hr, sim_time)) diff --git a/exhibits/time-integrated-stdp/patch_model/sample_model.sh b/exhibits/time-integrated-stdp/patch_model/sample_model.sh new file mode 100755 index 0000000..66b948e --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/sample_model.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +## get in user-provided program args +GPU_ID=$1 #1 +MODEL=$2 # evstdp trstdp tistdp stdp + +if [[ "$MODEL" != "evstdp" && "$MODEL" != "tistdp" ]]; then + echo "Invalid Arg: $MODEL -- only 'tistdp', 'evstdp' models supported!" + exit 1 +fi +echo " >>>> Setting up $MODEL on GPU $GPU_ID" + +PARAM_SUBDIR="/custom" +SEED=1234 #(1234 77 811) +N_SAMPLES=5000 #1000 #50000 +DATA_X="../../data/mnist/trainX.npy" +DATA_Y="../../data/mnist/trainY.npy" +#DEV_X="../../data/mnist/testX.npy" # validX.npy +#DEV_Y="../../data/mnist/testY.npy" # validY.npy + +EXP_DIR="exp_$MODEL""_$SEED/" +echo " > Running Simulation/Model: $EXP_DIR" + +## train model +CUDA_VISIBLE_DEVICES=$GPU_ID python assemble_patterns.py --dataX=$DATA_X --dataY=$DATA_Y \ + --n_samples=$N_SAMPLES --exp_dir=$EXP_DIR \ + --model_dir=$EXP_DIR$MODEL \ + --param_subdir=$PARAM_SUBDIR \ + --model_type=$MODEL --seed=$SEED diff --git a/exhibits/time-integrated-stdp/patch_model/train_by_patch.py b/exhibits/time-integrated-stdp/patch_model/train_by_patch.py new file mode 100755 index 0000000..3849001 --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/train_by_patch.py @@ -0,0 +1,167 @@ +from jax import numpy as jnp, random +import sys, getopt as gopt, optparse, time +from ngclearn import Context +from ngclearn.utils.io_utils import makedir +from ngclearn.utils.patch_utils import generate_patch_set + +import matplotlib #.pyplot as plt +matplotlib.use('Agg') +import matplotlib.pyplot as plt +cmap = plt.cm.jet + +def save_parameters(model_dir, nodes): ## model context saving routine + makedir(model_dir) + for node in nodes: + node.save(model_dir) ## call node's local save function + +################################################################################ +# read in general program arguments +options, remainder = gopt.getopt(sys.argv[1:], '', ["dataX=", "dataY=", "n_samples=", + "n_iter=", "verbosity=", + "bind_target=", "exp_dir=", + "model_type=", "seed="]) + +seed = 1234 +exp_dir = "exp/" +model_type = "tistdp" +n_iter = 1 # 10 ## total number passes through dataset +n_samples = -1 +bind_target = 40000 +dataX = "../../data/mnist/trainX.npy" +dataY = "../../data/mnist/trainY.npy" +verbosity = 0 ## verbosity level (0 - fairly minimal, 1 - prints multiple lines on I/O) +for opt, arg in options: + if opt in ("--dataX"): + dataX = arg.strip() + elif opt in ("--dataY"): + dataY = arg.strip() + elif opt in ("--verbosity"): + verbosity = int(arg.strip()) + elif opt in ("--n_samples"): + n_samples = int(arg.strip()) + elif opt in ("--n_iter"): + n_iter = int(arg.strip()) + elif opt in ("--bind_target"): + bind_target = int(arg.strip()) + elif opt in ("--exp_dir"): + exp_dir = arg.strip() + elif opt in ("--model_type"): + model_type = arg.strip() + elif opt in ("--seed"): + seed = int(arg.strip()) +if model_type == "tistdp": + print(" >> Setting up TI-STDP builder!") + from tistdp_snn import build_model, get_nodes +elif model_type == "trstdp": + print(" >> Setting up Trace-based STDP builder!") + from trstdp_snn import build_model, get_nodes +elif model_type == "evstdp": + print(" >> Setting up Event-Driven STDP builder!") + from evstdp_snn import build_model, get_nodes +elif model_type == "stdp": + print(" >> Setting up classical STDP builder!") + from stdp_snn import build_model, get_nodes +print(">> X: {} Y: {}".format(dataX, dataY)) + +## load dataset +_X = jnp.load(dataX) +_Y = jnp.load(dataY) +if n_samples > 0: + _X = _X[0:n_samples, :] + _Y = _Y[0:n_samples, :] + print("-> Fitting model to only {} samples".format(n_samples)) +n_batches = _X.shape[0] ## num batches is = to num samples (online learning) + +## basic simulation hyper-parameter/configuration values go here +viz_mod = 1000 #10000 +mb_size = 1 ## locked to batch sizes of 1 +patch_shape = (10, 10) #(28, 28) +in_dim = patch_shape[0] * patch_shape[1] +num_patches = 10 + +T = 250 #300 ## num time steps to simulate (stimulus presentation window length) +dt = 1. ## integration time constant + +################################################################################ +print("--- Building Model ---") +dkey = random.PRNGKey(seed) +dkey, *subkeys = random.split(dkey, 3) +## Create model +makedir(exp_dir) +model = build_model(seed, in_dim=in_dim) #Context("model") ## get model from imported header file +nodes, node_map = get_nodes(model) +model.save_to_json(exp_dir, model_type) +################################################################################ +print("--- Starting Simulation ---") + +sim_start_time = time.time() ## start time profiling + +model.viz(name="{}{}".format(exp_dir, "recFields"), low_rez=True, + raster_name="{}{}".format(exp_dir, "raster_plot")) + +print("------------------------------------") +model.showStats(-1) + +model_dir = "{}{}/custom_snapshot{}".format(exp_dir, model_type, 0) +save_parameters(model_dir, nodes) + +## enter main adaptation loop over data patterns +class_responses = jnp.zeros((_Y.shape[1], node_map.get("z2e").n_units)) +num_bound = 0 +n_total_samp_seen = 0 +for i in range(n_iter): + dkey, *subkeys = random.split(dkey, 2) + ptrs = random.permutation(subkeys[0], _X.shape[0]) + X = _X[ptrs, :] + Y = _Y[ptrs, :] + + tstart = time.time() + n_samps_seen = 0 + for j in range(n_batches): + idx = j + Xb = X[idx: idx + mb_size, :] + Yb = Y[idx: idx + mb_size, :] + + ## generate a set of patches from current pattern + Xb = generate_patch_set(Xb, patch_shape, num_patches, center=False) + for p in range(Xb.shape[0]): # within a batch of patches, adapt SNN + xs = jnp.expand_dims(Xb[p, :], axis=0) + flag = jnp.sum(xs) + if flag > 0.: + model.reset() + model.clamp(xs) + spikes1, spikes2 = model.observe(jnp.array([[dt * k, dt] for k in range(T)])) + + n_samps_seen += Xb.shape[0] + n_total_samp_seen += Xb.shape[0] + print("\r Seen {} images (Binding {})...".format(n_samps_seen, num_bound), end="") + if (j+1) % viz_mod == 0: ## save intermediate receptive fields + tend = time.time() + print() + print(" -> Time = {} s".format(tend - tstart)) + tstart = tend + 0. + model.showStats(i) + model.viz(name="{}{}".format(exp_dir, "recFields"), low_rez=True, + raster_name="{}{}".format(exp_dir, "raster_plot")) + + ## save a running/current overridden copy of NPZ parameters + model_dir = "{}{}/custom".format(exp_dir, model_type) + save_parameters(model_dir, nodes) + + ## end of iteration/epoch + ## save a snapshot of the NPZ parameters at this particular epoch + model_dir = "{}{}/custom_snapshot{}".format(exp_dir, model_type, i) + save_parameters(model_dir, nodes) +print() + +## stop time profiling +sim_end_time = time.time() +sim_time = sim_end_time - sim_start_time +sim_time_hr = (sim_time/3600.0) # convert time to hours + +#plot_clusters(W1, W2) + +print("------------------------------------") +print(" Trial.sim_time = {} h ({} sec)".format(sim_time_hr, sim_time)) + + diff --git a/exhibits/time-integrated-stdp/patch_model/train_patch_models.sh b/exhibits/time-integrated-stdp/patch_model/train_patch_models.sh new file mode 100755 index 0000000..db0c7fb --- /dev/null +++ b/exhibits/time-integrated-stdp/patch_model/train_patch_models.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +## get in user-provided program args +GPU_ID=$1 #1 +MODEL=$2 # evstdp trstdp tistdp stdp + +if [[ "$MODEL" != "evstdp" && "$MODEL" != "tistdp" ]]; then + echo "Invalid Arg: $MODEL -- only 'tistdp', 'evstdp' models supported!" + exit 1 +fi +echo " >>>> Setting up $MODEL on GPU $GPU_ID" + +SEEDS=(1234) +N_ITER=10 +N_SAMPLES=50000 #10000 #5000 #1000 #50000 +BIND_COUNT=10000 +BIND_TARGET=$((($N_ITER - 1) * $N_SAMPLES + ($N_SAMPLES - $BIND_COUNT))) + +DATA_X="../../data/mnist/trainX.npy" +DATA_Y="../../data/mnist/trainY.npy" +#DEV_X="../../data/mnist/testX.npy" # validX.npy +#DEV_Y="../../data/mnist/testY.npy" # validY.npy + +if (( N_ITER * N_SAMPLES < BIND_COUNT )) ; then + echo "Not enough samples to reach bind target!" + exit 1 +fi + +for seed in "${SEEDS[@]}" +do + EXP_DIR="exp_$MODEL""_$seed/" + echo " > Running Simulation/Model: $EXP_DIR" + + rm -r $EXP_DIR* + ## train model + CUDA_VISIBLE_DEVICES=$GPU_ID python patched_train.py --dataX=$DATA_X --dataY=$DATA_Y \ + --n_iter=$N_ITER --bind_target=$BIND_TARGET \ + --n_samples=$N_SAMPLES --exp_dir=$EXP_DIR \ + --model_type=$MODEL --seed=$seed +done diff --git a/exhibits/time-integrated-stdp/snn_case1.py b/exhibits/time-integrated-stdp/snn_case1.py new file mode 100755 index 0000000..b3ba73a --- /dev/null +++ b/exhibits/time-integrated-stdp/snn_case1.py @@ -0,0 +1,326 @@ +from ngclearn import Context, numpy as jnp +import math +from ngclearn.components import (LIFCell, PoissonCell, StaticSynapse, + VarTrace, Monitor, EventSTDPSynapse, TraceSTDPSynapse) +from custom.ti_stdp_synapse import TI_STDP_Synapse +from ngclearn.operations import summation +from jax import jit, random +from ngclearn.utils.viz.raster import create_raster_plot +from ngclearn.utils.viz.synapse_plot import viz_block +from ngclearn.utils.model_utils import scanner +import ngclearn.utils.weight_distribution as dists +from ngclearn.utils.model_utils import normalize_matrix +from matplotlib import pyplot as plt +import ngclearn.utils as utils + +def get_nodes(model): + nodes = model.get_components("W1", "W1ie", "W1ei", "z0", "z1e", "z1i", + "W2", "W2ie", "W2ei", "z2e", "z2i", "M") + map = {} ## node-name look-up hash table + for node in nodes: + map[node.name] = node + return nodes, map + +def build_model(seed=1234, in_dim=1, is_patch_model=False, algo="tistdp"): + window_length = 250 #300 + + dt = 1 + X_size = int(jnp.sqrt(in_dim)) + # X_size = 28 + # in_dim = (X_size * X_size) + if is_patch_model: + hidden_size = 10 * 10 + out_size = 6 * 6 + else: + hidden_size = 25 * 25 + out_size = 15 * 15 + + exc = 22.5 + inh = 120. #60. #15. #10. + tau_m_e = 100. # ms (excitatory membrane time constant) + tau_m_i = 100. # ms (inhibitory membrane time constant) + tau_theta = 1e5 + theta_plus = 0.05 + thr_jitter = 0. + + R1 = R2 = 1. + if algo == "tistdp": + R1 = 1. + R2 = 6. #1. + elif algo == "evstdp": + R1 = 1. # 6. + R2 = 6. # 12. + elif algo == "trstdp": + R1 = 1. # 6. + R2 = 6. # 12. + tau_theta = 1e4 #1e5 + theta_plus = 0.2 #0.05 + + px = py = X_size + hidx = hidy = int(jnp.sqrt(hidden_size)) + + dkey = random.PRNGKey(seed) + dkey, *subkeys = random.split(dkey, 12) + + with Context("model") as model: + M = Monitor("M", default_window_length=window_length) + ## layer 0 + z0 = PoissonCell("z0", n_units=in_dim, max_freq=63.75, key=subkeys[0]) + ## layer 1 + z1e = LIFCell("z1e", n_units=hidden_size, tau_m=tau_m_e, resist_m=tau_m_e/dt, + refract_time=5., thr_jitter=thr_jitter, tau_theta=tau_theta, + theta_plus=theta_plus, one_spike=True, key=subkeys[2]) + z1i = LIFCell("z1i", n_units=hidden_size, tau_m=tau_m_i, resist_m=tau_m_i/dt, + refract_time=5., thr_jitter=thr_jitter, thr=-40., v_rest=-60., + v_reset=-45., tau_theta=0.) + W1ie = StaticSynapse("W1ie", shape=(hidden_size, hidden_size), + weight_init=dists.hollow(-inh)) + W1ei = StaticSynapse("W1ei", shape=(hidden_size, hidden_size), + weight_init=dists.eye(exc)) + ## layer 2 + z2e = LIFCell("z2e", n_units=out_size, tau_m=tau_m_e, resist_m=tau_m_e/dt, + refract_time=5., thr_jitter=thr_jitter, tau_theta=tau_theta, + theta_plus=theta_plus, one_spike=True, key=subkeys[4]) + z2i = LIFCell("z2i", n_units=out_size, tau_m=tau_m_i, resist_m=tau_m_i/dt, + refract_time=5., thr_jitter=thr_jitter, thr=-40., v_rest=-60., + v_reset=-45., tau_theta=0.) + W2ie = StaticSynapse("W2ie", shape=(out_size, out_size), + weight_init=dists.hollow(-inh)) + W2ei = StaticSynapse("W2ei", shape=(out_size, out_size), + weight_init=dists.eye(exc)) + + tr0 = tr1 = tr2 = None + if algo == "tistdp": + print(" >> Equipping SNN with TI-STDP Adaptation") + W1 = TI_STDP_Synapse("W1", alpha=0.0075 * 0.5, beta=1.25, + pre_decay=0.75, + shape=((X_size ** 2), hidden_size), + weight_init=dists.uniform(amin=0.025, + amax=0.8), + resist_scale=R1, key=subkeys[1]) + W2 = TI_STDP_Synapse("W2", alpha=0.025 * 2, beta=2, + pre_decay=0.25 * 0.5, + shape=(hidden_size, out_size), + weight_init=dists.uniform(amin=0.025, + amax=0.8), + resist_scale=R2, key=subkeys[6]) + # wire relevant plasticity statistics to synaptic cables W1 and W2 + W1.pre << z0.tols + W1.post << z1e.tols + W2.pre << z1e.tols + W2.post << z2e.tols + elif algo == "trstdp": + print(" >> Equipping SNN with Trace/TR-STDP Adaptation") + tau_tr = 20. # 10. + trace_delta = 0. + x_tar1 = 0.3 + x_tar2 = 0.025 + Aplus = 1e-2 ## LTP learning rate (STDP); nu1 + Aminus = 1e-3 ## LTD learning rate (STDP); nu0 + mu = 1. # 0. + W1 = TraceSTDPSynapse("W1", shape=(in_dim, hidden_size), mu=mu, + A_plus=Aplus, A_minus=Aminus, eta=1., + pretrace_target=x_tar1, + weight_init=dists.uniform(amin=0.025, amax=0.8), + resist_scale=R1, key=subkeys[1]) + W2 = TraceSTDPSynapse("W2", shape=(hidden_size, out_size), mu=mu, + A_plus=Aplus, A_minus=Aminus, eta=1., + pretrace_target=x_tar2, resist_scale=R2, + weight_init=dists.uniform(amin=0.025, amax=0.8), + key=subkeys[3]) + ## traces + tr0 = VarTrace("tr0", n_units=in_dim, tau_tr=tau_tr, + decay_type="exp", + a_delta=trace_delta) + tr1 = VarTrace("tr1", n_units=hidden_size, tau_tr=tau_tr, + decay_type="exp", + a_delta=trace_delta) + tr2 = VarTrace("tr2", n_units=out_size, tau_tr=tau_tr, + decay_type="exp", + a_delta=trace_delta) + # wire cells z0 and z1e to their respective traces + tr0.inputs << z0.outputs + tr1.inputs << z1e.s + tr2.inputs << z2e.s + + # wire relevant compartment statistics to synaptic cables W1 and W2 + W1.preTrace << tr0.trace + W1.preSpike << z0.outputs + W1.postTrace << tr1.trace + W1.postSpike << z1e.s + + W2.preTrace << tr1.trace + W2.preSpike << z1e.s + W2.postTrace << tr2.trace + W2.postSpike << z2e.s + elif algo == "evstdp": + print(" >> Equipping SNN with EV-STDP Adaptation") + ## EV-STDP meta-parameters + eta_w = 1. + Aplus1 = 0.0055 + Aminus1 = 0.25 * 0.0055 + Aplus2 = 0.0055 + Aminus2 = 0.05 * 0.0055 + lmbda = 0. # 0.01 + W1 = EventSTDPSynapse("W1", shape=(in_dim, hidden_size), eta=eta_w, + A_plus=Aplus1, A_minus=Aminus1, + lmbda=lmbda, w_bound=1., presyn_win_len=2., + weight_init=dists.uniform(0.025, 0.8), + resist_scale=R1, key=subkeys[1]) + W2 = EventSTDPSynapse("W2", shape=(hidden_size, out_size), + eta=eta_w, + A_plus=Aplus2, A_minus=Aminus2, + lmbda=lmbda, w_bound=1., presyn_win_len=2., + weight_init=dists.uniform(amin=0.025, amax=0.8), + resist_scale=R2, key=subkeys[6]) + # wire relevant plasticity statistics to synaptic cables W1 and W2 + W1.pre_tols << z0.tols + W1.postSpike << z1e.s + W2.pre_tols << z1e.tols + W2.postSpike << z2e.s + else: + print("ERROR: algorithm ", algo, " provided is not supported!") + exit() + + ## layer 0 to layer 1 + W1.inputs << z0.outputs + W1ie.inputs << z1i.s + z1e.j << summation(W1.outputs, W1ie.outputs) + W1ei.inputs << z1e.s + z1i.j << W1ei.outputs + ## layer 1 to layer 2 + W2.inputs << z1e.s_raw + W2ie.inputs << z2i.s + z2e.j << summation(W2.outputs, W2ie.outputs) + W2ei.inputs << z2e.s + z2i.j << W2ei.outputs + + ## wire statistics into global monitor + ## layer 1 stats + M << z1e.s + M << z1e.j + M << z1e.v + ## layer 2 stats + M << z2e.s + M << z2e.j + M << z2e.v + + if algo == "trstdp": + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, + tr0, tr1, tr2, M, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, + compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, tr0, + tr1, tr2, + compile_key="reset") + else: + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, M, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, + compile_key="reset") + model.wrap_and_add_command(jit(model.reset), name="reset") + + @model.dynamicCommand + def clamp(x): + z0.inputs.set(x) + + @model.dynamicCommand + def viz(name, low_rez=True, raster_name="exp/raster_plot"): + viz_block([W1.weights.value, W2.weights.value], + [(px, py), (hidx, hidy)], name + "_block", padding=2, + low_rez=low_rez) + + fig, ax = plt.subplots(3, 2, sharex=True, figsize=(15, 8)) + + for i in range(out_size): + ax[1][0].plot([i for i in range(window_length)], + M.view(z1e.v)[:, :, i]) + ax[0][0].plot([i for i in range(window_length)], + M.view(z1e.j)[:, :, i]) + + ax[1][1].plot([i for i in range(window_length)], + M.view(z2e.v)[:, :, i]) + ax[0][1].plot([i for i in range(window_length)], + M.view(z2e.j)[:, :, i]) + utils.viz.raster.create_raster_plot(M.view(z1e.s), ax=ax[2][0]) + utils.viz.raster.create_raster_plot(M.view(z2e.s), ax=ax[2][1]) + # plt.show() + plt.savefig(raster_name) + plt.close() + + @scanner + def observe(current_state, args): + _t, _dt = args + current_state = model.advance_state(current_state, t=_t, dt=_dt) + current_state = model.evolve(current_state, t=_t, dt=_dt) + return current_state, (current_state[z1e.s_raw.path], + current_state[z2e.s_raw.path]) + + @model.dynamicCommand + def showStats(i): + print(f"\n~~~~~Iteration {str(i)}~~~~~~") + print("W1:\n min {} ; max {} \n mu {} ; norm {}".format(jnp.amin(W1.weights.value), jnp.amax(W1.weights.value), jnp.mean(W1.weights.value), + jnp.linalg.norm(W1.weights.value))) + print("W2:\n min {} ; max {} \n mu {} ; norm {}".format(jnp.amin(W2.weights.value), jnp.amax(W2.weights.value), jnp.mean(W2.weights.value), + jnp.linalg.norm(W2.weights.value))) + #M.halt_all() + return model + +def load_from_disk(model_directory, param_dir="/custom", disable_adaptation=True): + with Context("model") as model: + model.load_from_dir(model_directory, custom_folder=param_dir) + nodes = model.get_components("W1", "W1ie", "W1ei", "W2", "W2ie", "W2ei", + "z0", "z1e", "z1i", "z2e", "z2i") + (W1, W1ie, W1ei, W2, W2ie, W2ei,z0, z1e, z1i, z2e, z2i) = nodes + if disable_adaptation: + z1e.tau_theta = 0. ## disable homeostatic adaptation + z2e.tau_theta = 0. ## disable homeostatic adaptation + + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, + compile_key="reset") + model.wrap_and_add_command(jit(model.reset), name="reset") + + @model.dynamicCommand + def clamp(x): + z0.inputs.set(x) + + @model.dynamicCommand + def viz(name, low_rez=True): + viz_block([W1.weights.value, W2.weights.value], + [(28, 28), (10, 10)], name + "_block", padding=2, + low_rez=low_rez) + + @scanner + def infer(current_state, args): + _t, _dt = args + current_state = model.advance_state(current_state, t=_t, dt=_dt) + return current_state, (current_state[z1e.s_raw.path], + current_state[z2e.s_raw.path]) + + @model.dynamicCommand + def showStats(i): + print(f"\n~~~~~Iteration {str(i)}~~~~~~") + print("W1:\n min {} ; max {} \n mu {} ; norm {}".format( + jnp.amin(W1.weights.value), jnp.amax(W1.weights.value), + jnp.mean(W1.weights.value), + jnp.linalg.norm(W1.weights.value))) + print("W2:\n min {} ; max {} \n mu {} ; norm {}".format( + jnp.amin(W2.weights.value), jnp.amax(W2.weights.value), + jnp.mean(W2.weights.value), + jnp.linalg.norm(W2.weights.value))) + return model + diff --git a/exhibits/time-integrated-stdp/snn_case2.py b/exhibits/time-integrated-stdp/snn_case2.py new file mode 100755 index 0000000..489d3f2 --- /dev/null +++ b/exhibits/time-integrated-stdp/snn_case2.py @@ -0,0 +1,324 @@ +from ngclearn import Context, numpy as jnp +import math +from ngclearn.components import (LIFCell, PoissonCell, StaticSynapse, + VarTrace, Monitor, EventSTDPSynapse, TraceSTDPSynapse) +from custom.ti_stdp_synapse import TI_STDP_Synapse +from ngclearn.operations import summation +from jax import jit, random +from ngclearn.utils.viz.raster import create_raster_plot +from ngclearn.utils.viz.synapse_plot import viz_block +from ngclearn.utils.model_utils import scanner +import ngclearn.utils.weight_distribution as dists +from ngclearn.utils.model_utils import normalize_matrix +from matplotlib import pyplot as plt +import ngclearn.utils as utils + +def get_nodes(model): + nodes = model.get_components("W1", "W1ie", "W1ei", "z0", "z1e", "z1i", + "W2", "W2ie", "W2ei", "z2e", "z2i", "M") + map = {} ## node-name look-up hash table + for node in nodes: + map[node.name] = node + return nodes, map + +def build_model(seed=1234, in_dim=1, is_patch_model=False, algo="tistdp"): + window_length = 250 #300 + + dt = 1 + X_size = int(jnp.sqrt(in_dim)) + if is_patch_model: + hidden_size = 10 * 10 + out_size = 6 * 6 + else: + hidden_size = 25 * 25 + out_size = 15 * 15 + + # exc = 22.5 + # inh = 120. #60. #15. #10. + tau_m_e = 100. # ms (excitatory membrane time constant) + tau_m_i = 100. # ms (inhibitory membrane time constant) + # tau_theta = 1e5 + # theta_plus = 0.05 + # thr_jitter = 0. + + R1 = R2 = 1. + if algo == "tistdp": + R1 = 1. + R2 = 1. + elif algo == "evstdp": + R1 = 1. #12. #1. + R2 = 6. # 12. + elif algo == "trstdp": + R1 = 1. # 6. + R2 = 6. # 12. + + px = py = X_size + hidx = hidy = int(jnp.sqrt(hidden_size)) + + dkey = random.PRNGKey(seed) + dkey, *subkeys = random.split(dkey, 12) + + with Context("model") as model: + M = Monitor("M", default_window_length=window_length) + ## layer 0 + z0 = PoissonCell("z0", n_units=in_dim, max_freq=63.75, key=subkeys[0]) + ## layer 1 + z1e = LIFCell("z1e", n_units=hidden_size, tau_m=tau_m_e, + resist_m=25, tau_theta=1e4, theta_plus=0.2, refract_time=5., + one_spike=True, key=subkeys[2]) + z1i = LIFCell("z1i", n_units=hidden_size, tau_m=tau_m_i, + resist_m=tau_m_i / dt, thr=-40., v_rest=-60., + v_reset=-45., tau_theta=0., refract_time=5., ) + W1ie = StaticSynapse("W1ie", shape=(hidden_size, hidden_size), + weight_init=dists.hollow(-10)) + W1ei = StaticSynapse("W1ei", shape=(hidden_size, hidden_size), + weight_init=dists.eye(22.5)) + ## layer 2 + z2e = LIFCell("z2e", n_units=out_size, tau_m=tau_m_e, + resist_m=500., thr_jitter=20., tau_theta=1e4, refract_time=5., + theta_plus=0.2, one_spike=True, key=subkeys[4]) + z2i = LIFCell("z2i", n_units=out_size, tau_m=tau_m_i, + resist_m=tau_m_i / dt, refract_time=5., + thr=-40., v_rest=-60., v_reset=-45., tau_theta=0.,) + ## if make z2e bigger, then make the inhibition weaker possibly + W2ie = StaticSynapse("W2ie", shape=(out_size, out_size), + weight_init=dists.hollow(-10)) + W2ei = StaticSynapse("W2ei", shape=(out_size, out_size), + weight_init=dists.eye(22.5)) + + tr0 = tr1 = tr2 = None + if algo == "tistdp": + # Working numbers 1.5, 0.2 + # Single pass -> alpha=0.025, beta=2, pre_decay=0.25 * 0.5, acc: 60% + print(" >> Equipping SNN with TI-STDP Adaptation") + W1 = TI_STDP_Synapse("W1", alpha=0.0075 * 0.5, beta=1.25, + pre_decay=0.75, + shape=((X_size ** 2), hidden_size), + weight_init=dists.uniform(amin=0.025, + amax=0.8), + resist_scale=R1, key=subkeys[1]) + W2 = TI_STDP_Synapse("W2", alpha=0.025, beta=2, + pre_decay=0.25 * 0.5, + shape=(hidden_size, out_size), + weight_init=dists.uniform(amin=0.025, + amax=0.8), + resist_scale=R2, key=subkeys[6]) + # wire relevant plasticity statistics to synaptic cables W1 and W2 + W1.pre << z0.tols + W1.post << z1e.tols + W2.pre << z1e.tols + W2.post << z2e.tols + elif algo == "trstdp": + print(" >> Equipping SNN with Trace/TR-STDP Adaptation") + tau_tr = 20. # 40. # 10. + trace_delta = 0. + x_tar1 = 0.3 + x_tar2 = 0.025 + Aplus = 1e-2 ## LTP learning rate (STDP); nu1 + Aminus = 1e-3 ## LTD learning rate (STDP); nu0 + mu = 1. # 0. + W1 = TraceSTDPSynapse("W1", shape=(in_dim, hidden_size), mu=mu, + A_plus=Aplus, A_minus=Aminus, eta=1., + pretrace_target=x_tar1, + weight_init=dists.uniform(amin=0.025, amax=0.8), + resist_scale=R1, key=subkeys[1]) + W2 = TraceSTDPSynapse("W2", shape=(hidden_size, out_size), mu=mu, + A_plus=Aplus, A_minus=Aminus, eta=1., + pretrace_target=x_tar2, resist_scale=R2, + weight_init=dists.uniform(amin=0.025, amax=0.8), + key=subkeys[3]) + ## traces + tr0 = VarTrace("tr0", n_units=in_dim, tau_tr=tau_tr, + decay_type="exp", + a_delta=trace_delta) + tr1 = VarTrace("tr1", n_units=hidden_size, tau_tr=tau_tr, + decay_type="exp", + a_delta=trace_delta) + tr2 = VarTrace("tr2", n_units=out_size, tau_tr=tau_tr, + decay_type="exp", + a_delta=trace_delta) + # wire cells z0 and z1e to their respective traces + tr0.inputs << z0.outputs + tr1.inputs << z1e.s + tr2.inputs << z2e.s + # wire relevant compartment statistics to synaptic cables W1 and W2 + W1.preTrace << tr0.trace + W1.preSpike << z0.outputs + W1.postTrace << tr1.trace + W1.postSpike << z1e.s + + W2.preTrace << tr1.trace + W2.preSpike << z1e.s + W2.postTrace << tr2.trace + W2.postSpike << z2e.s + elif algo == "evstdp": + print(" >> Equipping SNN with EV-STDP Adaptation") + ## EV-STDP meta-parameters + eta_w = 0.01 # 0.002 #1. + Aplus1 = 1. #0.0055 * 2 + Aminus1 = 0.3 #0.25 * 0.0055 + Aplus2 = 1. #0.0055 * 2 + Aminus2 = 0.075 # 0.06 # 0.3 #0.05 * 0.0055 + lmbda = 0. # 0.01 + W1 = EventSTDPSynapse("W1", shape=(in_dim, hidden_size), eta=eta_w, + A_plus=Aplus1, A_minus=Aminus1, + lmbda=lmbda, w_bound=1., presyn_win_len=2., + weight_init=dists.uniform(0.025, 0.8), + resist_scale=R1, key=subkeys[1]) + W2 = EventSTDPSynapse("W2", shape=(hidden_size, out_size), + eta=eta_w, A_plus=Aplus2, A_minus=Aminus2, + lmbda=lmbda, w_bound=1., presyn_win_len=2., + weight_init=dists.uniform(amin=0.025, amax=0.8), + resist_scale=R2, key=subkeys[6]) + # wire relevant plasticity statistics to synaptic cables W1 and W2 + W1.pre_tols << z0.tols + W1.postSpike << z1e.s + W2.pre_tols << z1e.tols + W2.postSpike << z2e.s + else: + print("ERROR: algorithm ", algo, " provided is not supported!") + exit() + + ## layer 0 to layer 1 + W1.inputs << z0.outputs + W1ie.inputs << z1i.s + z1e.j << summation(W1.outputs, W1ie.outputs) + W1ei.inputs << z1e.s + z1i.j << W1ei.outputs + ## layer 1 to layer 2 + W2.inputs << z1e.s_raw + W2ie.inputs << z2i.s + z2e.j << summation(W2.outputs, W2ie.outputs) + W2ei.inputs << z2e.s + z2i.j << W2ei.outputs + + ## wire statistics into global monitor + ## layer 1 stats + M << z1e.s + M << z1e.j + M << z1e.v + ## layer 2 stats + M << z2e.s + M << z2e.j + M << z2e.v + + if algo == "trstdp": + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, + tr0, tr1, tr2, M, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, + compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, tr0, + tr1, tr2, + compile_key="reset") + model.wrap_and_add_command(jit(model.reset), name="reset") + else: + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, M, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, + compile_key="reset") + model.wrap_and_add_command(jit(model.reset), name="reset") + + @model.dynamicCommand + def clamp(x): + z0.inputs.set(x) + + @model.dynamicCommand + def viz(name, low_rez=True, raster_name="exp/raster_plot"): + viz_block([W1.weights.value, W2.weights.value], + [(px, py), (hidx, hidy)], name + "_block", padding=2, + low_rez=low_rez) + + fig, ax = plt.subplots(3, 2, sharex=True, figsize=(15, 8)) + + for i in range(out_size): + ax[1][0].plot([i for i in range(window_length)], + M.view(z1e.v)[:, :, i]) + ax[0][0].plot([i for i in range(window_length)], + M.view(z1e.j)[:, :, i]) + + ax[1][1].plot([i for i in range(window_length)], + M.view(z2e.v)[:, :, i]) + ax[0][1].plot([i for i in range(window_length)], + M.view(z2e.j)[:, :, i]) + utils.viz.raster.create_raster_plot(M.view(z1e.s), ax=ax[2][0]) + utils.viz.raster.create_raster_plot(M.view(z2e.s), ax=ax[2][1]) + # plt.show() + plt.savefig(raster_name) + plt.close() + + @scanner + def observe(current_state, args): + _t, _dt = args + current_state = model.advance_state(current_state, t=_t, dt=_dt) + current_state = model.evolve(current_state, t=_t, dt=_dt) + return current_state, (current_state[z1e.s_raw.path], + current_state[z2e.s_raw.path]) + + @model.dynamicCommand + def showStats(i): + print(f"\n~~~~~Iteration {str(i)}~~~~~~") + print("W1:\n min {} ; max {} \n mu {} ; norm {}".format(jnp.amin(W1.weights.value), jnp.amax(W1.weights.value), jnp.mean(W1.weights.value), + jnp.linalg.norm(W1.weights.value))) + print("W2:\n min {} ; max {} \n mu {} ; norm {}".format(jnp.amin(W2.weights.value), jnp.amax(W2.weights.value), jnp.mean(W2.weights.value), + jnp.linalg.norm(W2.weights.value))) + #M.halt_all() + return model + +def load_from_disk(model_directory, param_dir="/custom", disable_adaptation=True): + with Context("model") as model: + model.load_from_dir(model_directory, custom_folder=param_dir) + nodes = model.get_components("W1", "W1ie", "W1ei", "W2", "W2ie", "W2ei", + "z0", "z1e", "z1i", "z2e", "z2i") + (W1, W1ie, W1ei, W2, W2ie, W2ei,z0, z1e, z1i, z2e, z2i) = nodes + if disable_adaptation: + z1e.tau_theta = 0. ## disable homeostatic adaptation + z2e.tau_theta = 0. ## disable homeostatic adaptation + + advance, adv_args = model.compile_by_key( + W1, W1ie, W1ei, W2, W2ie, W2ei, + z0, z1e, z1i, z2e, z2i, + compile_key="advance_state") + evolve, evolve_args = model.compile_by_key(W1, W2, compile_key="evolve") + reset, reset_args = model.compile_by_key( + z0, z1e, z1i, z2e, z2i, W1, W2, W1ie, W1ei, W2ie, W2ei, + compile_key="reset") + model.wrap_and_add_command(jit(model.reset), name="reset") + + @model.dynamicCommand + def clamp(x): + z0.inputs.set(x) + + @model.dynamicCommand + def viz(name, low_rez=True): + viz_block([W1.weights.value, W2.weights.value], + [(28, 28), (10, 10)], name + "_block", padding=2, + low_rez=low_rez) + + @scanner + def infer(current_state, args): + _t, _dt = args + current_state = model.advance_state(current_state, t=_t, dt=_dt) + return current_state, (current_state[z1e.s_raw.path], + current_state[z2e.s_raw.path]) + + @model.dynamicCommand + def showStats(i): + print(f"\n~~~~~Iteration {str(i)}~~~~~~") + print("W1:\n min {} ; max {} \n mu {} ; norm {}".format( + jnp.amin(W1.weights.value), jnp.amax(W1.weights.value), + jnp.mean(W1.weights.value), + jnp.linalg.norm(W1.weights.value))) + print("W2:\n min {} ; max {} \n mu {} ; norm {}".format( + jnp.amin(W2.weights.value), jnp.amax(W2.weights.value), + jnp.mean(W2.weights.value), + jnp.linalg.norm(W2.weights.value))) + return model + diff --git a/exhibits/time-integrated-stdp/train.py b/exhibits/time-integrated-stdp/train.py new file mode 100755 index 0000000..1b88c7e --- /dev/null +++ b/exhibits/time-integrated-stdp/train.py @@ -0,0 +1,206 @@ +from jax import numpy as jnp, random +import sys, getopt as gopt, optparse, time +from ngclearn.utils.io_utils import makedir +from ngclearn.utils.patch_utils import generate_patch_set + +import matplotlib #.pyplot as plt +matplotlib.use('Agg') +import matplotlib.pyplot as plt +cmap = plt.cm.jet + +def save_parameters(model_dir, nodes): ## model context saving routine + makedir(model_dir) + for node in nodes: + node.save(model_dir) ## call node's local save function + +################################################################################ +# read in general program arguments +options, remainder = gopt.getopt( + sys.argv[1:], '', ["dataX=", "dataY=", "n_samples=", "n_iter=", "verbosity=", + "bind_target=", "exp_dir=", "model_type=", "seed=", + "use_patches=", "model_case="] +) + +model_case = "snn_case1" +seed = 1234 +exp_dir = "exp/" +model_type = "tistdp" +n_iter = 1 # 10 ## total number passes through dataset +n_samples = -1 +bind_target = 40000 +use_patches = False +dataX = "../../data/mnist/trainX.npy" +dataY = "../../data/mnist/trainY.npy" +verbosity = 0 ## verbosity level (0 - fairly minimal, 1 - prints multiple lines on I/O) +for opt, arg in options: + if opt in ("--dataX"): + dataX = arg.strip() + elif opt in ("--dataY"): + dataY = arg.strip() + elif opt in ("--verbosity"): + verbosity = int(arg.strip()) + elif opt in ("--n_samples"): + n_samples = int(arg.strip()) + elif opt in ("--n_iter"): + n_iter = int(arg.strip()) + elif opt in ("--bind_target"): + bind_target = int(arg.strip()) + elif opt in ("--exp_dir"): + exp_dir = arg.strip() + elif opt in ("--model_type"): + model_type = arg.strip() + elif opt in ("--model_case"): + model_case = arg.strip() + elif opt in ("--seed"): + seed = int(arg.strip()) + elif opt in ("--use_patches"): + use_patches = int(arg.strip()) #(arg.strip().lower() == "true") + use_patches = (use_patches == 1) + +if model_case == "snn_case1": + print(" >> Setting up Case 1 model!") + from snn_case1 import build_model, get_nodes +elif model_case == "snn_case2": + print(" >> Setting up Case 2 model!") + from snn_case2 import build_model, get_nodes +else: + print("Error: No other model case studies supported! (", model_case, " invalid)") + exit() + +print(">> X: {} Y: {}".format(dataX, dataY)) + +## load dataset +_X = jnp.load(dataX) +_Y = jnp.load(dataY) +if n_samples > 0: + _X = _X[0:n_samples, :] + _Y = _Y[0:n_samples, :] + print("-> Fitting model to only {} samples".format(n_samples)) +n_batches = _X.shape[0] ## num batches is = to num samples (online learning) + +## basic simulation hyper-parameter/configuration values go here +viz_mod = 1000 #10000 +mb_size = 1 ## locked to batch sizes of 1 +patch_shape = (28, 28) ## same as image_shape +num_patches = 10 +if use_patches: + patch_shape = (10, 10) +in_dim = patch_shape[0] * patch_shape[1] + +T = 250 #300 ## num time steps to simulate (stimulus presentation window length) +dt = 1. ## integration time constant + +################################################################################ +print("--- Building Model ---") +dkey = random.PRNGKey(seed) +dkey, *subkeys = random.split(dkey, 3) +## Create model +makedir(exp_dir) +## get model from imported header file +model = build_model(seed, in_dim=in_dim, is_patch_model=use_patches, algo=model_type) +nodes, node_map = get_nodes(model) +model.save_to_json(exp_dir, model_type) +################################################################################ +print("--- Starting Simulation ---") + +sim_start_time = time.time() ## start time profiling + +model.viz(name="{}{}".format(exp_dir, "recFields"), low_rez=True, + raster_name="{}{}".format(exp_dir, "raster_plot")) + +print("------------------------------------") +model.showStats(-1) + +model_dir = "{}{}/custom_snapshot{}".format(exp_dir, model_type, 0) +save_parameters(model_dir, nodes) + +## enter main adaptation loop over data patterns +class_responses = jnp.zeros((_Y.shape[1], node_map.get("z2e").n_units)) +num_bound = 0 +n_total_samp_seen = 0 +for i in range(n_iter): + dkey, *subkeys = random.split(dkey, 2) + ptrs = random.permutation(subkeys[0], _X.shape[0]) + X = _X[ptrs, :] + Y = _Y[ptrs, :] + + tstart = time.time() + n_samps_seen = 0 + for j in range(n_batches): + idx = j + Xb = X[idx: idx + mb_size, :] + Yb = Y[idx: idx + mb_size, :] + + if use_patches: + ## generate a set of patches from current pattern + X_patches = generate_patch_set(Xb, patch_shape, num_patches, center=False) + for p in range(X_patches.shape[0]): # within a batch of patches, adapt SNN + xs = jnp.expand_dims(X_patches[p, :], axis=0) + flag = jnp.sum(xs) + if flag > 0.: + model.reset() + model.clamp(xs) + spikes1, spikes2 = model.observe( + jnp.array([[dt * k, dt] for k in range(T)])) + + if n_total_samp_seen >= bind_target: + responses = Yb.T * jnp.sum(spikes2, axis=0) + class_responses = class_responses + responses + num_bound += 1 + else: + model.reset() + model.clamp(Xb) + spikes1, spikes2 = model.observe(jnp.array([[dt * k, dt] for k in range(T)])) + # print(tr1) + # print("...") + # print(jnp.sum(spikes1, axis=0)) + # print(jnp.sum(spikes2, axis=0)) + # exit() + + if n_total_samp_seen >= bind_target: + responses = Yb.T * jnp.sum(spikes2, axis=0) + class_responses = class_responses + responses + num_bound += 1 + + n_samps_seen += Xb.shape[0] + n_total_samp_seen += Xb.shape[0] + + print("\r Seen {} images (Binding {})...".format(n_samps_seen, num_bound), end="") + if (j+1) % viz_mod == 0: ## save intermediate receptive fields + tend = time.time() + print() + print(" -> Time = {} s".format(tend - tstart)) + tstart = tend + 0. + model.showStats(i) + model.viz(name="{}{}".format(exp_dir, "recFields"), low_rez=True, + raster_name="{}{}".format(exp_dir, "raster_plot")) + + ## save a running/current overridden copy of NPZ parameters + model_dir = "{}{}/custom".format(exp_dir, model_type) + save_parameters(model_dir, nodes) + + ## end of iteration/epoch + ## save a snapshot of the NPZ parameters at this particular epoch + model_dir = "{}{}/custom_snapshot{}".format(exp_dir, model_type, i) + save_parameters(model_dir, nodes) +print() + +# print(" >> Producing final rec-fields / raster plots!") +# model.viz(name="{}{}".format(exp_dir, "recFields"), low_rez=False, +# raster_name="{}{}".format(exp_dir, "raster_plot")) + +class_responses = jnp.argmax(class_responses, axis=0, keepdims=True) +print("---- Max Class Responses ----") +print(class_responses) +print(class_responses.shape) +jnp.save("{}binded_labels.npy".format(exp_dir), class_responses) + +## stop time profiling +sim_end_time = time.time() +sim_time = sim_end_time - sim_start_time +sim_time_hr = (sim_time/3600.0) # convert time to hours + +print("------------------------------------") +print(" Trial.sim_time = {} h ({} sec)".format(sim_time_hr, sim_time)) + + diff --git a/exhibits/time-integrated-stdp/train_models.sh b/exhibits/time-integrated-stdp/train_models.sh new file mode 100755 index 0000000..f2fddbc --- /dev/null +++ b/exhibits/time-integrated-stdp/train_models.sh @@ -0,0 +1,56 @@ +#!/bin/bash + +## get in user-provided program args +GPU_ID=$1 #1 +MODEL=$2 # evstdp trstdp tistdp stdp +CASE_STUDY=$3 ## snn_case1 snn_case2 +USE_PATCHES=0 ## set to false for this simulations study + +if [[ "$MODEL" != "evstdp" && "$MODEL" != "trstdp" && "$MODEL" != "tistdp" && "$MODEL" != "stdp" ]]; then + echo "Invalid Arg: $MODEL -- only 'evstdp', 'trstdp', 'tistdp', 'stdp', 'patch_evstdp' models supported!" + exit 1 +fi +if [[ "$CASE_STUDY" != "snn_case1" && "$CASE_STUDY" != "snn_case2" ]]; then + echo "Invalid Case provided: $CASE_STUDY -- only 'snn_case1', 'snn_case2' supported!" + exit 1 +fi +echo " >>>> Setting up $MODEL on GPU $GPU_ID" + +SEEDS=(1234 77 811) +N_ITER=20 +if [[ "$CASE_STUDY" == "snn_case2" ]]; then + N_ITER=1 ## Case 2 models focus on online learning; so only 1 pass allowed +fi + +if (( USE_PATCHES == 1 )) ; then + echo ">> Patch-level modeling configured!" + N_ITER=1 +fi +N_SAMPLES=50000 +BIND_COUNT=10000 +BIND_TARGET=$((($N_ITER - 1) * $N_SAMPLES + ($N_SAMPLES - $BIND_COUNT))) + +DATA_X="../../data/mnist/trainX.npy" +DATA_Y="../../data/mnist/trainY.npy" +#DEV_X="../../data/mnist/testX.npy" # validX.npy +#DEV_Y="../../data/mnist/testY.npy" # validY.npy + +if (( N_ITER * N_SAMPLES < BIND_COUNT )) ; then + echo "Not enough samples to reach bind target!" + exit 1 +fi + +for seed in "${SEEDS[@]}" +do + EXP_DIR="exp_$MODEL""_$seed/" + echo " > Running Simulation/Model: $EXP_DIR" + + rm -r $EXP_DIR* + ## train model + CUDA_VISIBLE_DEVICES=$GPU_ID python train.py --dataX=$DATA_X --dataY=$DATA_Y \ + --n_iter=$N_ITER --bind_target=$BIND_TARGET \ + --n_samples=$N_SAMPLES --exp_dir=$EXP_DIR \ + --model_type=$MODEL --seed=$seed \ + --use_patches=$USE_PATCHES \ + --model_case=$CASE_STUDY +done diff --git a/exhibits/time-integrated-stdp/viz_codes.py b/exhibits/time-integrated-stdp/viz_codes.py new file mode 100755 index 0000000..4c22b43 --- /dev/null +++ b/exhibits/time-integrated-stdp/viz_codes.py @@ -0,0 +1,29 @@ +from jax import numpy as jnp +import sys, getopt as gopt, optparse +from ngclearn.utils.viz.dim_reduce import extract_tsne_latents, plot_latents + + +# read in general program arguments +options, remainder = gopt.getopt(sys.argv[1:], '', ["labels_fname=", "codes_fname=", + "plot_fname="]) + +plot_fname = "codes.jpg" +labels_fname = "../../data/mnist/testY.npy" +codes_fname = "exp/test_codes.npy" +for opt, arg in options: + if opt in ("--labels_fname"): + labels_fname = arg.strip() + elif opt in ("--codes_fname"): + codes_fname = arg.strip() + elif opt in ("--plot_fname"): + plot_fname = arg.strip() + +labels = jnp.load(labels_fname) +print("Lab.shape: ", labels.shape) +codes = jnp.load(codes_fname) +print("Codes.shape: ", codes.shape) + +## visualize the above data via the t-SNE algorithm +tsne_codes = extract_tsne_latents(codes) +print("tSNE-codes.shape = ", tsne_codes.shape) +plot_latents(tsne_codes, labels, plot_fname=plot_fname) \ No newline at end of file