Skip to content

Commit

Permalink
WIP circuit sim classes
Browse files Browse the repository at this point in the history
  • Loading branch information
clbarnes committed Apr 12, 2022
1 parent 9ab2c8d commit f0c8863
Showing 1 changed file with 164 additions and 141 deletions.
305 changes: 164 additions & 141 deletions django/applications/catmaid/static/js/widgets/circuit-simulation.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,162 @@
(function (CATMAID) {
"use strict";

function zip(arrays) {
const minLen = arrays.reduce(
(accum, curr) => Math.min(accum, curr.length),
Infinity
);
const out = [];
for (let i = 0; i < minLen; ++i) {
out.push(arrays.map((a) => a[i]));
}
return out;
const nullish = CATMAID.tools.nullish;

class LogisticActivation extends ActivationFn {
constructor(slope, threshold, scale) {
this.slope = slope;
this.threshold = threshold;
this.scale = nullish(scale, 1.0);
}

compute(x) {
if (x == 0) {
// Logistic functions are asymptotic, and so will always produce
// (very small) outputs even with no inputs when translated up.
// This ensures that no input -> no output.
return 0;
}
return this.scale / (1.0 + Math.exp(-this.slope * (x - this.threshold)));
}

static fromObj(obj) {
return new LogisticActivation(obj.slope, obj.threshold, obj.scale);
}
}

class Stimulation {
/**
*
* @param {number} strength - strength of stimulus
* @param {number?} start - time at which stimulus starts, default -inf
* @param {number?} stop - time at which stimulus stops, default inf
*/
constructor(strength, start, stop) {
this.strength = strength;
this.start = nullish(start, -Infinity);
this.stop = nullish(stop, Infinity);
}

/**
*
* @param {number} t - timepoint
* @returns {number} Amount of stimulation at the give timepoint
*/
atTime(t) {
return (t >= this.start && t < this.stop) ? this.strength : 0;
}

static fromObj(obj) {
return new Stimulation(obj.strength, obj.start, obj.stop);
}
}

class Unit {
constructor(name, color, activation, tau, tonic, stims) {
this.name = name;
this.color = color;
this.activation = activation;
this.tau = tau;
this.tonic = tonic;
this.stims = stims;
}

/** Does not include tonic */
inputFromStims(t) {
return this.stims.reduce((prev, current) => prev + current.getStimulation(t), 0);
}

/** Includes tonic bias, i.e. self-excitation */
inputFromPartners(weights, rates) {
return weights.reduce(
(prev, current, idx) => prev + current * rates[idx],
this.tonic,
);
}

dy_dt(rate, t, weights, rates) {
let independent = this.inputFromStims(t) - rate;
return (independent + this.activation.compute(this.inputFromPartners(weights, rates))) / this.tau;
}

static fromObj(obj) {
return new Unit(
obj.name, obj.color,
LogisticActivation.fromObj(obj.activation),
obj.tau, obj.tonic,
obj.stims.map(Stimulation.fromObj),
);
}
}

class Circuit {
constructor(units, weights) {
this.units = units;
this.weights = weights;
}

dy_dt(t, rates) {
return this.units.map(
(unit, idx) => unit.dy_dt(rates[idx], t, this.weights[idx], rates)
);
}

solve(length) {
return numeric.dopri_nonnegative(
0, length, this.units.map(() => 0), this.dy_dt
);
}

getLines(solution) {
const lines = this.units.map(function (p) {
return { name: p.name, color: p.color, stroke_width: "3", xy: [] };
});

CATMAID.tools.zip(solution.x, solution.y).forEach((ty) => {
let t = ty[0];
let y = ty[1];
for (let i = 0; i < y.length; ++i) {
lines[i].xy.push({ x: t, y: y[i] });
}
});

return lines;
}

static fromObj(obj) {
return new Circuit(obj.units.map(Unit.fromObj), obj.weights);
}
}

CATMAID.CircuitSimulation = class CircuitSimulation extends InstanceRegistry {
constructor() {
super();
this.widgetID = this.registerInstance();
this.idPrefix = `circuitsim${this.widgetID}-`;
this.showParameterUi = true;
this.skelSource = new CATMAID.BasicSkeletonSource(this.getName(), {
owner: this,
});

// An array of objects with members:
// - name: of the unit
// - color: string, #-prefixed hex color for the plot
// - w: array of connection weights with the other units (in order)
// - k: slope of logistic function
// - th: threshold of logistic function
// - I_tonic: bias/ tonic stimulation
// - I_stim: current while stimulus is active
// - I_stim_start: first time point of stimulus
// - I_stim_end: last time point of stimulus
// - scaling: multiplies the logistic activation function
// - tau: time constant; divides dy/dt to represent how quickly unit responds
this.units = [];
this.circuit = null;

this.sol = null;
this.lines = [];
}

getSubId(id) {
return this.idPrefix + id;
}

getName() {
return "Circuit Simulation " + this.widgetID;
}

getWidgetConfiguration() {
return {
controlsID: "circuit_simulation_buttons" + this.widgetID,
contentID: "circuit_simulation_div" + this.widgetID,
controlsID: this.getSubId("controls"),
contentID: this.getSubId("content"),
createControls: function (controls) {
var CS = this;
var tabs = CATMAID.DOM.addTabGroup(controls, CS.widgetID, [
Expand All @@ -60,7 +170,7 @@

const fileButton = controls.appendChild(
CATMAID.DOM.createFileButton(
"load-json-dialog-" + this.widgetID,
this.getSubId("loadjson"),
false,
(event) => CS.loadJson(event.target.files)
)
Expand All @@ -81,7 +191,7 @@
["Save results", CS.saveResults.bind(CS)],
[
CATMAID.DOM.createNumericField(
"cs_time" + CS.widgetID,
CS.getSubId("maxtime"),
"Time:",
"Amount of simulated time, in arbitrary units",
"1000",
Expand Down Expand Up @@ -111,18 +221,32 @@
};
}

clearCache() {
this.circuit = null;
this.sol = null;
this.lines = null;
}

cacheCircuit() {
this.sol = null;
}

saveResults() {
if (!this.sol) {
CATMAID.warn("No results to save");
return;
if (!this.circuit) {
this.cacheCircuit();
}

const out = {
time: this.sol.x,
units: this.units,
rates: zip(this.sol.y),
circuit: this.circuit,
results: null,
};

if (!!this.sol) {
out.results = { time: this.sol.x, rates: CATMAID.tools.zip(...this.sol.y) };
}
else {
CATMAID.warn("No results to save");
}

const timestamp = CATMAID.tools.dateToString(null, "T", "");
const defaultFilename = `circuit-simulation_${timestamp}.json`;
saveAs(
Expand All @@ -144,12 +268,7 @@
let parsed;
try {
parsed = JSON.parse(event.target.result);
// allow results JSON or just units member to be used
if (Array.isArray(parsed)) {
this.units = parsed;
} else {
this.units = parsed.units;
}
this.circuit = Circuit.fromObj(parsed.circuit);
} catch (err) {
CATMAID.handleError(err);
return;
Expand All @@ -159,9 +278,7 @@
}

clear() {
this.units = [];
this.sol = null;
this.lines = [];
this.clearCache();
this.redraw();
}

Expand Down Expand Up @@ -235,104 +352,10 @@
}

run() {
// params: vector of maps, one map per unit containing everything that is constant in the stimulation for that unit:
// * w: weights in the circuit graph.
// * k: slope of the logistic function of each unit.
// * th: threshold of the logistic function of each unit.
// * I_tonic: bias of the inputs of each unit.
// * I_stim: vector that directly controls the output of a unit for a specific time period. (The optogenetic stimulus, so to speak.)
// * t_stim_start: first time point at which I_stim is applied.
// * t_stim_end: last time point at which I_stim is applied.
// * scaling: multiplies the logistic
// * tau: divide the entire dydt output, representing how fast the output of the unit responds to its input.
//
// Also:
// * name
// * color
//

// Logistic function: all parameters are scalars
// x: sum of weights * rates for inputs
// k: slope
// th: threshold
// Returns a scalar
var logistic = function (x, k, th) {
if (x == 0) {
// Logistic functions are asymptotic, and so will always produce
// (very small) outputs even with no inputs when translated up.
// This ensures that no input -> no output.
return 0;
}
return 1.0 / (1.0 + Math.exp(-k * (x - th)));
};

// Compute the external stimulation to the unit, if any
var stim = function (t, p) {
return t >= p.I_stim_start && t <= p.I_stim_end ? p.I_stim : 0;
};

// Compute the 'x' in the logistic function
// by multiplying the weight of the connection times the rate of that input neuron,
// summing all, and then adding the baseline bias, I_tonic.
var input = function (vw, vr, I_tonic) {
var s = 0;
for (var i = 0; i < vw.length; ++i) {
s += vw[i] * vr[i];
}
return s + I_tonic;
};

// A two-argument function that runs at every simulated time instance
// (units is bound to this.units, returning a two-arg function.)
// t: (scalar) current time point
// vr: (vector) current firing rates of all units
// Returns a vector: each item is a solution for each unit.
var dydt = function (units, t, vr) {
return vr.map(function (r, i) {
var p = units[i];
return (
(-r +
stim(t, p) +
p.scaling * logistic(input(p.w, vr, p.I_tonic), p.k, p.th)) /
p.tau
);
});
}.bind(null, this.units);

// numeric.dopri parameters: (the ODE solver)
// x0: initial time of the simulation.
// x1: final time of the simulation.
// y0: initial state (of the rates in our case).
// f: dydt function, executing at every step of the simulation.
// tol: (optional) tolerance, default 1e-6.
// maxit: (optional) maximum number of iterations, default 1000.
// event: (optional) the integration stops if the event function foes from negative to positive.

var x0 = 0;
var x1 = Number($("#cs_time" + this.widgetID).val());
var y0 = this.units.map(function () {
return 0;
}); // a vector full of zeros

var sol = numeric.dopri_nonnegative(x0, x1, y0, dydt, 1e-6, 1000, null);

// Extract one line per unit, for plotting
const lines = this.units.map(function (p) {
return { name: p.name, color: p.color, stroke_width: "3", xy: [] };
});

zip([sol.x, sol.y]).forEach((ty) => {
let t = ty[0];
let y = ty[1];
for (let i = 0; i < y.length; ++i) {
lines[i].xy.push({ x: t, y: y[i] });
}
});

this.lines = lines;

// Store the result for analysis from the command line
this.sol = sol;
this.cacheCircuit();
var x1 = Number($("#" + this.getSubId("maxtime")).val());
this.sol = this.circuit.solve(x1);
this.lines = this.circuit.getLines(this.sol);

this.redraw();
}
Expand Down

0 comments on commit f0c8863

Please sign in to comment.