diff --git a/django/applications/catmaid/static/js/widgets/circuit-simulation.js b/django/applications/catmaid/static/js/widgets/circuit-simulation.js index a07d26bb2d..f496452e1b 100644 --- a/django/applications/catmaid/static/js/widgets/circuit-simulation.js +++ b/django/applications/catmaid/static/js/widgets/circuit-simulation.js @@ -5,52 +5,162 @@ (function (CATMAID) { "use strict"; - function zip(arrays) { - const minLen = arrays.reduce( - (accum, curr) => Math.min(accum, curr.length), - Infinity - ); - const out = []; - for (let i = 0; i < minLen; ++i) { - out.push(arrays.map((a) => a[i])); - } - return out; + const nullish = CATMAID.tools.nullish; + + class LogisticActivation extends ActivationFn { + constructor(slope, threshold, scale) { + this.slope = slope; + this.threshold = threshold; + this.scale = nullish(scale, 1.0); + } + + compute(x) { + if (x == 0) { + // Logistic functions are asymptotic, and so will always produce + // (very small) outputs even with no inputs when translated up. + // This ensures that no input -> no output. + return 0; + } + return this.scale / (1.0 + Math.exp(-this.slope * (x - this.threshold))); + } + + static fromObj(obj) { + return new LogisticActivation(obj.slope, obj.threshold, obj.scale); + } + } + + class Stimulation { + /** + * + * @param {number} strength - strength of stimulus + * @param {number?} start - time at which stimulus starts, default -inf + * @param {number?} stop - time at which stimulus stops, default inf + */ + constructor(strength, start, stop) { + this.strength = strength; + this.start = nullish(start, -Infinity); + this.stop = nullish(stop, Infinity); + } + + /** + * + * @param {number} t - timepoint + * @returns {number} Amount of stimulation at the give timepoint + */ + atTime(t) { + return (t >= this.start && t < this.stop) ? this.strength : 0; + } + + static fromObj(obj) { + return new Stimulation(obj.strength, obj.start, obj.stop); + } + } + + class Unit { + constructor(name, color, activation, tau, tonic, stims) { + this.name = name; + this.color = color; + this.activation = activation; + this.tau = tau; + this.tonic = tonic; + this.stims = stims; + } + + /** Does not include tonic */ + inputFromStims(t) { + return this.stims.reduce((prev, current) => prev + current.getStimulation(t), 0); + } + + /** Includes tonic bias, i.e. self-excitation */ + inputFromPartners(weights, rates) { + return weights.reduce( + (prev, current, idx) => prev + current * rates[idx], + this.tonic, + ); + } + + dy_dt(rate, t, weights, rates) { + let independent = this.inputFromStims(t) - rate; + return (independent + this.activation.compute(this.inputFromPartners(weights, rates))) / this.tau; + } + + static fromObj(obj) { + return new Unit( + obj.name, obj.color, + LogisticActivation.fromObj(obj.activation), + obj.tau, obj.tonic, + obj.stims.map(Stimulation.fromObj), + ); + } + } + + class Circuit { + constructor(units, weights) { + this.units = units; + this.weights = weights; + } + + dy_dt(t, rates) { + return this.units.map( + (unit, idx) => unit.dy_dt(rates[idx], t, this.weights[idx], rates) + ); + } + + solve(length) { + return numeric.dopri_nonnegative( + 0, length, this.units.map(() => 0), this.dy_dt + ); + } + + getLines(solution) { + const lines = this.units.map(function (p) { + return { name: p.name, color: p.color, stroke_width: "3", xy: [] }; + }); + + CATMAID.tools.zip(solution.x, solution.y).forEach((ty) => { + let t = ty[0]; + let y = ty[1]; + for (let i = 0; i < y.length; ++i) { + lines[i].xy.push({ x: t, y: y[i] }); + } + }); + + return lines; + } + + static fromObj(obj) { + return new Circuit(obj.units.map(Unit.fromObj), obj.weights); + } } CATMAID.CircuitSimulation = class CircuitSimulation extends InstanceRegistry { constructor() { super(); + this.widgetID = this.registerInstance(); + this.idPrefix = `circuitsim${this.widgetID}-`; this.showParameterUi = true; this.skelSource = new CATMAID.BasicSkeletonSource(this.getName(), { owner: this, }); - // An array of objects with members: - // - name: of the unit - // - color: string, #-prefixed hex color for the plot - // - w: array of connection weights with the other units (in order) - // - k: slope of logistic function - // - th: threshold of logistic function - // - I_tonic: bias/ tonic stimulation - // - I_stim: current while stimulus is active - // - I_stim_start: first time point of stimulus - // - I_stim_end: last time point of stimulus - // - scaling: multiplies the logistic activation function - // - tau: time constant; divides dy/dt to represent how quickly unit responds - this.units = []; + this.circuit = null; this.sol = null; this.lines = []; } + getSubId(id) { + return this.idPrefix + id; + } + getName() { return "Circuit Simulation " + this.widgetID; } getWidgetConfiguration() { return { - controlsID: "circuit_simulation_buttons" + this.widgetID, - contentID: "circuit_simulation_div" + this.widgetID, + controlsID: this.getSubId("controls"), + contentID: this.getSubId("content"), createControls: function (controls) { var CS = this; var tabs = CATMAID.DOM.addTabGroup(controls, CS.widgetID, [ @@ -60,7 +170,7 @@ const fileButton = controls.appendChild( CATMAID.DOM.createFileButton( - "load-json-dialog-" + this.widgetID, + this.getSubId("loadjson"), false, (event) => CS.loadJson(event.target.files) ) @@ -81,7 +191,7 @@ ["Save results", CS.saveResults.bind(CS)], [ CATMAID.DOM.createNumericField( - "cs_time" + CS.widgetID, + CS.getSubId("maxtime"), "Time:", "Amount of simulated time, in arbitrary units", "1000", @@ -111,18 +221,32 @@ }; } + clearCache() { + this.circuit = null; + this.sol = null; + this.lines = null; + } + + cacheCircuit() { + this.sol = null; + } + saveResults() { - if (!this.sol) { - CATMAID.warn("No results to save"); - return; + if (!this.circuit) { + this.cacheCircuit(); } - const out = { - time: this.sol.x, - units: this.units, - rates: zip(this.sol.y), + circuit: this.circuit, + results: null, }; + if (!!this.sol) { + out.results = { time: this.sol.x, rates: CATMAID.tools.zip(...this.sol.y) }; + } + else { + CATMAID.warn("No results to save"); + } + const timestamp = CATMAID.tools.dateToString(null, "T", ""); const defaultFilename = `circuit-simulation_${timestamp}.json`; saveAs( @@ -144,12 +268,7 @@ let parsed; try { parsed = JSON.parse(event.target.result); - // allow results JSON or just units member to be used - if (Array.isArray(parsed)) { - this.units = parsed; - } else { - this.units = parsed.units; - } + this.circuit = Circuit.fromObj(parsed.circuit); } catch (err) { CATMAID.handleError(err); return; @@ -159,9 +278,7 @@ } clear() { - this.units = []; - this.sol = null; - this.lines = []; + this.clearCache(); this.redraw(); } @@ -235,104 +352,10 @@ } run() { - // params: vector of maps, one map per unit containing everything that is constant in the stimulation for that unit: - // * w: weights in the circuit graph. - // * k: slope of the logistic function of each unit. - // * th: threshold of the logistic function of each unit. - // * I_tonic: bias of the inputs of each unit. - // * I_stim: vector that directly controls the output of a unit for a specific time period. (The optogenetic stimulus, so to speak.) - // * t_stim_start: first time point at which I_stim is applied. - // * t_stim_end: last time point at which I_stim is applied. - // * scaling: multiplies the logistic - // * tau: divide the entire dydt output, representing how fast the output of the unit responds to its input. - // - // Also: - // * name - // * color - // - - // Logistic function: all parameters are scalars - // x: sum of weights * rates for inputs - // k: slope - // th: threshold - // Returns a scalar - var logistic = function (x, k, th) { - if (x == 0) { - // Logistic functions are asymptotic, and so will always produce - // (very small) outputs even with no inputs when translated up. - // This ensures that no input -> no output. - return 0; - } - return 1.0 / (1.0 + Math.exp(-k * (x - th))); - }; - - // Compute the external stimulation to the unit, if any - var stim = function (t, p) { - return t >= p.I_stim_start && t <= p.I_stim_end ? p.I_stim : 0; - }; - - // Compute the 'x' in the logistic function - // by multiplying the weight of the connection times the rate of that input neuron, - // summing all, and then adding the baseline bias, I_tonic. - var input = function (vw, vr, I_tonic) { - var s = 0; - for (var i = 0; i < vw.length; ++i) { - s += vw[i] * vr[i]; - } - return s + I_tonic; - }; - - // A two-argument function that runs at every simulated time instance - // (units is bound to this.units, returning a two-arg function.) - // t: (scalar) current time point - // vr: (vector) current firing rates of all units - // Returns a vector: each item is a solution for each unit. - var dydt = function (units, t, vr) { - return vr.map(function (r, i) { - var p = units[i]; - return ( - (-r + - stim(t, p) + - p.scaling * logistic(input(p.w, vr, p.I_tonic), p.k, p.th)) / - p.tau - ); - }); - }.bind(null, this.units); - - // numeric.dopri parameters: (the ODE solver) - // x0: initial time of the simulation. - // x1: final time of the simulation. - // y0: initial state (of the rates in our case). - // f: dydt function, executing at every step of the simulation. - // tol: (optional) tolerance, default 1e-6. - // maxit: (optional) maximum number of iterations, default 1000. - // event: (optional) the integration stops if the event function foes from negative to positive. - - var x0 = 0; - var x1 = Number($("#cs_time" + this.widgetID).val()); - var y0 = this.units.map(function () { - return 0; - }); // a vector full of zeros - - var sol = numeric.dopri_nonnegative(x0, x1, y0, dydt, 1e-6, 1000, null); - - // Extract one line per unit, for plotting - const lines = this.units.map(function (p) { - return { name: p.name, color: p.color, stroke_width: "3", xy: [] }; - }); - - zip([sol.x, sol.y]).forEach((ty) => { - let t = ty[0]; - let y = ty[1]; - for (let i = 0; i < y.length; ++i) { - lines[i].xy.push({ x: t, y: y[i] }); - } - }); - - this.lines = lines; - - // Store the result for analysis from the command line - this.sol = sol; + this.cacheCircuit(); + var x1 = Number($("#" + this.getSubId("maxtime")).val()); + this.sol = this.circuit.solve(x1); + this.lines = this.circuit.getLines(this.sol); this.redraw(); }