Skip to content

Commit

Permalink
add explicit output parameters to spiking output port
Browse files Browse the repository at this point in the history
  • Loading branch information
C.A.P. Linssen committed Oct 17, 2024
1 parent 7ea96ec commit 8953256
Show file tree
Hide file tree
Showing 24 changed files with 1,230 additions and 878 deletions.
2 changes: 1 addition & 1 deletion models/synapses/neuromodulated_stdp_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ model neuromodulated_stdp_synapse:
mod_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(mod_spikes):
n += 1. / tau_n
Expand Down
2 changes: 1 addition & 1 deletion models/synapses/noisy_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ model noisy_synapse:
pre_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(pre_spikes):
# temporary variable for the "weight" that will be transmitted
Expand Down
2 changes: 1 addition & 1 deletion models/synapses/static_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ model static_synapse:
pre_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(pre_spikes):
emit_spike(w, d)
2 changes: 1 addition & 1 deletion models/synapses/stdp_nn_pre_centered_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ model stdp_nn_pre_centered_synapse:
post_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(post_spikes):
post_trace = 1
Expand Down
2 changes: 1 addition & 1 deletion models/synapses/stdp_nn_restr_symm_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ model stdp_nn_restr_symm_synapse:
post_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(post_spikes):
post_trace = 1
Expand Down
2 changes: 1 addition & 1 deletion models/synapses/stdp_nn_symm_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ model stdp_nn_symm_synapse:
post_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(post_spikes):
post_trace = 1
Expand Down
2 changes: 1 addition & 1 deletion models/synapses/stdp_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ model stdp_synapse:
post_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(post_spikes):
post_trace += 1
Expand Down
2 changes: 1 addition & 1 deletion models/synapses/stdp_triplet_synapse.nestml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ model stdp_triplet_synapse:
post_spikes <- spike

output:
spike
spike(weight real, delay ms)

onReceive(post_spikes):
# potentiate synapse
Expand Down
51 changes: 43 additions & 8 deletions pynestml/cocos/co_co_output_port_defined_if_emit_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from pynestml.cocos.co_co import CoCo
from pynestml.meta_model.ast_function_call import ASTFunctionCall
from pynestml.meta_model.ast_model import ASTModel
from pynestml.symbols.predefined_functions import PredefinedFunctions
from pynestml.utils.ast_utils import ASTUtils
from pynestml.utils.logger import Logger, LoggingLevel
from pynestml.utils.messages import Messages
from pynestml.visitors.ast_visitor import ASTVisitor
Expand Down Expand Up @@ -60,22 +62,55 @@ def visit_function_call(self, node: ASTFunctionCall):
"""
assert self.neuron is not None
func_name = node.get_name()
if func_name == 'emit_spike':
if func_name == PredefinedFunctions.EMIT_SPIKE:
output_blocks = self.neuron.get_output_blocks()
if not output_blocks:

# exactly one output block should be defined
if len(output_blocks) == 0:
code, message = Messages.get_block_not_defined_correctly('output', missing=True)
Logger.log_message(error_position=node.get_source_position(), log_level=LoggingLevel.ERROR,
code=code, message=message)
return

spike_output_exists = False
for output_block in output_blocks:
if output_block.is_spike():
spike_output_exists = True
break
if len(output_blocks) > 1:
code, message = Messages.get_block_not_defined_correctly('output', missing=False)
Logger.log_message(error_position=node.get_source_position(), log_level=LoggingLevel.ERROR,
code=code, message=message)
return

if not spike_output_exists:
assert len(output_blocks) == 1

if not output_blocks[0].is_spike():
code, message = Messages.get_emit_spike_function_but_no_output_port()
Logger.log_message(code=code, message=message, log_level=LoggingLevel.ERROR,
error_position=node.get_source_position())
return

# check types
if len(node.get_args()) != len(output_blocks[0].get_attributes()):
code, message = Messages.get_output_port_type_differs()
Logger.log_message(code=code, message=message, log_level=LoggingLevel.ERROR,
error_position=node.get_source_position())
return

for emit_spike_arg, output_block_attr in zip(node.get_args(), output_blocks[0].get_attributes()):

emit_spike_arg_type_sym = emit_spike_arg.type
output_block_attr_type_sym = output_block_attr.get_data_type().get_type_symbol()

if emit_spike_arg_type_sym.equals(output_block_attr_type_sym):
continue

if emit_spike_arg_type_sym.is_castable_to(output_block_attr_type_sym):
# types are not equal, but castable
code, message = Messages.get_implicit_cast_rhs_to_lhs(output_block_attr_type_sym.print_symbol(),
emit_spike_arg_type_sym.print_symbol())
Logger.log_message(error_position=node.get_source_position(),
code=code, message=message, log_level=LoggingLevel.WARNING)
continue
else:
# types are not equal and not castable
code, message = Messages.get_output_port_type_differs()
Logger.log_message(code=code, message=message, log_level=LoggingLevel.ERROR,
error_position=node.get_source_position())
return
8 changes: 8 additions & 0 deletions pynestml/codegeneration/printers/nestml_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,14 @@ def print_output_block(self, node: ASTOutputBlock) -> str:
ret += print_n_spaces(self.indent) + "output:\n"
ret += print_n_spaces(self.indent + 4)
ret += "spike" if node.is_spike() else "continuous"
if node.get_attributes():
ret += "("
for i, attr in enumerate(node.get_attributes()):
ret += self.print(attr)
if i < len(node.get_attributes()) - 1:
ret += ", "

ret += ")"
ret += print_sl_comment(node.in_comment)
ret += "\n"
return ret
Expand Down
Loading

0 comments on commit 8953256

Please sign in to comment.