From f281427a55d18e72ca46bad8aa4e7352c126814e Mon Sep 17 00:00:00 2001 From: Pooja Babu Date: Thu, 17 Oct 2024 12:53:03 +0200 Subject: [PATCH] Modify templates for models with a numeric solver --- .../codegeneration/nest_code_generator.py | 1 + .../nest_gpu_numeric_function_call_printer.py | 2 +- .../point_neuron/@NEURON_NAME@.cu.jinja2 | 42 +++++++++++++++++-- pynestml/utils/ast_utils.py | 12 +++--- 4 files changed, 48 insertions(+), 9 deletions(-) diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py index 66e0c9e13..f4250e8db 100644 --- a/pynestml/codegeneration/nest_code_generator.py +++ b/pynestml/codegeneration/nest_code_generator.py @@ -307,6 +307,7 @@ def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None: Logger.log_message(None, code, message, None, LoggingLevel.INFO) spike_updates, post_spike_updates, equations_with_delay_vars, equations_with_vector_vars = self.analyse_neuron(neuron) neuron.spike_updates = spike_updates + ASTUtils.print_spike_update_expressions(neuron) neuron.post_spike_updates = post_spike_updates neuron.equations_with_delay_vars = equations_with_delay_vars neuron.equations_with_vector_vars = equations_with_vector_vars diff --git a/pynestml/codegeneration/printers/nest_gpu_numeric_function_call_printer.py b/pynestml/codegeneration/printers/nest_gpu_numeric_function_call_printer.py index 894e9ed55..472409712 100644 --- a/pynestml/codegeneration/printers/nest_gpu_numeric_function_call_printer.py +++ b/pynestml/codegeneration/printers/nest_gpu_numeric_function_call_printer.py @@ -34,7 +34,7 @@ def _print_function_call_format_string(self, function_call: ASTFunctionCall) -> if function_name == PredefinedFunctions.TIME_RESOLUTION: # context dependent; we assume the template contains the necessary definitions - return 'h' + return 'NESTGPUTimeResolution' if function_name == PredefinedFunctions.TIME_STEPS: return '(int)round({!s}/NESTGPUTimeResolution)' diff --git a/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2 b/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2 index bd9d6a290..a9ab317fa 100644 --- a/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2 +++ b/pynestml/codegeneration/resources_nest_gpu/point_neuron/@NEURON_NAME@.cu.jinja2 @@ -226,6 +226,41 @@ __device__ {%- include "directives/Block.jinja2" %} {%- endfor %} {%- endfilter %} +{%- endif %} + + /** + * Begin NESTML generated code for the onReceive block(s) + **/ +{% for blk in neuron.get_on_receive_blocks() %} +{%- set inport = blk.get_port_name() %} + if (var[N_SCAL_VAR + i_{{ inport }}]) + { +{%- set ast = blk.get_block() %} +{%- filter indent(6, True) -%} +{%- include "directives/Block.jinja2" %} +{%- endfilter %} + var[N_SCAL_VAR + i_{{ inport }}] = 0; // reset the value + } +{%- endfor %} + + /** + * Begin NESTML generated code for the onCondition block(s) + **/ +{% if neuron.get_on_condition_blocks() %} +{%- for block in neuron.get_on_condition_blocks() %} + if ({{ printer.print(block.get_cond_expr()) }}) + { +{%- set ast = block.get_block() %} +{%- if ast.print_comment('*') | length > 1 %} +/* + {{ast.print_comment('*')}} + */ +{%- endif %} +{%- filter indent(6) %} +{%- include "directives/Block.jinja2" %} +{%- endfilter %} + } +{%- endfor %} {%- endif %} } @@ -290,7 +325,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/, group_param_ = new float[N_GROUP_PARAM]; {%- endif %} -{%- if not uses_numeric_solver %} +{%- if uses_analytic_solver %} AllocParamArr(); AllocVarArr(); {%- endif %} @@ -337,7 +372,8 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/, {%- if uses_numeric_solver %} {# TODO: automatically determine "PSConInit_E" #} - port_weight_arr_ = GetParamArr() + GetScalParamIdx("PSConInit_E"); + # port_weight_arr_ = GetParamArr() + GetScalParamIdx("PSConInit_E"); + port_weight_arr_ = GetParamArr() port_weight_arr_step_ = n_param_; port_weight_port_step_ = 1; {%- else %} @@ -361,7 +397,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/, return 0; } -{%- if not uses_numeric_solver %} +{%- if uses_analytic_solver %} int {{ neuronName }}::Free() { FreeVarArr(); diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index cec41a5b7..d4afcde0b 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -2569,6 +2569,13 @@ def initial_value_or_zero(cls, astnode: ASTModel, var): return "0" + @classmethod + def print_spike_update_expressions(cls, neuron:ASTModel): + import pdb; + for port, update_exp in neuron.spike_updates.items(): + pdb.set_trace() + print(f"Update expression for {port} is {update_exp}") + @classmethod def get_first_spike_port_from_spike_updates(cls, neuron: ASTModel) -> ASTVariable: # Get the first variable in the sorted spike update expressions list @@ -2586,11 +2593,6 @@ def get_first_excitatory_port(cls, neuron: ASTModel) -> str: # There is no port marked excitatory, return the first port name return neuron.get_spike_input_ports()[0].get_symbol_name() - # @classmethod - # def get_port_qualifier_by_port_name(cls, neuron: ASTModel, port_name: str): - # for port in neuron.get_input_blocks()[0].get_input_ports(): - # if port.get - @classmethod def get_exc_spike_variable(cls, neuron: ASTModel) -> ASTVariable: for block in neuron.get_on_receive_blocks():