Skip to content

Commit

Permalink
Modify templates for models with a numeric solver
Browse files Browse the repository at this point in the history
  • Loading branch information
pnbabu committed Oct 17, 2024
1 parent 5f6653a commit f281427
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 9 deletions.
1 change: 1 addition & 0 deletions pynestml/codegeneration/nest_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def analyse_transform_neurons(self, neurons: List[ASTModel]) -> None:
Logger.log_message(None, code, message, None, LoggingLevel.INFO)
spike_updates, post_spike_updates, equations_with_delay_vars, equations_with_vector_vars = self.analyse_neuron(neuron)
neuron.spike_updates = spike_updates
ASTUtils.print_spike_update_expressions(neuron)
neuron.post_spike_updates = post_spike_updates
neuron.equations_with_delay_vars = equations_with_delay_vars
neuron.equations_with_vector_vars = equations_with_vector_vars
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _print_function_call_format_string(self, function_call: ASTFunctionCall) ->

if function_name == PredefinedFunctions.TIME_RESOLUTION:
# context dependent; we assume the template contains the necessary definitions
return 'h'
return 'NESTGPUTimeResolution'

if function_name == PredefinedFunctions.TIME_STEPS:
return '(int)round({!s}/NESTGPUTimeResolution)'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,41 @@ __device__
{%- include "directives/Block.jinja2" %}
{%- endfor %}
{%- endfilter %}
{%- endif %}

/**
* Begin NESTML generated code for the onReceive block(s)
**/
{% for blk in neuron.get_on_receive_blocks() %}
{%- set inport = blk.get_port_name() %}
if (var[N_SCAL_VAR + i_{{ inport }}])
{
{%- set ast = blk.get_block() %}
{%- filter indent(6, True) -%}
{%- include "directives/Block.jinja2" %}
{%- endfilter %}
var[N_SCAL_VAR + i_{{ inport }}] = 0; // reset the value
}
{%- endfor %}

/**
* Begin NESTML generated code for the onCondition block(s)
**/
{% if neuron.get_on_condition_blocks() %}
{%- for block in neuron.get_on_condition_blocks() %}
if ({{ printer.print(block.get_cond_expr()) }})
{
{%- set ast = block.get_block() %}
{%- if ast.print_comment('*') | length > 1 %}
/*
{{ast.print_comment('*')}}
*/
{%- endif %}
{%- filter indent(6) %}
{%- include "directives/Block.jinja2" %}
{%- endfilter %}
}
{%- endfor %}
{%- endif %}
}

Expand Down Expand Up @@ -290,7 +325,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
group_param_ = new float[N_GROUP_PARAM];
{%- endif %}

{%- if not uses_numeric_solver %}
{%- if uses_analytic_solver %}
AllocParamArr();
AllocVarArr();
{%- endif %}
Expand Down Expand Up @@ -337,7 +372,8 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,

{%- if uses_numeric_solver %}
{# TODO: automatically determine "PSConInit_E" #}
port_weight_arr_ = GetParamArr() + GetScalParamIdx("PSConInit_E");
# port_weight_arr_ = GetParamArr() + GetScalParamIdx("PSConInit_E");
port_weight_arr_ = GetParamArr()
port_weight_arr_step_ = n_param_;
port_weight_port_step_ = 1;
{%- else %}
Expand All @@ -361,7 +397,7 @@ int {{ neuronName }}::Init(int i_node_0, int n_node, int /*n_port*/,
return 0;
}

{%- if not uses_numeric_solver %}
{%- if uses_analytic_solver %}
int {{ neuronName }}::Free()
{
FreeVarArr();
Expand Down
12 changes: 7 additions & 5 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2569,6 +2569,13 @@ def initial_value_or_zero(cls, astnode: ASTModel, var):
return "0"


@classmethod
def print_spike_update_expressions(cls, neuron:ASTModel):
import pdb;
for port, update_exp in neuron.spike_updates.items():
pdb.set_trace()
print(f"Update expression for {port} is {update_exp}")

@classmethod
def get_first_spike_port_from_spike_updates(cls, neuron: ASTModel) -> ASTVariable:
# Get the first variable in the sorted spike update expressions list
Expand All @@ -2586,11 +2593,6 @@ def get_first_excitatory_port(cls, neuron: ASTModel) -> str:
# There is no port marked excitatory, return the first port name
return neuron.get_spike_input_ports()[0].get_symbol_name()

# @classmethod
# def get_port_qualifier_by_port_name(cls, neuron: ASTModel, port_name: str):
# for port in neuron.get_input_blocks()[0].get_input_ports():
# if port.get

@classmethod
def get_exc_spike_variable(cls, neuron: ASTModel) -> ASTVariable:
for block in neuron.get_on_receive_blocks():
Expand Down

0 comments on commit f281427

Please sign in to comment.