From f281427a55d18e72ca46bad8aa4e7352c126814e Mon Sep 17 00:00:00 2001
From: Pooja Babu
Date: Thu, 17 Oct 2024 12:53:03 +0200
Subject: [PATCH] Modify templates for models with a numeric solver
---
.../codegeneration/nest_code_generator.py | 1 +
.../nest_gpu_numeric_function_call_printer.py | 2 +-
.../point_neuron/@NEURON_NAME@.cu.jinja2 | 42 +++++++++++++++++--
pynestml/utils/ast_utils.py | 12 +++---
4 files changed, 48 insertions(+), 9 deletions(-)
diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py
index 66e0c9e13..f4250e8db 100644
--- a/pynestml/codegeneration/nest_code_generator.py
+++ b/pynestml/codegeneration/nest_code_generator.py
@@ -307,6 +307,7 @@ def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None:
Logger.log_message(None, code, message, None, LoggingLevel.INFO)
spike_updates, post_spike_updates, equations_with_delay_vars, equations_with_vector_vars = self.analyse_neuron(neuron)
neuron.spike_updates = spike_updates
+ ASTUtils.print_spike_update_expressions(neuron)
neuron.post_spike_updates = post_spike_updates
neuron.equations_with_delay_vars = equations_with_delay_vars
neuron.equations_with_vector_vars = equations_with_vector_vars
diff --git a/pynestml/codegeneration/printers/nest_gpu_numeric_function_call_printer.py b/pynestml/codegeneration/printers/nest_gpu_numeric_function_call_printer.py
index 894e9ed55..472409712 100644
--- a/pynestml/codegeneration/printers/nest_gpu_numeric_function_call_printer.py
+++ b/pynestml/codegeneration/printers/nest_gpu_numeric_function_call_printer.py
@@ -34,7 +34,7 @@ def _print_function_call_format_string(self, function_call: ASTFunctionCall) ->
if function_name == PredefinedFunctions.TIME_RESOLUTION:
# context dependent; we assume the template contains the necessary definitions
- return 'h'
+ return 'NESTGPUTimeResolution'
if function_name == PredefinedFunctions.TIME_STEPS:
return '(int)round({!s}/NESTGPUTimeResolution)'
diff --git a/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2 b/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2
index bd9d6a290..a9ab317fa 100644
--- a/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2
+++ b/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2
@@ -226,6 +226,41 @@ __device__
{%- include "directives/Block.jinja2" %}
{%- endfor %}
{%- endfilter %}
+{%- endif %}
+
+ /**
+ * Begin NESTML generated code for the onReceive block(s)
+ **/
+{% for blk in neuron.get_on_receive_blocks() %}
+{%- set inport = blk.get_port_name() %}
+ if (var[N_SCAL_VAR + i_{{ inport }}])
+ {
+{%- set ast = blk.get_block() %}
+{%- filter indent(6, True) -%}
+{%- include "directives/Block.jinja2" %}
+{%- endfilter %}
+ var[N_SCAL_VAR + i_{{ inport }}] = 0; // reset the value
+ }
+{%- endfor %}
+
+ /**
+ * Begin NESTML generated code for the onCondition block(s)
+ **/
+{% if neuron.get_on_condition_blocks() %}
+{%- for block in neuron.get_on_condition_blocks() %}
+ if ({{ printer.print(block.get_cond_expr()) }})
+ {
+{%- set ast = block.get_block() %}
+{%- if ast.print_comment('*') | length > 1 %}
+/*
+ {{ast.print_comment('*')}}
+ */
+{%- endif %}
+{%- filter indent(6) %}
+{%- include "directives/Block.jinja2" %}
+{%- endfilter %}
+ }
+{%- endfor %}
{%- endif %}
}
@@ -290,7 +325,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
group_param_ = new float[N_GROUP_PARAM];
{%- endif %}
-{%- if not uses_numeric_solver %}
+{%- if uses_analytic_solver %}
AllocParamArr();
AllocVarArr();
{%- endif %}
@@ -337,7 +372,8 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
{%- if uses_numeric_solver %}
{# TODO: automatically determine "PSConInit_E" #}
- port_weight_arr_ = GetParamArr() + GetScalParamIdx("PSConInit_E");
+ # port_weight_arr_ = GetParamArr() + GetScalParamIdx("PSConInit_E");
+ port_weight_arr_ = GetParamArr()
port_weight_arr_step_ = n_param_;
port_weight_port_step_ = 1;
{%- else %}
@@ -361,7 +397,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
return 0;
}
-{%- if not uses_numeric_solver %}
+{%- if uses_analytic_solver %}
int {{ neuronName }}::Free()
{
FreeVarArr();
diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py
index cec41a5b7..d4afcde0b 100644
--- a/pynestml/utils/ast_utils.py
+++ b/pynestml/utils/ast_utils.py
@@ -2569,6 +2569,13 @@ def initial_value_or_zero(cls, astnode: ASTModel, var):
return "0"
+ @classmethod
+ def print_spike_update_expressions(cls, neuron:ASTModel):
+ import pdb;
+ for port, update_exp in neuron.spike_updates.items():
+ pdb.set_trace()
+ print(f"Update expression for {port} is {update_exp}")
+
@classmethod
def get_first_spike_port_from_spike_updates(cls, neuron: ASTModel) -> ASTVariable:
# Get the first variable in the sorted spike update expressions list
@@ -2586,11 +2593,6 @@ def get_first_excitatory_port(cls, neuron: ASTModel) -> str:
# There is no port marked excitatory, return the first port name
return neuron.get_spike_input_ports()[0].get_symbol_name()
- # @classmethod
- # def get_port_qualifier_by_port_name(cls, neuron: ASTModel, port_name: str):
- # for port in neuron.get_input_blocks()[0].get_input_ports():
- # if port.get
-
@classmethod
def get_exc_spike_variable(cls, neuron: ASTModel) -> ASTVariable:
for block in neuron.get_on_receive_blocks():