From 9fc69b46b47c0b7b9d8684ff312f774902d55561 Mon Sep 17 00:00:00 2001 From: clinssen Date: Thu, 10 Oct 2024 16:29:41 +0200 Subject: [PATCH] Run context condition checks only once, after model parsing (#1105) --- .../stdp_third_factor_active_dendrite.ipynb | 2 +- pynestml/cocos/co_co_all_variables_defined.py | 52 +- pynestml/cocos/co_co_function_unique.py | 1 + pynestml/cocos/co_co_illegal_expression.py | 6 +- .../co_co_no_kernels_except_in_convolve.py | 49 +- pynestml/cocos/co_co_v_comp_exists.py | 3 - pynestml/cocos/co_cos_manager.py | 13 +- pynestml/codegeneration/builder.py | 4 +- .../codegeneration/nest_code_generator.py | 10 +- .../nest_compartmental_code_generator.py | 11 +- .../python_standalone_code_generator.py | 1 - .../spinnaker_code_generator.py | 9 +- pynestml/frontend/frontend_configuration.py | 4 +- pynestml/frontend/pynestml_frontend.py | 110 +-- pynestml/meta_model/ast_model.py | 37 +- pynestml/symbols/symbol.py | 2 +- pynestml/symbols/type_symbol.py | 5 +- pynestml/symbols/unit_type_symbol.py | 25 +- .../synapse_post_neuron_transformer.py | 7 +- pynestml/utils/ast_utils.py | 17 +- pynestml/utils/logger.py | 75 +- pynestml/utils/mechs_info_enricher.py | 64 +- pynestml/utils/messages.py | 28 +- pynestml/utils/model_parser.py | 6 + pynestml/utils/type_caster.py | 35 +- ...ign_implicit_conversion_factors_visitor.py | 326 ++++++++ pynestml/visitors/ast_builder_visitor.py | 14 +- .../visitors/ast_function_call_visitor.py | 1 - pynestml/visitors/ast_symbol_table_visitor.py | 13 +- tests/cocos_test.py | 698 ------------------ tests/function_parameter_templating_test.py | 57 -- tests/nest_compartmental_tests/test__cocos.py | 92 ++- .../nest_delay_based_variables_test.py | 28 +- tests/nest_tests/non_linear_dendrite_test.py | 2 - .../integrate_odes_test_params.nestml | 1 - .../integrate_odes_test_params2.nestml | 10 + tests/nest_tests/test_integrate_odes.py | 30 +- tests/test_cocos.py | 403 ++++++++++ tests/test_function_parameter_templating.py | 36 + tests/test_unit_system.py | 164 ++++ tests/unit_system_test.py | 177 ----- 41 files changed, 1348 insertions(+), 1280 deletions(-) create mode 100644 pynestml/visitors/assign_implicit_conversion_factors_visitor.py delete mode 100644 tests/cocos_test.py delete mode 100644 tests/function_parameter_templating_test.py create mode 100644 tests/nest_tests/resources/integrate_odes_test_params2.nestml create mode 100644 tests/test_cocos.py create mode 100644 tests/test_function_parameter_templating.py create mode 100644 tests/test_unit_system.py delete mode 100644 tests/unit_system_test.py diff --git a/doc/tutorials/stdp_third_factor_active_dendrite/stdp_third_factor_active_dendrite.ipynb b/doc/tutorials/stdp_third_factor_active_dendrite/stdp_third_factor_active_dendrite.ipynb index 147f5b4a1..3922d0ac8 100644 --- a/doc/tutorials/stdp_third_factor_active_dendrite/stdp_third_factor_active_dendrite.ipynb +++ b/doc/tutorials/stdp_third_factor_active_dendrite/stdp_third_factor_active_dendrite.ipynb @@ -1347,7 +1347,7 @@ " NESTCodeGeneratorUtils.generate_code_for(nestml_neuron_model,\n", " nestml_synapse_model,\n", " codegen_opts=codegen_opts,\n", - " logging_level=\"INFO\") # try \"INFO\" or \"DEBUG\" for more debug information" + " logging_level=\"WARNING\") # try \"INFO\" or \"DEBUG\" for more debug information" ] }, { diff --git a/pynestml/cocos/co_co_all_variables_defined.py b/pynestml/cocos/co_co_all_variables_defined.py index e41b0727e..3ec7f16c4 100644 --- a/pynestml/cocos/co_co_all_variables_defined.py +++ b/pynestml/cocos/co_co_all_variables_defined.py @@ -41,11 +41,10 @@ class CoCoAllVariablesDefined(CoCo): """ @classmethod - def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False): + def check_co_co(cls, node: ASTModel): """ Checks if this coco applies for the handed over neuron. Models which contain undefined variables are not correct. :param node: a single neuron instance. - :param after_ast_rewrite: indicates whether this coco is checked after the code generator has done rewriting of the abstract syntax tree. If True, checks are not as rigorous. Use False where possible. """ # for each variable in all expressions, check if the variable has been defined previously expression_collector_visitor = ASTExpressionCollectorVisitor() @@ -62,31 +61,23 @@ def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False): # test if the symbol has been defined at least if symbol is None: - if after_ast_rewrite: # after ODE-toolbox transformations, convolutions are replaced by state variables, so cannot perform this check properly - symbol2 = node.get_scope().resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE) - if symbol2 is not None: - # an inline expression defining this variable name (ignoring differential order) exists - if "__X__" in str(symbol2): # if this variable was the result of a convolution... + # for inline expressions, also allow derivatives of that kernel to appear + inline_expr_names = [] + inline_exprs = [] + for equations_block in node.get_equations_blocks(): + inline_expr_names.extend([inline_expr.variable_name for inline_expr in equations_block.get_inline_expressions()]) + inline_exprs.extend(equations_block.get_inline_expressions()) + + if var.get_name() in inline_expr_names: + inline_expr_idx = inline_expr_names.index(var.get_name()) + inline_expr = inline_exprs[inline_expr_idx] + from pynestml.utils.ast_utils import ASTUtils + if ASTUtils.inline_aliases_convolution(inline_expr): + symbol2 = node.get_scope().resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE) + if symbol2 is not None: + # actually, no problem detected, skip error + # XXX: TODO: check that differential order is less than or equal to that of the kernel continue - else: - # for kernels, also allow derivatives of that kernel to appear - - inline_expr_names = [] - inline_exprs = [] - for equations_block in node.get_equations_blocks(): - inline_expr_names.extend([inline_expr.variable_name for inline_expr in equations_block.get_inline_expressions()]) - inline_exprs.extend(equations_block.get_inline_expressions()) - - if var.get_name() in inline_expr_names: - inline_expr_idx = inline_expr_names.index(var.get_name()) - inline_expr = inline_exprs[inline_expr_idx] - from pynestml.utils.ast_utils import ASTUtils - if ASTUtils.inline_aliases_convolution(inline_expr): - symbol2 = node.get_scope().resolve_to_symbol(var.get_name(), SymbolKind.VARIABLE) - if symbol2 is not None: - # actually, no problem detected, skip error - # XXX: TODO: check that differential order is less than or equal to that of the kernel - continue # check if this symbol is actually a type, e.g. "mV" in the expression "(1 + 2) * mV" symbol2 = var.get_scope().resolve_to_symbol(var.get_complete_name(), SymbolKind.TYPE) @@ -106,9 +97,14 @@ def check_co_co(cls, node: ASTModel, after_ast_rewrite: bool = False): # in this case its ok if it is recursive or defined later on continue + if symbol.is_predefined: + continue + + if symbol.block_type == BlockType.LOCAL and symbol.get_referenced_object().get_source_position().before(var.get_source_position()): + continue + # check if it has been defined before usage, except for predefined symbols, input ports and variables added by the AST transformation functions - if (not symbol.is_predefined) \ - and symbol.block_type != BlockType.INPUT \ + if symbol.block_type != BlockType.INPUT \ and not symbol.get_referenced_object().get_source_position().is_added_source_position(): # except for parameters, those can be defined after if ((not symbol.get_referenced_object().get_source_position().before(var.get_source_position())) diff --git a/pynestml/cocos/co_co_function_unique.py b/pynestml/cocos/co_co_function_unique.py index 15643c0ad..bf0f2be60 100644 --- a/pynestml/cocos/co_co_function_unique.py +++ b/pynestml/cocos/co_co_function_unique.py @@ -65,4 +65,5 @@ def check_co_co(cls, model: ASTModel): log_level=LoggingLevel.ERROR, message=message, code=code) checked.append(funcA) + checked_funcs_names.append(func.get_name()) diff --git a/pynestml/cocos/co_co_illegal_expression.py b/pynestml/cocos/co_co_illegal_expression.py index b78396e3b..c362d0dc5 100644 --- a/pynestml/cocos/co_co_illegal_expression.py +++ b/pynestml/cocos/co_co_illegal_expression.py @@ -18,13 +18,13 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from pynestml.meta_model.ast_inline_expression import ASTInlineExpression -from pynestml.utils.ast_source_location import ASTSourceLocation -from pynestml.meta_model.ast_declaration import ASTDeclaration from pynestml.cocos.co_co import CoCo +from pynestml.meta_model.ast_declaration import ASTDeclaration +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from pynestml.symbols.error_type_symbol import ErrorTypeSymbol from pynestml.symbols.predefined_types import PredefinedTypes +from pynestml.utils.ast_source_location import ASTSourceLocation from pynestml.utils.logger import LoggingLevel, Logger from pynestml.utils.logging_helper import LoggingHelper from pynestml.utils.messages import Messages diff --git a/pynestml/cocos/co_co_no_kernels_except_in_convolve.py b/pynestml/cocos/co_co_no_kernels_except_in_convolve.py index 18b862292..e318ae566 100644 --- a/pynestml/cocos/co_co_no_kernels_except_in_convolve.py +++ b/pynestml/cocos/co_co_no_kernels_except_in_convolve.py @@ -22,11 +22,14 @@ from typing import List from pynestml.cocos.co_co import CoCo +from pynestml.meta_model.ast_declaration import ASTDeclaration +from pynestml.meta_model.ast_external_variable import ASTExternalVariable from pynestml.meta_model.ast_function_call import ASTFunctionCall from pynestml.meta_model.ast_kernel import ASTKernel from pynestml.meta_model.ast_model import ASTModel from pynestml.meta_model.ast_node import ASTNode from pynestml.meta_model.ast_variable import ASTVariable +from pynestml.symbols.predefined_functions import PredefinedFunctions from pynestml.symbols.symbol import SymbolKind from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.messages import Messages @@ -89,24 +92,44 @@ def visit_variable(self, node: ASTNode): if not (isinstance(node, ASTExternalVariable) and node.get_alternate_name()): code, message = Messages.get_no_variable_found(kernelName) Logger.log_message(node=self.__neuron_node, code=code, message=message, log_level=LoggingLevel.ERROR) + continue + if not symbol.is_kernel(): continue + if node.get_complete_name() == kernelName: - parent = node.get_parent() - if parent is not None: + parent = node + correct = False + while parent is not None and not isinstance(parent, ASTModel): + parent = parent.get_parent() + assert parent is not None + + if isinstance(parent, ASTDeclaration): + for lhs_var in parent.get_variables(): + if kernelName == lhs_var.get_complete_name(): + # kernel name appears on lhs of declaration, assume it is initial state + correct = True + parent = None # break out of outer loop + break + if isinstance(parent, ASTKernel): - continue - grandparent = parent.get_parent() - if grandparent is not None and isinstance(grandparent, ASTFunctionCall): - grandparent_func_name = grandparent.get_name() - if grandparent_func_name == 'convolve': - continue - code, message = Messages.get_kernel_outside_convolve(kernelName) - Logger.log_message(code=code, - message=message, - log_level=LoggingLevel.ERROR, - error_position=node.get_source_position()) + # kernel name is used inside kernel definition, e.g. for a node ``g``, it appears in ``kernel g'' = -1/tau**2 * g - 2/tau * g'`` + correct = True + break + + if isinstance(parent, ASTFunctionCall): + func_name = parent.get_name() + if func_name == PredefinedFunctions.CONVOLVE: + # kernel name is used inside convolve call + correct = True + + if not correct: + code, message = Messages.get_kernel_outside_convolve(kernelName) + Logger.log_message(code=code, + message=message, + log_level=LoggingLevel.ERROR, + error_position=node.get_source_position()) class KernelCollectingVisitor(ASTVisitor): diff --git a/pynestml/cocos/co_co_v_comp_exists.py b/pynestml/cocos/co_co_v_comp_exists.py index 4ef08c0ec..51308f2cc 100644 --- a/pynestml/cocos/co_co_v_comp_exists.py +++ b/pynestml/cocos/co_co_v_comp_exists.py @@ -43,9 +43,6 @@ def check_co_co(cls, neuron: ASTModel): Models which are supposed to be compartmental but do not contain state variable called v_comp are not correct. :param neuron: a single neuron instance. - :param after_ast_rewrite: indicates whether this coco is checked - after the code generator has done rewriting of the abstract syntax tree. - If True, checks are not as rigorous. Use False where possible. """ from pynestml.codegeneration.nest_compartmental_code_generator import NESTCompartmentalCodeGenerator diff --git a/pynestml/cocos/co_cos_manager.py b/pynestml/cocos/co_cos_manager.py index 7f1bbf244..c90ffa2b1 100644 --- a/pynestml/cocos/co_cos_manager.py +++ b/pynestml/cocos/co_cos_manager.py @@ -68,6 +68,7 @@ from pynestml.cocos.co_co_priorities_correctly_specified import CoCoPrioritiesCorrectlySpecified from pynestml.meta_model.ast_model import ASTModel from pynestml.frontend.frontend_configuration import FrontendConfiguration +from pynestml.utils.logger import Logger class CoCosManager: @@ -122,12 +123,12 @@ def check_state_variables_initialized(cls, model: ASTModel): CoCoStateVariablesInitialized.check_co_co(model) @classmethod - def check_variables_defined_before_usage(cls, model: ASTModel, after_ast_rewrite: bool) -> None: + def check_variables_defined_before_usage(cls, model: ASTModel) -> None: """ Checks that all variables are defined before being used. :param model: a single model. """ - CoCoAllVariablesDefined.check_co_co(model, after_ast_rewrite) + CoCoAllVariablesDefined.check_co_co(model) @classmethod def check_v_comp_requirement(cls, neuron: ASTModel): @@ -409,17 +410,19 @@ def check_co_co_nest_random_functions_legally_used(cls, model: ASTModel): CoCoNestRandomFunctionsLegallyUsed.check_co_co(model) @classmethod - def post_symbol_table_builder_checks(cls, model: ASTModel, after_ast_rewrite: bool = False): + def check_cocos(cls, model: ASTModel, after_ast_rewrite: bool = False): """ Checks all context conditions. :param model: a single model object. """ + Logger.set_current_node(model) + cls.check_each_block_defined_at_most_once(model) cls.check_function_defined(model) cls.check_variables_unique_in_scope(model) cls.check_inline_expression_not_assigned_to(model) cls.check_state_variables_initialized(model) - cls.check_variables_defined_before_usage(model, after_ast_rewrite) + cls.check_variables_defined_before_usage(model) if FrontendConfiguration.get_target_platform().upper() == 'NEST_COMPARTMENTAL': # XXX: TODO: refactor this out; define a ``cocos_from_target_name()`` in the frontend instead. cls.check_v_comp_requirement(model) @@ -459,3 +462,5 @@ def post_symbol_table_builder_checks(cls, model: ASTModel, after_ast_rewrite: bo cls.check_co_co_priorities_correctly_specified(model) cls.check_resolution_func_legally_used(model) cls.check_input_port_size_type(model) + + Logger.set_current_node(None) diff --git a/pynestml/codegeneration/builder.py b/pynestml/codegeneration/builder.py index 2e6757c1a..a9f98bf58 100644 --- a/pynestml/codegeneration/builder.py +++ b/pynestml/codegeneration/builder.py @@ -20,12 +20,12 @@ # along with NEST. If not, see . from __future__ import annotations -import subprocess -import os from typing import Any, Mapping, Optional from abc import ABCMeta, abstractmethod +import os +import subprocess from pynestml.exceptions.invalid_target_exception import InvalidTargetException from pynestml.frontend.frontend_configuration import FrontendConfiguration diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py index 66e0c9e13..9090d5e34 100644 --- a/pynestml/codegeneration/nest_code_generator.py +++ b/pynestml/codegeneration/nest_code_generator.py @@ -177,7 +177,7 @@ def run_nest_target_specific_cocos(self, neurons: Sequence[ASTModel], synapses: # Check if the random number functions are used in the right blocks CoCosManager.check_co_co_nest_random_functions_legally_used(model) - if Logger.has_errors(model): + if Logger.has_errors(model.name): raise Exception("Error(s) occurred during code generation") if self.get_option("neuron_synapse_pairs"): @@ -196,7 +196,7 @@ def run_nest_target_specific_cocos(self, neurons: Sequence[ASTModel], synapses: delay_variable = self.get_option("delay_variable")[synapse_name_stripped] CoCoNESTSynapseDelayNotAssignedTo.check_co_co(delay_variable, model) - if Logger.has_errors(model): + if Logger.has_errors(model.name): raise Exception("Error(s) occurred during code generation") def setup_printers(self): @@ -384,6 +384,9 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di if not used_in_eq: self.non_equations_state_variables[neuron.get_name()].append(var) + # cache state variables before symbol table update for the sake of delay variables + state_vars_before_update = neuron.get_state_symbols() + ASTUtils.remove_initial_values_for_kernels(neuron) kernels = ASTUtils.remove_kernel_definitions_from_equations_block(neuron) ASTUtils.update_initial_values_for_odes(neuron, [analytic_solver, numeric_solver]) @@ -398,7 +401,6 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di neuron = ASTUtils.add_declarations_to_internals( neuron, self.analytic_solver[neuron.get_name()]["propagators"]) - state_vars_before_update = neuron.get_state_symbols() self.update_symbol_table(neuron) # Update the delay parameter parameters after symbol table update @@ -908,8 +910,8 @@ def update_symbol_table(self, neuron) -> None: """ SymbolTable.delete_model_scope(neuron.get_name()) symbol_table_visitor = ASTSymbolTableVisitor() - symbol_table_visitor.after_ast_rewrite_ = True neuron.accept(symbol_table_visitor) + CoCosManager.check_cocos(neuron, after_ast_rewrite=True) SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope()) def get_spike_update_expressions(self, neuron: ASTModel, kernel_buffers, solver_dicts, delta_factors) -> Tuple[Dict[str, ASTAssignment], Dict[str, ASTAssignment]]: diff --git a/pynestml/codegeneration/nest_compartmental_code_generator.py b/pynestml/codegeneration/nest_compartmental_code_generator.py index 4711bc497..00f061775 100644 --- a/pynestml/codegeneration/nest_compartmental_code_generator.py +++ b/pynestml/codegeneration/nest_compartmental_code_generator.py @@ -18,14 +18,18 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -import shutil + from typing import Any, Dict, List, Mapping, Optional import datetime import os from jinja2 import TemplateRuntimeError + +from odetoolbox import analysis + import pynestml +from pynestml.cocos.co_cos_manager import CoCosManager from pynestml.codegeneration.code_generator import CodeGenerator from pynestml.codegeneration.nest_assignments_helper import NestAssignmentsHelper from pynestml.codegeneration.nest_declarations_helper import NestDeclarationsHelper @@ -53,9 +57,9 @@ from pynestml.meta_model.ast_variable import ASTVariable from pynestml.symbol_table.symbol_table import SymbolTable from pynestml.symbols.symbol import SymbolKind +from pynestml.transformers.inline_expression_expansion_transformer import InlineExpressionExpansionTransformer from pynestml.utils.ast_vector_parameter_setter_and_printer import ASTVectorParameterSetterAndPrinter from pynestml.utils.ast_vector_parameter_setter_and_printer_factory import ASTVectorParameterSetterAndPrinterFactory -from pynestml.transformers.inline_expression_expansion_transformer import InlineExpressionExpansionTransformer from pynestml.utils.mechanism_processing import MechanismProcessing from pynestml.utils.channel_processing import ChannelProcessing from pynestml.utils.concentration_processing import ConcentrationProcessing @@ -72,7 +76,6 @@ from pynestml.utils.synapse_processing import SynapseProcessing from pynestml.visitors.ast_random_number_generator_visitor import ASTRandomNumberGeneratorVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor -from odetoolbox import analysis class NESTCompartmentalCodeGenerator(CodeGenerator): @@ -740,8 +743,8 @@ def update_symbol_table(self, neuron, kernel_buffers): """ SymbolTable.delete_model_scope(neuron.get_name()) symbol_table_visitor = ASTSymbolTableVisitor() - symbol_table_visitor.after_ast_rewrite_ = True neuron.accept(symbol_table_visitor) + CoCosManager.check_cocos(neuron, after_ast_rewrite=True) SymbolTable.add_model_scope(neuron.get_name(), neuron.get_scope()) def _get_ast_variable(self, neuron, var_name) -> Optional[ASTVariable]: diff --git a/pynestml/codegeneration/python_standalone_code_generator.py b/pynestml/codegeneration/python_standalone_code_generator.py index f44123743..d6afaa095 100644 --- a/pynestml/codegeneration/python_standalone_code_generator.py +++ b/pynestml/codegeneration/python_standalone_code_generator.py @@ -111,7 +111,6 @@ def setup_printers(self): # GSL printers self._gsl_variable_printer = PythonSteppingFunctionVariablePrinter(None) - print("In Python code generator: created self._gsl_variable_printer = " + str(self._gsl_variable_printer)) self._gsl_function_call_printer = PythonSteppingFunctionFunctionCallPrinter(None) self._gsl_printer = PythonExpressionPrinter(simple_expression_printer=PythonSimpleExpressionPrinter(variable_printer=self._gsl_variable_printer, constant_printer=self._constant_printer, diff --git a/pynestml/codegeneration/spinnaker_code_generator.py b/pynestml/codegeneration/spinnaker_code_generator.py index 2a8fed7de..0fab611ed 100644 --- a/pynestml/codegeneration/spinnaker_code_generator.py +++ b/pynestml/codegeneration/spinnaker_code_generator.py @@ -137,7 +137,6 @@ def setup_printers(self): # GSL printers self._gsl_variable_printer = PythonSteppingFunctionVariablePrinter(None) - print("In Python code generator: created self._gsl_variable_printer = " + str(self._gsl_variable_printer)) self._gsl_function_call_printer = PythonSteppingFunctionFunctionCallPrinter(None) self._gsl_printer = PythonExpressionPrinter(simple_expression_printer=SpinnakerPythonSimpleExpressionPrinter( variable_printer=self._gsl_variable_printer, @@ -216,14 +215,8 @@ def generate_code(self, models: Sequence[ASTModel]) -> None: for model in models: cloned_model = model.clone() cloned_model.accept(ASTSymbolTableVisitor()) + CoCosManager.check_cocos(cloned_model) cloned_models.append(cloned_model) self.codegen_cpp.generate_code(cloned_models) - - cloned_models = [] - for model in models: - cloned_model = model.clone() - cloned_model.accept(ASTSymbolTableVisitor()) - cloned_models.append(cloned_model) - self.codegen_py.generate_code(cloned_models) diff --git a/pynestml/frontend/frontend_configuration.py b/pynestml/frontend/frontend_configuration.py index 173534c95..aae1fc29a 100644 --- a/pynestml/frontend/frontend_configuration.py +++ b/pynestml/frontend/frontend_configuration.py @@ -244,8 +244,8 @@ def handle_module_name(cls, module_name): @classmethod def handle_target_platform(cls, target_platform: Optional[str]): - if target_platform is None or target_platform.upper() == 'NONE': - target_platform = '' # make sure `target_platform` is always a string + if target_platform is None: + target_platform = "NONE" # make sure `target_platform` is always a string from pynestml.frontend.pynestml_frontend import get_known_targets diff --git a/pynestml/frontend/pynestml_frontend.py b/pynestml/frontend/pynestml_frontend.py index 59b9bb5f8..058fe8ca8 100644 --- a/pynestml/frontend/pynestml_frontend.py +++ b/pynestml/frontend/pynestml_frontend.py @@ -41,6 +41,8 @@ from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.messages import Messages from pynestml.utils.model_parser import ModelParser +from pynestml.visitors.ast_parent_visitor import ASTParentVisitor +from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor def get_known_targets(): @@ -131,10 +133,10 @@ def code_generator_from_target_name(target_name: str, options: Optional[Mapping[ return SpiNNakerCodeGenerator(options) if target_name.upper() == "NONE": - # dummy/null target: user requested to not generate any code + # dummy/null target: user requested to not generate any code (for instance, when just doing validation of a model) code, message = Messages.get_no_code_generated() Logger.log_message(None, code, message, None, LoggingLevel.INFO) - return CodeGenerator("", options) + return CodeGenerator(options) # cannot reach here due to earlier assert -- silence static checker warnings assert "Unknown code generator requested: " + target_name @@ -193,12 +195,17 @@ def generate_target(input_path: Union[str, Sequence[str]], target_platform: str, Enable development mode: code generation is attempted even for models that contain errors, and extra information is rendered in the generated code. codegen_opts : Optional[Mapping[str, Any]] A dictionary containing additional options for the target code generator. + + Return + ------ + errors_occurred + Flag indicating whether errors occurred during processing. False if processing was successful; True if errors occurred in any of the models. """ configure_front_end(input_path, target_platform, target_path, install_path, logging_level, module_name, store_log, suffix, dev, codegen_opts) - if not process() == 0: - raise Exception("Error(s) occurred while processing the model") + + return process() def configure_front_end(input_path: Union[str, Sequence[str]], target_platform: str, target_path=None, @@ -373,34 +380,36 @@ def generate_nest_compartmental_target(input_path: Union[str, Sequence[str]], ta def main() -> int: - """ + r""" Entry point for the command-line application. Returns ------- - The process exit code: 0 for success, > 0 for failure + exit_code + The process exit code: 0 for success, > 0 for failure """ try: FrontendConfiguration.parse_config(sys.argv[1:]) except InvalidPathException as e: print(e) + return 1 + # the default Python recursion limit is 1000, which might not be enough in practice when running an AST visitor on a deep tree, e.g. containing an automatically generated expression sys.setrecursionlimit(10000) + # after all argument have been collected, start the actual processing return int(process()) -def get_parsed_models(): +def get_parsed_models() -> List[ASTModel]: r""" Handle the parsing and validation of the NESTML files Returns ------- - models: Sequence[ASTModel] + models List of correctly parsed models - errors_occurred : bool - Flag indicating whether errors occurred during processing """ # init log dir create_report_dir() @@ -417,36 +426,25 @@ def get_parsed_models(): for nestml_file in nestml_files: parsed_unit = ModelParser.parse_file(nestml_file) - if parsed_unit is None: - # Parsing error in the NESTML model, return True - return [], True - - compilation_units.append(parsed_unit) - - if len(compilation_units) > 0: - # generate a list of all models - models: Sequence[ASTModel] = [] - for compilationUnit in compilation_units: - models.extend(compilationUnit.get_model_list()) + if parsed_unit: + compilation_units.append(parsed_unit) - # check that no models with duplicate names have been defined - CoCosManager.check_no_duplicate_compilation_unit_names(models) + # generate a list of all models + models: Sequence[ASTModel] = [] + for compilation_unit in compilation_units: + CoCosManager.check_model_names_unique(compilation_unit) + models.extend(compilation_unit.get_model_list()) - # now exclude those which are broken, i.e. have errors. - for model in models: - if Logger.has_errors(model): - code, message = Messages.get_model_contains_errors(model.get_name()) - Logger.log_message(node=model, code=code, message=message, - error_position=model.get_source_position(), - log_level=LoggingLevel.WARNING) - return [model], True + # check that no models with duplicate names have been defined + CoCosManager.check_no_duplicate_compilation_unit_names(models) - return models, False + return models def transform_models(transformers, models): for transformer in transformers: models = transformer.transform(models) + return models @@ -454,14 +452,14 @@ def generate_code(code_generators, models): code_generators.generate_code(models) -def process(): +def process() -> bool: r""" The main toolchain workflow entry point. For all models: parse, validate, transform, generate code and build. - Returns - ------- - errors_occurred : bool - Flag indicating whether errors occurred during processing + Return + ------ + errors_occurred + Flag indicating whether errors occurred during processing. False if processing was successful; True if errors occurred in any of the models. """ # initialise model transformers @@ -481,20 +479,42 @@ def process(): if opt_key in unused_opts_transformer.keys() and opt_key in unused_opts_codegen.keys() and opt_key in unused_opts_builder.keys(): raise CodeGeneratorOptionsException("The code generator option \"" + opt_key + "\" does not exist.") - models, errors_occurred = get_parsed_models() + models = get_parsed_models() + + # validation -- check cocos for models that do not have errors already + excluded_models = [] + for model in models: + if not Logger.has_errors(model.name): + CoCosManager.check_cocos(model) + + if Logger.has_errors(model.name): + code, message = Messages.get_model_contains_errors(model.get_name()) + Logger.log_message(node=model, code=code, message=message, + error_position=model.get_source_position(), + log_level=LoggingLevel.WARNING) + excluded_models.append(model) + + # exclude models that have errors + models = list(set(models) - set(excluded_models)) + + if len(models) == 0: + return True # there is no model code to generate, return error condition + + # transformation(s) + models = transform_models(transformers, models) - if not errors_occurred: - models = transform_models(transformers, models) - generate_code(code_generator, models) + # generate code + generate_code(code_generator, models) - # perform build - if _builder is not None: - _builder.build() + # perform build + if _builder is not None: + _builder.build() if FrontendConfiguration.store_log: store_log_to_file() - return errors_occurred + # return a boolean indicating whether errors occurred + return len(Logger.get_all_messages_of_level(LoggingLevel.ERROR)) > 0 def init_predefined(): diff --git a/pynestml/meta_model/ast_model.py b/pynestml/meta_model/ast_model.py index 834e56897..c4b7374bf 100644 --- a/pynestml/meta_model/ast_model.py +++ b/pynestml/meta_model/ast_model.py @@ -459,23 +459,27 @@ def add_to_internals_block(self, declaration: ASTDeclaration, index: int = -1) - Adds the handed over declaration the internals block :param declaration: a single declaration """ - assert len(self.get_internals_blocks()) <= 1, "Only one internals block supported for now" from pynestml.utils.ast_utils import ASTUtils + from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + + assert len(self.get_internals_blocks()) <= 1, "Only one internals block supported for now" + if not self.get_internals_blocks(): ASTUtils.create_internal_block(self) + n_declarations = len(self.get_internals_blocks()[0].get_declarations()) if n_declarations == 0: index = 0 else: index = 1 + (index % len(self.get_internals_blocks()[0].get_declarations())) + self.get_internals_blocks()[0].get_declarations().insert(index, declaration) declaration.update_scope(self.get_internals_blocks()[0].get_scope()) - from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor - from pynestml.visitors.ast_parent_visitor import ASTParentVisitor symtable_vistor = ASTSymbolTableVisitor() symtable_vistor.block_type_stack.push(BlockType.INTERNALS) - declaration.accept(symtable_vistor) - self.get_internals_blocks()[0].accept(ASTParentVisitor()) + self.accept(ASTParentVisitor()) + self.accept(symtable_vistor) symtable_vistor.block_type_stack.pop() def add_to_state_block(self, declaration: ASTDeclaration) -> None: @@ -483,24 +487,26 @@ def add_to_state_block(self, declaration: ASTDeclaration) -> None: Adds the handed over declaration to an arbitrary state block. A state block will be created if none exists. :param declaration: a single declaration. """ - assert len(self.get_state_blocks()) <= 1, "Only one internals block supported for now" + from pynestml.symbols.symbol import SymbolKind from pynestml.utils.ast_utils import ASTUtils + from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + + assert len(self.get_state_blocks()) <= 1, "Only one internals block supported for now" + if not self.get_state_blocks(): ASTUtils.create_state_block(self) + self.get_state_blocks()[0].get_declarations().append(declaration) declaration.update_scope(self.get_state_blocks()[0].get_scope()) - from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor - from pynestml.visitors.ast_parent_visitor import ASTParentVisitor symtable_vistor = ASTSymbolTableVisitor() symtable_vistor.block_type_stack.push(BlockType.STATE) - declaration.accept(symtable_vistor) - self.get_state_blocks()[0].accept(ASTParentVisitor()) + self.accept(ASTParentVisitor()) + self.accept(symtable_vistor) symtable_vistor.block_type_stack.pop() - from pynestml.symbols.symbol import SymbolKind - assert declaration.get_variables()[0].get_scope().resolve_to_symbol( - declaration.get_variables()[0].get_name(), SymbolKind.VARIABLE) is not None - assert declaration.get_scope().resolve_to_symbol(declaration.get_variables()[0].get_name(), - SymbolKind.VARIABLE) is not None + + assert declaration.get_variables()[0].get_scope().resolve_to_symbol(declaration.get_variables()[0].get_name(), SymbolKind.VARIABLE) is not None + assert declaration.get_scope().resolve_to_symbol(declaration.get_variables()[0].get_name(), SymbolKind.VARIABLE) is not None def print_comment(self, prefix: str = "") -> str: """ @@ -566,7 +572,6 @@ def get_spike_input_port_names(self) -> List[str]: """ Returns a list of all spike input ports defined in the model. """ - print("get_spike_input_port_names = " + str([port.get_symbol_name() for port in self.get_spike_input_ports()])) return [port.get_symbol_name() for port in self.get_spike_input_ports()] def get_continuous_input_ports(self) -> List[VariableSymbol]: diff --git a/pynestml/symbols/symbol.py b/pynestml/symbols/symbol.py index 1e294566b..c73435c6d 100644 --- a/pynestml/symbols/symbol.py +++ b/pynestml/symbols/symbol.py @@ -18,8 +18,8 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from abc import ABCMeta, abstractmethod +from abc import ABCMeta, abstractmethod from enum import Enum diff --git a/pynestml/symbols/type_symbol.py b/pynestml/symbols/type_symbol.py index 7047cdbca..a3eb28a12 100644 --- a/pynestml/symbols/type_symbol.py +++ b/pynestml/symbols/type_symbol.py @@ -18,11 +18,11 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . + from abc import ABCMeta, abstractmethod from pynestml.symbols.symbol import Symbol from pynestml.utils.logger import Logger, LoggingLevel -from pynestml.utils.messages import Messages class TypeSymbol(Symbol): @@ -198,6 +198,7 @@ def is_castable_to(self, _other_type): def binary_operation_not_defined_error(self, _operator, _other): from pynestml.symbols.error_type_symbol import ErrorTypeSymbol + from pynestml.utils.messages import Messages result = ErrorTypeSymbol() code, message = Messages.get_binary_operation_not_defined( lhs=self.print_nestml_type(), operator=_operator, rhs=_other.print_nestml_type()) @@ -208,6 +209,7 @@ def binary_operation_not_defined_error(self, _operator, _other): def unary_operation_not_defined_error(self, _operator): from pynestml.symbols.error_type_symbol import ErrorTypeSymbol result = ErrorTypeSymbol() + from pynestml.utils.messages import Messages code, message = Messages.get_unary_operation_not_defined(_operator, self.print_symbol()) Logger.log_message(code=code, message=message, error_position=self.referenced_object.get_source_position(), @@ -226,6 +228,7 @@ def inverse_of_unit(cls, other): return result def warn_implicit_cast_from_to(self, _from, _to): + from pynestml.utils.messages import Messages code, message = Messages.get_implicit_cast_rhs_to_lhs(_to.print_symbol(), _from.print_symbol()) Logger.log_message(code=code, message=message, error_position=self.get_referenced_object().get_source_position(), diff --git a/pynestml/symbols/unit_type_symbol.py b/pynestml/symbols/unit_type_symbol.py index 37c43b035..1f9977de0 100644 --- a/pynestml/symbols/unit_type_symbol.py +++ b/pynestml/symbols/unit_type_symbol.py @@ -19,6 +19,7 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import Optional from pynestml.symbols.type_symbol import TypeSymbol from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.messages import Messages @@ -131,12 +132,12 @@ def __sub__(self, other): def add_or_sub_another_unit(self, other): if self.equals(other): return other - else: - return self.attempt_magnitude_cast(other) + + return self.attempt_magnitude_cast(other) def attempt_magnitude_cast(self, other): if self.differs_only_in_magnitude(other): - factor = UnitTypeSymbol.get_conversion_factor(self.astropy_unit, other.astropy_unit) + factor = UnitTypeSymbol.get_conversion_factor(other.astropy_unit, self.astropy_unit) other.referenced_object.set_implicit_conversion_factor(factor) code, message = Messages.get_implicit_magnitude_conversion(self, other, factor) Logger.log_message(code=code, message=message, @@ -144,18 +145,20 @@ def attempt_magnitude_cast(self, other): log_level=LoggingLevel.INFO) return self - else: - return self.binary_operation_not_defined_error('+/-', other) - # TODO: change order of parameters to conform with the from_to scheme. - # TODO: Also rename to reflect that, i.e. get_conversion_factor_from_to + return self.binary_operation_not_defined_error('+/-', other) + @classmethod - def get_conversion_factor(cls, to, _from): + def get_conversion_factor(cls, _from, to) -> Optional[float]: """ - Calculates the conversion factor from _convertee_unit to target_unit. - Behaviour is only well-defined if both units have the same physical base type + Calculates the conversion factor from _convertee_unit to target_unit. Behaviour is only well-defined if both units have the same physical base type. """ - factor = (_from / to).si.scale + try: + factor = (_from / to).si.scale + except BaseException: + # this can fail in case of e.g. trying to convert from "1/s" to "2/s" + return None + return factor def is_castable_to(self, _other_type): diff --git a/pynestml/transformers/synapse_post_neuron_transformer.py b/pynestml/transformers/synapse_post_neuron_transformer.py index 68cc70a62..a06260824 100644 --- a/pynestml/transformers/synapse_post_neuron_transformer.py +++ b/pynestml/transformers/synapse_post_neuron_transformer.py @@ -23,6 +23,7 @@ from typing import Any, Sequence, Mapping, Optional, Union +from pynestml.cocos.co_cos_manager import CoCosManager from pynestml.frontend.frontend_configuration import FrontendConfiguration from pynestml.meta_model.ast_assignment import ASTAssignment from pynestml.meta_model.ast_equations_block import ASTEquationsBlock @@ -563,11 +564,6 @@ def mark_post_port(_expr=None): # replace occurrences of the variables in expressions in the original synapse with calls to the corresponding neuron getters # - # make sure the moved symbols can be resolved in the scope of the neuron (that's where ``ASTExternalVariable._altscope`` will be pointing to) - ast_symbol_table_visitor = ASTSymbolTableVisitor() - ast_symbol_table_visitor.after_ast_rewrite_ = True - new_neuron.accept(ast_symbol_table_visitor) - Logger.log_message( None, -1, "In synapse: replacing variables with suffixed external variable references", None, LoggingLevel.INFO) for state_var in syn_to_neuron_state_vars: @@ -609,7 +605,6 @@ def mark_post_port(_expr=None): new_neuron.accept(ASTParentVisitor()) new_synapse.accept(ASTParentVisitor()) ast_symbol_table_visitor = ASTSymbolTableVisitor() - ast_symbol_table_visitor.after_ast_rewrite_ = True new_neuron.accept(ast_symbol_table_visitor) new_synapse.accept(ast_symbol_table_visitor) diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index d3d6f6ef5..a3983694d 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -28,7 +28,6 @@ from pynestml.codegeneration.printers.ast_printer import ASTPrinter from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter -from pynestml.codegeneration.printers.nestml_printer import NESTMLPrinter from pynestml.frontend.frontend_configuration import FrontendConfiguration from pynestml.generated.PyNestMLLexer import PyNestMLLexer from pynestml.meta_model.ast_assignment import ASTAssignment @@ -66,7 +65,6 @@ from pynestml.utils.messages import Messages from pynestml.utils.string_utils import removesuffix from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor -from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_visitor import ASTVisitor @@ -1766,10 +1764,12 @@ def remove_initial_values_for_kernels(cls, model: ASTModel) -> None: @classmethod def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict]) -> None: """ - Update initial values for original ODE declarations (e.g. V_m', g_ahp'') that are present in the model - before ODE-toolbox processing, with the formatted variable names and initial values returned by ODE-toolbox. + Update initial values for original ODE declarations (e.g. V_m', g_ahp'') that are present in the model before ODE-toolbox processing, with the formatted variable names and initial values returned by ODE-toolbox. """ from pynestml.utils.model_parser import ModelParser + from pynestml.visitors.ast_parent_visitor import ASTParentVisitor + from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor + assert len(model.get_equations_blocks()) == 1, "Only one equation block should be present" if not model.get_state_blocks(): @@ -1782,10 +1782,6 @@ def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict if cls.is_ode_variable(var.get_name(), model): assert cls.variable_in_solver(cls.to_ode_toolbox_processed_name(var_name), solver_dicts) - # replace the left-hand side variable name by the ode-toolbox format - var.set_name(cls.to_ode_toolbox_processed_name(var.get_complete_name())) - var.set_differential_order(0) - # replace the defining expression by the ode-toolbox result iv_expr = cls.get_initial_value_from_ode_toolbox_result( cls.to_ode_toolbox_processed_name(var_name), solver_dicts) @@ -1794,6 +1790,9 @@ def update_initial_values_for_odes(cls, model: ASTModel, solver_dicts: List[dict iv_expr.update_scope(state_block.get_scope()) iv_decl.set_expression(iv_expr) + model.accept(ASTParentVisitor()) + model.accept(ASTSymbolTableVisitor()) + @classmethod def integrate_odes_args_strs_from_function_call(cls, function_call: ASTFunctionCall): arg_names = [] @@ -2296,6 +2295,7 @@ def replace_convolve_calls_with_buffers_(cls, model: ASTModel, equations_block: r""" Replace all occurrences of `convolve(kernel[']^n, spike_input_port)` with the corresponding buffer variable, e.g. `g_E__X__spikes_exc[__d]^n` for a kernel named `g_E` and a spike input port named `spikes_exc`. """ + from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor def replace_function_call_through_var(_expr=None): if _expr.is_function_call() and _expr.get_function_call().get_name() == "convolve": @@ -2326,6 +2326,7 @@ def func(x): return replace_function_call_through_var(x) if isinstance(x, ASTSimpleExpression) else True equations_block.accept(ASTHigherOrderVisitor(func)) + equations_block.accept(ASTSymbolTableVisitor()) @classmethod def update_blocktype_for_common_parameters(cls, node): diff --git a/pynestml/utils/logger.py b/pynestml/utils/logger.py index 06e95b804..8404f1245 100644 --- a/pynestml/utils/logger.py +++ b/pynestml/utils/logger.py @@ -19,7 +19,7 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from typing import List, Mapping, Optional, Tuple +from typing import List, Mapping, Optional, Tuple, Union from collections import OrderedDict from enum import Enum @@ -75,6 +75,7 @@ class Logger: def init_logger(cls, logging_level: LoggingLevel): """ Initializes the logger. + :param logging_level: the logging level as required :type logging_level: LoggingLevel """ @@ -82,7 +83,6 @@ def init_logger(cls, logging_level: LoggingLevel): cls.curr_message = 0 cls.log = {} cls.log_frozen = False - return @classmethod def freeze_log(cls, do_freeze: bool = True): @@ -95,6 +95,7 @@ def freeze_log(cls, do_freeze: bool = True): def get_log(cls) -> Mapping[int, Tuple[ASTNode, LoggingLevel, str]]: """ Returns the overall log of messages. The structure of the log is: (NODE, LEVEL, MESSAGE) + :return: mapping from id to ASTNode, log level and message. """ return cls.log @@ -103,6 +104,7 @@ def get_log(cls) -> Mapping[int, Tuple[ASTNode, LoggingLevel, str]]: def set_log(cls, log, counter): """ Restores log from the 'log' variable + :param log: the log :param counter: the counter """ @@ -113,20 +115,19 @@ def set_log(cls, log, counter): def log_message(cls, node: ASTNode = None, code: MessageCode = None, message: str = None, error_position: ASTSourceLocation = None, log_level: LoggingLevel = None): """ Logs the handed over message on the handed over node. If the current logging is appropriate, the message is also printed. + :param node: the node in which the error occurred :param code: a single error code - :type code: ErrorCode :param error_position: the position on which the error occurred. - :type error_position: SourcePosition :param message: a message. - :type message: str :param log_level: the corresponding log level. - :type log_level: LoggingLevel """ if cls.log_frozen: return + if cls.curr_message is None: cls.init_logger(LoggingLevel.INFO) + from pynestml.meta_model.ast_node import ASTNode from pynestml.utils.ast_source_location import ASTSourceLocation assert (node is None or isinstance(node, ASTNode)), \ @@ -134,15 +135,23 @@ def log_message(cls, node: ASTNode = None, code: MessageCode = None, message: st assert (error_position is None or isinstance(error_position, ASTSourceLocation)), \ '(PyNestML.Logger) Wrong type of error position provided (%s)!' % type(error_position) from pynestml.meta_model.ast_model import ASTModel + if isinstance(node, ASTModel): cls.log[cls.curr_message] = ( node.get_artifact_name(), node, log_level, code, error_position, message) - elif cls.current_node is not None: - cls.log[cls.curr_message] = (cls.current_node.get_artifact_name(), cls.current_node, + else: + if cls.current_node is not None: + artifact_name = cls.current_node.get_artifact_name() + else: + artifact_name = "" + + cls.log[cls.curr_message] = (artifact_name, cls.current_node, log_level, code, error_position, message) + cls.curr_message += 1 if cls.no_print: return + if cls.logging_level.value <= log_level.value: if isinstance(node, ASTInlineExpression): node_name = node.variable_name @@ -163,10 +172,9 @@ def log_message(cls, node: ASTNode = None, code: MessageCode = None, message: st def string_to_level(cls, string: str) -> LoggingLevel: """ Returns the logging level corresponding to the handed over string. If no such exits, returns None. + :param string: a single string representing the level. - :type string: str :return: a single logging level. - :rtype: LoggingLevel """ if string == 'DEBUG': return LoggingLevel.DEBUG @@ -183,7 +191,7 @@ def string_to_level(cls, string: str) -> LoggingLevel: if string == 'NO' or string == 'NONE': return LoggingLevel.NO - raise Exception('Tried to convert unknown string \"' + string + '\" to logging level') + raise Exception("Tried to convert unknown string '" + string + "' to logging level") @classmethod def level_to_string(cls, level: LoggingLevel) -> str: @@ -207,7 +215,7 @@ def level_to_string(cls, level: LoggingLevel) -> str: if level == LoggingLevel.NO: return 'NO' - raise Exception('Tried to convert unknown logging level \"' + str(level) + '\" to string') + raise Exception("Tried to convert unknown logging level '" + str(level) + "' to string") @classmethod def set_logging_level(cls, level: LoggingLevel) -> None: @@ -218,79 +226,89 @@ def set_logging_level(cls, level: LoggingLevel) -> None: """ if cls.log_frozen: return + cls.logging_level = level @classmethod def set_current_node(cls, node: Optional[ASTNode]) -> None: """ - Sets the handed over node as the currently processed one. This enables a retrieval of messages for a - specific node. - :param node: a single node instance + Sets the handed over node as the currently processed one. This enables a retrieval of messages for a specific node. + + :param node: a single node instance """ cls.current_node = node @classmethod - def get_all_messages_of_level_and_or_node(cls, node: ASTNode, level: LoggingLevel) -> List[Tuple[ASTNode, LoggingLevel, str]]: + def get_all_messages_of_level_and_or_node(cls, node: Union[ASTNode, str], level: LoggingLevel) -> List[Tuple[ASTNode, LoggingLevel, str]]: """ - Returns all messages which have a certain logging level, or have been reported for a certain node, or - both. + Returns all messages which have a certain logging level, or have been reported for a certain node, or both. + :param node: a single node instance :param level: a logging level - :type level: LoggingLevel :return: a list of messages with their levels. - :rtype: list((str,Logging_Level) """ if level is None and node is None: return cls.get_log() + + if isinstance(node, str): + # search by artifact name + node_artifact_name = node + node = None + else: + # search by artifact class object + node_artifact_name = None + ret = list() for (artifactName, node_i, logLevel, code, errorPosition, message) in cls.log.values(): - if (level == logLevel if level is not None else True) and ( - node if node is not None else True) and ( - node.get_artifact_name() == artifactName if node is not None else True): + if (level == logLevel if level is not None else True) and (node if node is not None else True) and (node_artifact_name == artifactName if node is not None else True): ret.append((node, logLevel, message)) + return ret @classmethod def get_all_messages_of_level(cls, level: LoggingLevel) -> List[Tuple[ASTNode, LoggingLevel, str]]: """ Returns all messages which have a certain logging level. + :param level: a logging level - :type level: LoggingLevel :return: a list of messages with their levels. - :rtype: list((str,Logging_Level) """ if level is None: return cls.get_log() + ret = list() for (artifactName, node, logLevel, code, errorPosition, message) in cls.log.values(): if level == logLevel: ret.append((node, logLevel, message)) + return ret @classmethod def get_all_messages_of_node(cls, node: ASTNode) -> List[Tuple[ASTNode, LoggingLevel, str]]: """ Returns all messages which have been reported for a certain node. + :param node: a single node instance :return: a list of messages with their levels. - :rtype: list((str,Logging_Level) """ if node is None: return cls.get_log() + ret = list() for (artifactName, node_i, logLevel, code, errorPosition, message) in cls.log.values(): if (node_i == node if node is not None else True) and \ (node.get_artifact_name() == artifactName if node is not None else True): ret.append((node, logLevel, message)) + return ret @classmethod def has_errors(cls, node: ASTNode) -> bool: """ Indicates whether the handed over node, thus the corresponding model, has errors. + :param node: a single node instance. :return: True if errors detected, otherwise False - :rtype: bool """ return len(cls.get_all_messages_of_level_and_or_node(node, LoggingLevel.ERROR)) > 0 @@ -311,6 +329,7 @@ def get_json_format(cls) -> str: (node.get_name() if node is not None else 'GLOBAL') + '", ' + \ '"severity":"' \ + str(logLevel.name) + '", ' + if code is not None: ret += '"code":"' + \ code.name + \ @@ -323,10 +342,12 @@ def get_json_format(cls) -> str: '", ' + \ '"message":"' + str(message).replace('"', "'") + '"}' ret += ',' + if len(cls.log.keys()) == 0: parsed = json.loads('[]', object_pairs_hook=OrderedDict) else: ret = ret[:-1] # delete the last "," ret += ']' parsed = json.loads(ret, object_pairs_hook=OrderedDict) + return json.dumps(parsed, indent=2, sort_keys=False) diff --git a/pynestml/utils/mechs_info_enricher.py b/pynestml/utils/mechs_info_enricher.py index 456ece178..ea645a02c 100644 --- a/pynestml/utils/mechs_info_enricher.py +++ b/pynestml/utils/mechs_info_enricher.py @@ -22,13 +22,14 @@ from collections import defaultdict from pynestml.meta_model.ast_model import ASTModel +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.symbol import SymbolKind +from pynestml.utils.ast_vector_parameter_setter_and_printer_factory import ASTVectorParameterSetterAndPrinterFactory from pynestml.visitors.ast_parent_visitor import ASTParentVisitor from pynestml.visitors.ast_symbol_table_visitor import ASTSymbolTableVisitor from pynestml.utils.ast_utils import ASTUtils -from pynestml.visitors.ast_visitor import ASTVisitor from pynestml.utils.model_parser import ModelParser -from pynestml.symbols.predefined_functions import PredefinedFunctions -from pynestml.symbols.symbol import SymbolKind +from pynestml.visitors.ast_visitor import ASTVisitor class MechsInfoEnricher: @@ -57,33 +58,6 @@ def transform_ode_solutions(cls, neuron, mechs_info): solution_transformed["states"] = defaultdict() solution_transformed["propagators"] = defaultdict() - for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["initial_values"].items(): - variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, - SymbolKind.VARIABLE) - - expression = ModelParser.parse_expression(rhs_str) - # pretend that update expressions are in "equations" block, - # which should always be present, as synapses have been - # defined to get here - expression.update_scope(neuron.get_equations_blocks()[0].get_scope()) - expression.accept(ASTSymbolTableVisitor()) - - update_expr_str = ode_info["ode_toolbox_output"][ode_solution_index]["update_expressions"][ - variable_name] - update_expr_ast = ModelParser.parse_expression( - update_expr_str) - # pretend that update expressions are in "equations" block, - # which should always be present, as differential equations - # must have been defined to get here - update_expr_ast.update_scope( - neuron.get_equations_blocks()[0].get_scope()) - update_expr_ast.accept(ASTSymbolTableVisitor()) - - solution_transformed["states"][variable_name] = { - "ASTVariable": variable, - "init_expression": expression, - "update_expression": update_expr_ast, - } for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["propagators"].items(): prop_variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, SymbolKind.VARIABLE) @@ -118,6 +92,36 @@ def transform_ode_solutions(cls, neuron, mechs_info): PredefinedFunctions.TIME_RESOLUTION: mechanism_info["time_resolution_var"] = variable + for variable_name, rhs_str in ode_info["ode_toolbox_output"][ode_solution_index]["initial_values"].items(): + variable = neuron.get_equations_blocks()[0].get_scope().resolve_to_symbol(variable_name, + SymbolKind.VARIABLE) + + expression = ModelParser.parse_expression(rhs_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as synapses have been + # defined to get here + expression.update_scope(neuron.get_equations_blocks()[0].get_scope()) + expression.accept(ASTSymbolTableVisitor()) + + update_expr_str = ode_info["ode_toolbox_output"][ode_solution_index]["update_expressions"][ + variable_name] + update_expr_ast = ModelParser.parse_expression( + update_expr_str) + # pretend that update expressions are in "equations" block, + # which should always be present, as differential equations + # must have been defined to get here + update_expr_ast.update_scope( + neuron.get_scope()) + update_expr_ast.accept(ASTParentVisitor()) + update_expr_ast.accept(ASTSymbolTableVisitor()) + neuron.accept(ASTSymbolTableVisitor()) + + solution_transformed["states"][variable_name] = { + "ASTVariable": variable, + "init_expression": expression, + "update_expression": update_expr_ast, + } + mechanism_info["ODEs"][ode_var_name]["transformed_solutions"].append(solution_transformed) neuron.accept(ASTParentVisitor()) diff --git a/pynestml/utils/messages.py b/pynestml/utils/messages.py index 60ef0ce09..90efd06b7 100644 --- a/pynestml/utils/messages.py +++ b/pynestml/utils/messages.py @@ -18,11 +18,15 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from enum import Enum + +from __future__ import annotations + from typing import Tuple -from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from collections.abc import Iterable +from enum import Enum + +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression from pynestml.meta_model.ast_function import ASTFunction @@ -159,8 +163,8 @@ def get_input_path_not_found(cls, path): return MessageCode.INPUT_PATH_NOT_FOUND, message @classmethod - def get_unknown_target(cls, target): - message = 'Unknown target ("%s")' % (target) + def get_unknown_target_platform(cls, target: str): + message = "Unknown target: '" + target + "'" return MessageCode.UNKNOWN_TARGET, message @classmethod @@ -314,22 +318,13 @@ def get_different_type_rhs_lhs( return MessageCode.CAST_NOT_POSSIBLE, message @classmethod - def get_type_different_from_expected(cls, expected_type, got_type): + def get_type_different_from_expected(cls, expected_type, got_type) -> Tuple[MessageCode, str]: """ Returns a message indicating that the received type is different from the expected one. :param expected_type: the expected type - :type expected_type: TypeSymbol :param got_type: the actual type - :type got_type: type_symbol :return: a message - :rtype: (MessageCode,str) """ - from pynestml.symbols.type_symbol import TypeSymbol - assert (expected_type is not None and isinstance(expected_type, TypeSymbol)), \ - '(PyNestML.Utils.Message) Not a type symbol provided (%s)!' % type( - expected_type) - assert (got_type is not None and isinstance(got_type, TypeSymbol)), \ - '(PyNestML.Utils.Message) Not a type symbol provided (%s)!' % type(got_type) message = 'Actual type different from expected. Expected: \'%s\', got: \'%s\'!' % ( expected_type.print_symbol(), got_type.print_symbol()) return MessageCode.TYPE_DIFFERENT_FROM_EXPECTED, message @@ -431,11 +426,10 @@ def get_module_generated(cls, path: str) -> Tuple[MessageCode, str]: return MessageCode.MODULE_SUCCESSFULLY_GENERATED, message @classmethod - def get_variable_used_before_declaration(cls, variable_name): + def get_variable_used_before_declaration(cls, variable_name: str): """ Returns a message indicating that a variable is used before declaration. :param variable_name: a variable name - :type variable_name: str :return: a message :rtype: (MessageCode,str) """ @@ -702,7 +696,7 @@ def get_model_redeclared(cls, name: str) -> Tuple[MessageCode, str]: '(PyNestML.Utils.Message) Not a string provided (%s)!' % type(name) assert (name is not None and isinstance(name, str)), \ '(PyNestML.Utils.Message) Not a string provided (%s)!' % type(name) - message = 'model \'%s\' redeclared!' % name + message = 'Model \'%s\' redeclared!' % name return MessageCode.MODEL_REDECLARED, message @classmethod diff --git a/pynestml/utils/model_parser.py b/pynestml/utils/model_parser.py index 7fabf361e..d11618119 100644 --- a/pynestml/utils/model_parser.py +++ b/pynestml/utils/model_parser.py @@ -24,6 +24,7 @@ from antlr4 import CommonTokenStream, FileStream, InputStream from antlr4.error.ErrorStrategy import BailErrorStrategy, DefaultErrorStrategy from antlr4.error.ErrorListener import ConsoleErrorListener +from pynestml.cocos.co_cos_manager import CoCosManager from pynestml.generated.PyNestMLLexer import PyNestMLLexer from pynestml.generated.PyNestMLParser import PyNestMLParser @@ -69,6 +70,7 @@ from pynestml.utils.error_listener import NestMLErrorListener from pynestml.utils.logger import Logger, LoggingLevel from pynestml.utils.messages import Messages +from pynestml.visitors.assign_implicit_conversion_factors_visitor import AssignImplicitConversionFactorsVisitor from pynestml.visitors.ast_builder_visitor import ASTBuilderVisitor from pynestml.visitors.ast_higher_order_visitor import ASTHigherOrderVisitor from pynestml.visitors.ast_parent_visitor import ASTParentVisitor @@ -142,10 +144,14 @@ def parse_file(cls, file_path=None): for model in ast.get_model_list(): model.accept(ASTSymbolTableVisitor()) SymbolTable.add_model_scope(model.get_name(), model.get_scope()) + Logger.set_current_node(model) + model.accept(AssignImplicitConversionFactorsVisitor()) + Logger.set_current_node(None) # store source paths for model in ast.get_model_list(): model.file_path = file_path + ast.file_path = file_path return ast diff --git a/pynestml/utils/type_caster.py b/pynestml/utils/type_caster.py index 34e4e6ccc..2f7827bad 100644 --- a/pynestml/utils/type_caster.py +++ b/pynestml/utils/type_caster.py @@ -28,12 +28,11 @@ class TypeCaster: @staticmethod def do_magnitude_conversion_rhs_to_lhs(_rhs_type_symbol, _lhs_type_symbol, _containing_expression): """ - determine conversion factor from rhs to lhs, register it with the relevant expression + Determine conversion factor from rhs to lhs, register it with the relevant expression """ _containing_expression.set_implicit_conversion_factor( - UnitTypeSymbol.get_conversion_factor(_lhs_type_symbol.astropy_unit, - _rhs_type_symbol.astropy_unit)) - _containing_expression.type = _lhs_type_symbol + UnitTypeSymbol.get_conversion_factor(_rhs_type_symbol.astropy_unit, + _lhs_type_symbol.astropy_unit)) code, message = Messages.get_implicit_magnitude_conversion(_lhs_type_symbol, _rhs_type_symbol, _containing_expression.get_implicit_conversion_factor()) Logger.log_message(code=code, message=message, @@ -41,22 +40,30 @@ def do_magnitude_conversion_rhs_to_lhs(_rhs_type_symbol, _lhs_type_symbol, _cont log_level=LoggingLevel.INFO) @staticmethod - def try_to_recover_or_error(_lhs_type_symbol, _rhs_type_symbol, _containing_expression): + def try_to_recover_or_error(_lhs_type_symbol, _rhs_type_symbol, _containing_expression, set_implicit_conversion_factor_on_lhs=False): if _rhs_type_symbol.is_castable_to(_lhs_type_symbol): if isinstance(_lhs_type_symbol, UnitTypeSymbol) \ and isinstance(_rhs_type_symbol, UnitTypeSymbol): - conversion_factor = UnitTypeSymbol.get_conversion_factor( - _lhs_type_symbol.astropy_unit, _rhs_type_symbol.astropy_unit) - if not conversion_factor == 1.: + conversion_factor = UnitTypeSymbol.get_conversion_factor(_rhs_type_symbol.astropy_unit, _lhs_type_symbol.astropy_unit) + + if conversion_factor is None: + # error during conversion + code, message = Messages.get_type_different_from_expected(_lhs_type_symbol, _rhs_type_symbol) + Logger.log_message(error_position=_containing_expression.get_source_position(), + code=code, message=message, log_level=LoggingLevel.ERROR) + return + + if set_implicit_conversion_factor_on_lhs and not conversion_factor == 1.: # the units are mutually convertible, but require a factor unequal to 1 (e.g. mV and A*Ohm) - TypeCaster.do_magnitude_conversion_rhs_to_lhs( - _rhs_type_symbol, _lhs_type_symbol, _containing_expression) + TypeCaster.do_magnitude_conversion_rhs_to_lhs(_rhs_type_symbol, _lhs_type_symbol, _containing_expression) + # the units are mutually convertible (e.g. V and A*Ohm) code, message = Messages.get_implicit_cast_rhs_to_lhs(_rhs_type_symbol.print_symbol(), _lhs_type_symbol.print_symbol()) Logger.log_message(error_position=_containing_expression.get_source_position(), code=code, message=message, log_level=LoggingLevel.INFO) - else: - code, message = Messages.get_type_different_from_expected(_lhs_type_symbol, _rhs_type_symbol) - Logger.log_message(error_position=_containing_expression.get_source_position(), - code=code, message=message, log_level=LoggingLevel.ERROR) + return + + code, message = Messages.get_type_different_from_expected(_lhs_type_symbol, _rhs_type_symbol) + Logger.log_message(error_position=_containing_expression.get_source_position(), + code=code, message=message, log_level=LoggingLevel.ERROR) diff --git a/pynestml/visitors/assign_implicit_conversion_factors_visitor.py b/pynestml/visitors/assign_implicit_conversion_factors_visitor.py new file mode 100644 index 000000000..0fe4b93a7 --- /dev/null +++ b/pynestml/visitors/assign_implicit_conversion_factors_visitor.py @@ -0,0 +1,326 @@ +# -*- coding: utf-8 -*- +# +# assign_implicit_conversion_factors_visitor.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from typing import Sequence, Union + +from pynestml.meta_model.ast_compound_stmt import ASTCompoundStmt +from pynestml.meta_model.ast_declaration import ASTDeclaration +from pynestml.meta_model.ast_inline_expression import ASTInlineExpression +from pynestml.meta_model.ast_model import ASTModel +from pynestml.meta_model.ast_node import ASTNode +from pynestml.meta_model.ast_small_stmt import ASTSmallStmt +from pynestml.meta_model.ast_stmt import ASTStmt +from pynestml.symbols.error_type_symbol import ErrorTypeSymbol +from pynestml.symbols.predefined_types import PredefinedTypes +from pynestml.symbols.symbol import SymbolKind +from pynestml.symbols.template_type_symbol import TemplateTypeSymbol +from pynestml.symbols.variadic_type_symbol import VariadicTypeSymbol +from pynestml.utils.ast_source_location import ASTSourceLocation +from pynestml.utils.ast_utils import ASTUtils +from pynestml.utils.logger import LoggingLevel, Logger +from pynestml.utils.logging_helper import LoggingHelper +from pynestml.utils.messages import Messages +from pynestml.utils.type_caster import TypeCaster +from pynestml.visitors.ast_visitor import ASTVisitor + + +class AssignImplicitConversionFactorsVisitor(ASTVisitor): + r""" + Assign implicit conversion factors in expressions. + """ + + def visit_model(self, model: ASTModel): + self.__assign_return_types(model) + + def visit_declaration(self, node): + """ + Visits a single declaration and asserts that type of lhs is equal to type of rhs. + :param node: a single declaration. + :type node: ASTDeclaration + """ + assert isinstance(node, ASTDeclaration) + if node.has_expression(): + if node.get_expression().get_source_position().equals(ASTSourceLocation.get_added_source_position()): + # no type checks are executed for added nodes, since we assume correctness + return + lhs_type = node.get_data_type().get_type_symbol() + rhs_type = node.get_expression().type + if isinstance(rhs_type, ErrorTypeSymbol): + LoggingHelper.drop_missing_type_error(node) + return + if self.__types_do_not_match(lhs_type, rhs_type): + TypeCaster.try_to_recover_or_error(lhs_type, rhs_type, node.get_expression(), + set_implicit_conversion_factor_on_lhs=True) + + def visit_inline_expression(self, node): + """ + Visits a single inline expression and asserts that type of lhs is equal to type of rhs. + """ + assert isinstance(node, ASTInlineExpression) + lhs_type = node.get_data_type().get_type_symbol() + rhs_type = node.get_expression().type + if isinstance(rhs_type, ErrorTypeSymbol): + LoggingHelper.drop_missing_type_error(node) + return + + if self.__types_do_not_match(lhs_type, rhs_type): + TypeCaster.try_to_recover_or_error(lhs_type, rhs_type, node.get_expression(), + set_implicit_conversion_factor_on_lhs=True) + + def visit_assignment(self, node): + """ + Visits a single expression and assures that type(lhs) == type(rhs). + :param node: a single assignment. + :type node: ASTAssignment + """ + from pynestml.meta_model.ast_assignment import ASTAssignment + assert isinstance(node, ASTAssignment) + + if node.get_source_position().equals(ASTSourceLocation.get_added_source_position()): + # no type checks are executed for added nodes, since we assume correctness + return + if node.is_direct_assignment: # case a = b is simple + self.handle_simple_assignment(node) + else: + self.handle_compound_assignment(node) # e.g. a *= b + + def handle_compound_assignment(self, node): + rhs_expr = node.get_expression() + lhs_variable_symbol = node.get_variable().resolve_in_own_scope() + rhs_type_symbol = rhs_expr.type + + if lhs_variable_symbol is None: + code, message = Messages.get_equation_var_not_in_state_block(node.get_variable().get_complete_name()) + Logger.log_message(code=code, message=message, error_position=node.get_source_position(), + log_level=LoggingLevel.ERROR) + return + + if isinstance(rhs_type_symbol, ErrorTypeSymbol): + LoggingHelper.drop_missing_type_error(node) + return + + lhs_type_symbol = lhs_variable_symbol.get_type_symbol() + + if node.is_compound_product: + if self.__types_do_not_match(lhs_type_symbol, lhs_type_symbol * rhs_type_symbol): + TypeCaster.try_to_recover_or_error(lhs_type_symbol, lhs_type_symbol * rhs_type_symbol, + node.get_expression(), + set_implicit_conversion_factor_on_lhs=True) + return + return + + if node.is_compound_quotient: + if self.__types_do_not_match(lhs_type_symbol, lhs_type_symbol / rhs_type_symbol): + TypeCaster.try_to_recover_or_error(lhs_type_symbol, lhs_type_symbol / rhs_type_symbol, + node.get_expression(), + set_implicit_conversion_factor_on_lhs=True) + return + return + + assert node.is_compound_sum or node.is_compound_minus + if self.__types_do_not_match(lhs_type_symbol, rhs_type_symbol): + TypeCaster.try_to_recover_or_error(lhs_type_symbol, rhs_type_symbol, + node.get_expression(), + set_implicit_conversion_factor_on_lhs=True) + + @staticmethod + def __types_do_not_match(lhs_type_symbol, rhs_type_symbol): + if lhs_type_symbol is None: + return True + + return not lhs_type_symbol.equals(rhs_type_symbol) + + def handle_simple_assignment(self, node): + from pynestml.symbols.symbol import SymbolKind + lhs_variable_symbol = node.get_scope().resolve_to_symbol(node.get_variable().get_complete_name(), + SymbolKind.VARIABLE) + + rhs_type_symbol = node.get_expression().type + if isinstance(rhs_type_symbol, ErrorTypeSymbol): + LoggingHelper.drop_missing_type_error(node) + return + + if lhs_variable_symbol is not None and self.__types_do_not_match(lhs_variable_symbol.get_type_symbol(), + rhs_type_symbol): + TypeCaster.try_to_recover_or_error(lhs_variable_symbol.get_type_symbol(), rhs_type_symbol, + node.get_expression(), + set_implicit_conversion_factor_on_lhs=True) + + def visit_function_call(self, node): + """ + Check consistency for a single function call: check if the called function has been declared, whether the number and types of arguments correspond to the declaration, etc. + + :param node: a single function call. + :type node: ASTFunctionCall + """ + func_name = node.get_name() + + if func_name == 'convolve': + return + + symbol = node.get_scope().resolve_to_symbol(node.get_name(), SymbolKind.FUNCTION) + + if symbol is None and ASTUtils.is_function_delay_variable(node): + return + + # first check if the function has been declared + if symbol is None: + code, message = Messages.get_function_not_declared(node.get_name()) + Logger.log_message(error_position=node.get_source_position(), log_level=LoggingLevel.ERROR, + code=code, message=message) + return + + # check if the number of arguments is the same as in the symbol; accept anything for variadic types + is_variadic: bool = len(symbol.get_parameter_types()) == 1 and isinstance(symbol.get_parameter_types()[0], VariadicTypeSymbol) + if (not is_variadic) and len(node.get_args()) != len(symbol.get_parameter_types()): + code, message = Messages.get_wrong_number_of_args(str(node), len(symbol.get_parameter_types()), + len(node.get_args())) + Logger.log_message(code=code, message=message, log_level=LoggingLevel.ERROR, + error_position=node.get_source_position()) + return + + # finally check if the call is correctly typed + expected_types = symbol.get_parameter_types() + actual_args = node.get_args() + actual_types = [arg.type for arg in actual_args] + for actual_arg, actual_type, expected_type in zip(actual_args, actual_types, expected_types): + if isinstance(actual_type, ErrorTypeSymbol): + code, message = Messages.get_type_could_not_be_derived(actual_arg) + Logger.log_message(code=code, message=message, log_level=LoggingLevel.ERROR, + error_position=actual_arg.get_source_position()) + return + + if isinstance(expected_type, VariadicTypeSymbol): + # variadic type symbol accepts anything + return + + if not actual_type.equals(expected_type) and not isinstance(expected_type, TemplateTypeSymbol): + TypeCaster.try_to_recover_or_error(expected_type, actual_type, actual_arg, + set_implicit_conversion_factor_on_lhs=True) + + def __assign_return_types(self, _node): + for userDefinedFunction in _node.get_functions(): + symbol = userDefinedFunction.get_scope().resolve_to_symbol(userDefinedFunction.get_name(), + SymbolKind.FUNCTION) + # first ensure that the block contains at least one statement + if symbol is not None and len(userDefinedFunction.get_block().get_stmts()) > 0: + # now check that the last statement is a return + self.__check_return_recursively(userDefinedFunction, + symbol.get_return_type(), + userDefinedFunction.get_block().get_stmts(), + False) + # now if it does not have a statement, but uses a return type, it is an error + elif symbol is not None and userDefinedFunction.has_return_type() and \ + not symbol.get_return_type().equals(PredefinedTypes.get_void_type()): + code, message = Messages.get_no_return() + Logger.log_message(node=_node, code=code, message=message, + error_position=userDefinedFunction.get_source_position(), + log_level=LoggingLevel.ERROR) + + def __check_return_recursively(self, processed_function, type_symbol=None, stmts=None, ret_defined: bool = False) -> None: + """ + For a handed over statement, it checks if the statement is a return statement and if it is typed according to the handed over type symbol. + :param type_symbol: a single type symbol + :type type_symbol: type_symbol + :param stmts: a list of statements, either simple or compound + :type stmts: list(ASTSmallStmt,ASTCompoundStmt) + :param ret_defined: indicates whether a ret has already been defined after this block of stmt, thus is not + necessary. Implies that the return has been defined in the higher level block + """ + # in order to ensure that in the sub-blocks, a return is not necessary, we check if the last one in this + # block is a return statement, thus it is not required to have a return in the sub-blocks, but optional + last_statement = stmts[len(stmts) - 1] + ret_defined = False or ret_defined + if (len(stmts) > 0 and isinstance(last_statement, ASTStmt) + and last_statement.is_small_stmt() + and last_statement.small_stmt.is_return_stmt()): + ret_defined = True + + # now check that returns are there if necessary and correctly typed + for c_stmt in stmts: + if c_stmt.is_small_stmt(): + stmt = c_stmt.small_stmt + else: + stmt = c_stmt.compound_stmt + + # if it is a small statement, check if it is a return statement + if isinstance(stmt, ASTSmallStmt) and stmt.is_return_stmt(): + # first check if the return is the last one in this block of statements + if stmts.index(c_stmt) != (len(stmts) - 1): + code, message = Messages.get_not_last_statement('Return') + Logger.log_message(error_position=stmt.get_source_position(), + code=code, message=message, + log_level=LoggingLevel.WARNING) + + # now check that it corresponds to the declared type + if stmt.get_return_stmt().has_expression() and type_symbol is PredefinedTypes.get_void_type(): + code, message = Messages.get_type_different_from_expected(PredefinedTypes.get_void_type(), + stmt.get_return_stmt().get_expression().type) + Logger.log_message(error_position=stmt.get_source_position(), + message=message, code=code, log_level=LoggingLevel.ERROR) + + # if it is not void check if the type corresponds to the one stated + if not stmt.get_return_stmt().has_expression() and \ + not type_symbol.equals(PredefinedTypes.get_void_type()): + code, message = Messages.get_type_different_from_expected(PredefinedTypes.get_void_type(), + type_symbol) + Logger.log_message(error_position=stmt.get_source_position(), + message=message, code=code, log_level=LoggingLevel.ERROR) + + if stmt.get_return_stmt().has_expression(): + type_of_return = stmt.get_return_stmt().get_expression().type + if isinstance(type_of_return, ErrorTypeSymbol): + code, message = Messages.get_type_could_not_be_derived(processed_function.get_name()) + Logger.log_message(error_position=stmt.get_source_position(), + code=code, message=message, log_level=LoggingLevel.ERROR) + elif not type_of_return.equals(type_symbol): + TypeCaster.try_to_recover_or_error(type_symbol, type_of_return, + stmt.get_return_stmt().get_expression(), + set_implicit_conversion_factor_on_lhs=True) + elif isinstance(stmt, ASTCompoundStmt): + # otherwise it is a compound stmt, thus check recursively + if stmt.is_if_stmt(): + self.__check_return_recursively(processed_function, + type_symbol, + stmt.get_if_stmt().get_if_clause().get_block().get_stmts(), + ret_defined) + for else_ifs in stmt.get_if_stmt().get_elif_clauses(): + self.__check_return_recursively(processed_function, + type_symbol, else_ifs.get_block().get_stmts(), ret_defined) + if stmt.get_if_stmt().has_else_clause(): + self.__check_return_recursively(processed_function, + type_symbol, + stmt.get_if_stmt().get_else_clause().get_block().get_stmts(), + ret_defined) + elif stmt.is_while_stmt(): + self.__check_return_recursively(processed_function, + type_symbol, stmt.get_while_stmt().get_block().get_stmts(), + ret_defined) + elif stmt.is_for_stmt(): + self.__check_return_recursively(processed_function, + type_symbol, stmt.get_for_stmt().get_block().get_stmts(), + ret_defined) + # now, if a return statement has not been defined in the corresponding higher level block, we have to ensure that it is defined here + elif not ret_defined and stmts.index(c_stmt) == (len(stmts) - 1): + if not (isinstance(stmt, ASTSmallStmt) and stmt.is_return_stmt()): + code, message = Messages.get_no_return() + Logger.log_message(error_position=stmt.get_source_position(), log_level=LoggingLevel.ERROR, + code=code, message=message) diff --git a/pynestml/visitors/ast_builder_visitor.py b/pynestml/visitors/ast_builder_visitor.py index 0e766d530..bfc4dd902 100644 --- a/pynestml/visitors/ast_builder_visitor.py +++ b/pynestml/visitors/ast_builder_visitor.py @@ -52,16 +52,17 @@ def visitNestMLCompilationUnit(self, ctx): models = list() for child in ctx.model(): models.append(self.visit(child)) + # extract the name of the artifact from the context if hasattr(ctx.start.source[1], 'fileName'): artifact_name = ntpath.basename(ctx.start.source[1].fileName) else: artifact_name = 'parsed_from_string' + compilation_unit = ASTNodeFactory.create_ast_nestml_compilation_unit(list_of_models=models, source_position=create_source_pos(ctx), artifact_name=artifact_name) - # first ensure certain properties of the model - CoCosManager.check_model_names_unique(compilation_unit) + return compilation_unit # Visit a parse tree produced by PyNESTMLParser#datatype. @@ -387,15 +388,6 @@ def visitDeclaration(self, ctx): expression = self.visit(ctx.rhs) if ctx.rhs is not None else None invariant = self.visit(ctx.invariant) if ctx.invariant is not None else None - # print("Visiting variable \"" + str(str(ctx.NAME())) + "\"...") - # # check if this variable was decorated as homogeneous - # import pynestml.generated.PyNestMLLexer - # is_homogeneous = any([isinstance(ch, pynestml.generated.PyNestMLParser.PyNestMLParser.AnyDecoratorContext) \ - # and len(ch.getTokens(pynestml.generated.PyNestMLLexer.PyNestMLLexer.DECORATOR_HOMOGENEOUS)) > 0 \ - # for ch in ctx.parentCtx.children]) - # if is_homogeneous: - # print("\t----> is homogeneous") - declaration = ASTNodeFactory.create_ast_declaration(is_recordable=is_recordable, variables=variables, data_type=data_type, diff --git a/pynestml/visitors/ast_function_call_visitor.py b/pynestml/visitors/ast_function_call_visitor.py index 7d7bf75c4..e4ec8650e 100644 --- a/pynestml/visitors/ast_function_call_visitor.py +++ b/pynestml/visitors/ast_function_call_visitor.py @@ -94,7 +94,6 @@ def visit_simple_expression(self, node: ASTSimpleExpression) -> None: # return type of the convolve function is the type of the second parameter multiplied by the unit of time (s) if function_name == PredefinedFunctions.CONVOLVE: - # Deviations from the assumptions made here are handled in the convolveCoco buffer_parameter = node.get_function_call().get_args()[1] if buffer_parameter.get_variable() is not None: diff --git a/pynestml/visitors/ast_symbol_table_visitor.py b/pynestml/visitors/ast_symbol_table_visitor.py index 011182543..bc85d4cdd 100644 --- a/pynestml/visitors/ast_symbol_table_visitor.py +++ b/pynestml/visitors/ast_symbol_table_visitor.py @@ -19,7 +19,6 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . -from pynestml.cocos.co_cos_manager import CoCosManager from pynestml.meta_model.ast_model import ASTModel from pynestml.meta_model.ast_model_body import ASTModelBody from pynestml.meta_model.ast_namespace_decorator import ASTNamespaceDecorator @@ -53,7 +52,6 @@ def __init__(self): self.symbol_stack = Stack() self.scope_stack = Stack() self.block_type_stack = Stack() - self.after_ast_rewrite_ = False def visit_model(self, node: ASTModel) -> None: """ @@ -79,10 +77,6 @@ def visit_model(self, node: ASTModel) -> None: node.get_scope().add_symbol(types[symbol]) def endvisit_model(self, node: ASTModel): - # before following checks occur, we need to ensure several simple properties - CoCosManager.post_symbol_table_builder_checks( - node, after_ast_rewrite=self.after_ast_rewrite_) - # update the equations for equation_block in node.get_equations_blocks(): ASTUtils.assign_ode_to_variables(equation_block) @@ -287,8 +281,7 @@ def visit_declaration(self, node: ASTDeclaration) -> None: namespace_decorators = {} for d in node.get_decorators(): if isinstance(d, ASTNamespaceDecorator): - namespace_decorators[str(d.get_namespace())] = str( - d.get_name()) + namespace_decorators[str(d.get_namespace())] = str(d.get_name()) else: decorators.append(d) @@ -296,6 +289,7 @@ def visit_declaration(self, node: ASTDeclaration) -> None: block_type = None if not self.block_type_stack.is_empty(): block_type = self.block_type_stack.top() + for var in node.get_variables(): # for all variables declared create a new symbol var.update_scope(node.get_scope()) @@ -324,11 +318,14 @@ def visit_declaration(self, node: ASTDeclaration) -> None: symbol.set_comment(node.get_comment()) node.get_scope().add_symbol(symbol) var.set_type_symbol(type_symbol) + # the data type node.get_data_type().update_scope(node.get_scope()) + # the rhs update if node.has_expression(): node.get_expression().update_scope(node.get_scope()) + # the invariant update if node.has_invariant(): node.get_invariant().update_scope(node.get_scope()) diff --git a/tests/cocos_test.py b/tests/cocos_test.py deleted file mode 100644 index f557faaf0..000000000 --- a/tests/cocos_test.py +++ /dev/null @@ -1,698 +0,0 @@ -# -*- coding: utf-8 -*- -# -# cocos_test.py -# -# This file is part of NEST. -# -# Copyright (C) 2004 The NEST Initiative -# -# NEST is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 2 of the License, or -# (at your option) any later version. -# -# NEST is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with NEST. If not, see . - -from __future__ import print_function - -import os -import unittest - -from pynestml.utils.ast_source_location import ASTSourceLocation -from pynestml.symbol_table.symbol_table import SymbolTable -from pynestml.symbols.predefined_functions import PredefinedFunctions -from pynestml.symbols.predefined_types import PredefinedTypes -from pynestml.symbols.predefined_units import PredefinedUnits -from pynestml.symbols.predefined_variables import PredefinedVariables -from pynestml.utils.logger import LoggingLevel, Logger -from pynestml.utils.model_parser import ModelParser - - -class CoCosTest(unittest.TestCase): - - def setUp(self): - Logger.init_logger(LoggingLevel.INFO) - SymbolTable.initialize_symbol_table( - ASTSourceLocation( - start_line=0, - start_column=0, - end_line=0, - end_column=0)) - PredefinedUnits.register_units() - PredefinedTypes.register_types() - PredefinedVariables.register_variables() - PredefinedFunctions.register_functions() - - def test_invalid_element_defined_after_usage(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVariableDefinedAfterUsage.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_element_defined_after_usage(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVariableDefinedAfterUsage.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_element_in_same_line(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoElementInSameLine.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_element_in_same_line(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoElementInSameLine.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_integrate_odes_called_if_equations_defined(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_integrate_odes_called_if_equations_defined(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_element_not_defined_in_scope(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVariableNotDefined.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)), 5) - - def test_valid_element_not_defined_in_scope(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVariableNotDefined.nestml')) - self.assertEqual( - len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), - 0) - - def test_variable_with_same_name_as_unit(self): - Logger.set_logging_level(LoggingLevel.NO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVariableWithSameNameAsUnit.nestml')) - self.assertEqual( - len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)), - 3) - - def test_invalid_variable_redeclaration(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVariableRedeclared.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_variable_redeclaration(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVariableRedeclared.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_each_block_unique(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoEachBlockUnique.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2) - - def test_valid_each_block_unique(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoEachBlockUnique.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_function_unique_and_defined(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoFunctionNotUnique.nestml')) - self.assertEqual( - len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 5) - - def test_valid_function_unique_and_defined(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoFunctionNotUnique.nestml')) - self.assertEqual( - len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_inline_expressions_have_rhs(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoInlineExpressionHasNoRhs.nestml')) - assert model is None - - def test_valid_inline_expressions_have_rhs(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoInlineExpressionHasNoRhs.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_inline_expression_has_several_lhs(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoInlineExpressionWithSeveralLhs.nestml')) - assert model is None - - def test_valid_inline_expression_has_several_lhs(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoInlineExpressionWithSeveralLhs.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_no_values_assigned_to_input_ports(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoValueAssignedToInputPort.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_no_values_assigned_to_input_ports(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoValueAssignedToInputPort.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_order_of_equations_correct(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoNoOrderOfEquations.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2) - - def test_valid_order_of_equations_correct(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoNoOrderOfEquations.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_numerator_of_unit_one(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoUnitNumeratorNotOne.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)), 2) - - def test_valid_numerator_of_unit_one(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoUnitNumeratorNotOne.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_names_of_neurons_unique(self): - Logger.init_logger(LoggingLevel.INFO) - ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoMultipleNeuronsWithEqualName.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(None, LoggingLevel.ERROR)), 1) - - def test_valid_names_of_neurons_unique(self): - Logger.init_logger(LoggingLevel.INFO) - ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoMultipleNeuronsWithEqualName.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(None, LoggingLevel.ERROR)), 0) - - def test_invalid_no_nest_collision(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoNestNamespaceCollision.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_no_nest_collision(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoNestNamespaceCollision.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_redundant_input_port_keywords_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoInputPortWithRedundantTypes.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_redundant_input_port_keywords_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoInputPortWithRedundantTypes.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_parameters_assigned_only_in_parameters_block(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoParameterAssignedOutsideBlock.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_parameters_assigned_only_in_parameters_block(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoParameterAssignedOutsideBlock.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_inline_expressions_assigned_only_in_declaration(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoAssignmentToInlineExpression.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_invalid_internals_assigned_only_in_internals_block(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoInternalAssignedOutsideBlock.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_internals_assigned_only_in_internals_block(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoInternalAssignedOutsideBlock.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_function_with_wrong_arg_number_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_function_with_wrong_arg_number_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_init_values_have_rhs_and_ode(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoInitValuesWithoutOde.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)), 2) - - def test_valid_init_values_have_rhs_and_ode(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoInitValuesWithoutOde.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)), 2) - - def test_invalid_incorrect_return_stmt_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoIncorrectReturnStatement.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 4) - - def test_valid_incorrect_return_stmt_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoIncorrectReturnStatement.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_ode_vars_outside_init_block_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoOdeVarNotInInitialValues.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_ode_vars_outside_init_block_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoOdeVarNotInInitialValues.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_convolve_correctly_defined(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoConvolveNotCorrectlyProvided.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)), 3) - - def test_valid_convolve_correctly_defined(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoConvolveNotCorrectlyProvided.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_vector_in_non_vector_declaration_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVectorInNonVectorDeclaration.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_vector_in_non_vector_declaration_detected(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVectorInNonVectorDeclaration.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_vector_parameter_declaration(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVectorParameterDeclaration.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_vector_parameter_declaration(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVectorParameterDeclaration.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_vector_parameter_type(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVectorParameterType.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_vector_parameter_type(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVectorParameterType.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_vector_parameter_size(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVectorDeclarationSize.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2) - - def test_valid_vector_parameter_size(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVectorDeclarationSize.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_convolve_correctly_parameterized(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoConvolveNotCorrectlyParametrized.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2) - - def test_valid_convolve_correctly_parameterized(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoConvolveNotCorrectlyParametrized.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)), 0) - - def test_invalid_invariant_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoInvariantNotBool.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_invariant_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoInvariantNotBool.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_expression_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoIllegalExpression.nestml')) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)), 6) - - def test_valid_expression_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoIllegalExpression.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_compound_expression_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CompoundOperatorWithDifferentButCompatibleUnits.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 5) - - def test_valid_compound_expression_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CompoundOperatorWithDifferentButCompatibleUnits.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_ode_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoOdeIncorrectlyTyped.nestml')) - self.assertTrue(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)) > 0) - - def test_valid_ode_correctly_typed(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoOdeCorrectlyTyped.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_output_block_defined_if_emit_call(self): - """test that an error is raised when the emit_spike() function is called by the neuron, but an output block is not defined""" - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoOutputPortDefinedIfEmitCall.nestml')) - self.assertTrue(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)) > 0) - - def test_invalid_output_port_defined_if_emit_call(self): - """test that an error is raised when the emit_spike() function is called by the neuron, but a spiking output port is not defined""" - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoOutputPortDefinedIfEmitCall-2.nestml')) - self.assertTrue(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)) > 0) - - def test_valid_output_port_defined_if_emit_call(self): - """test that no error is raised when the output block is missing, but not emit_spike() functions are called""" - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoOutputPortDefinedIfEmitCall.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_valid_coco_kernel_type(self): - """ - Test the functionality of CoCoKernelType. - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoKernelType.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_coco_kernel_type(self): - """ - Test the functionality of CoCoKernelType. - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoKernelType.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_invalid_coco_kernel_type_initial_values(self): - """ - Test the functionality of CoCoKernelType. - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoKernelTypeInitialValues.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 4) - - def test_valid_coco_state_variables_initialized(self): - """ - Test that the CoCo condition is applicable for all the variables in the state block initialized with a value - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoStateVariablesInitialized.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_coco_state_variables_initialized(self): - """ - Test that the CoCo condition is applicable for all the variables in the state block not initialized - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoStateVariablesInitialized.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2) - - def test_invalid_co_co_priorities_correctly_specified(self): - """ - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoPrioritiesCorrectlySpecified.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) - - def test_valid_co_co_priorities_correctly_specified(self): - """ - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoPrioritiesCorrectlySpecified.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_co_co_resolution_legally_used(self): - """ - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoResolutionLegallyUsed.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 2) - - def test_valid_co_co_resolution_legally_used(self): - """ - """ - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoResolutionLegallyUsed.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_valid_co_co_vector_input_port(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), - 'CoCoVectorInputPortSizeAndType.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - - def test_invalid_co_co_vector_input_port(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), - 'CoCoVectorInputPortSizeAndType.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 1) diff --git a/tests/function_parameter_templating_test.py b/tests/function_parameter_templating_test.py deleted file mode 100644 index e3cb89e41..000000000 --- a/tests/function_parameter_templating_test.py +++ /dev/null @@ -1,57 +0,0 @@ -# -*- coding: utf-8 -*- -# -# function_parameter_templating_test.py -# -# This file is part of NEST. -# -# Copyright (C) 2004 The NEST Initiative -# -# NEST is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 2 of the License, or -# (at your option) any later version. -# -# NEST is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with NEST. If not, see . - -import os -import unittest - -from pynestml.symbol_table.symbol_table import SymbolTable -from pynestml.symbols.predefined_functions import PredefinedFunctions -from pynestml.symbols.predefined_types import PredefinedTypes -from pynestml.symbols.predefined_units import PredefinedUnits -from pynestml.symbols.predefined_variables import PredefinedVariables -from pynestml.utils.ast_source_location import ASTSourceLocation -from pynestml.utils.logger import Logger, LoggingLevel -from pynestml.utils.model_parser import ModelParser - -# minor setup steps required -SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0)) -PredefinedUnits.register_units() -PredefinedTypes.register_types() -PredefinedVariables.register_variables() -PredefinedFunctions.register_functions() - - -class FunctionParameterTemplatingTest(unittest.TestCase): - """ - This test is used to test the correct derivation of types when functions use templated type parameters. - """ - - def test(self): - Logger.init_logger(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), - "resources", "FunctionParameterTemplatingTest.nestml")))) - self.assertEqual(len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], - LoggingLevel.ERROR)), 7) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/nest_compartmental_tests/test__cocos.py b/tests/nest_compartmental_tests/test__cocos.py index dc4daa28c..7ee55f8a1 100644 --- a/tests/nest_compartmental_tests/test__cocos.py +++ b/tests/nest_compartmental_tests/test__cocos.py @@ -21,41 +21,39 @@ from __future__ import print_function +from typing import Optional + import os import pytest -from pynestml.frontend.frontend_configuration import FrontendConfiguration - -from pynestml.utils.ast_source_location import ASTSourceLocation +from pynestml.meta_model.ast_model import ASTModel from pynestml.symbol_table.symbol_table import SymbolTable from pynestml.symbols.predefined_functions import PredefinedFunctions from pynestml.symbols.predefined_types import PredefinedTypes from pynestml.symbols.predefined_units import PredefinedUnits from pynestml.symbols.predefined_variables import PredefinedVariables +from pynestml.utils.ast_source_location import ASTSourceLocation from pynestml.utils.logger import LoggingLevel, Logger from pynestml.utils.model_parser import ModelParser -@pytest.fixture -def setUp(): - Logger.init_logger(LoggingLevel.INFO) - SymbolTable.initialize_symbol_table( - ASTSourceLocation( - start_line=0, - start_column=0, - end_line=0, - end_column=0)) - PredefinedUnits.register_units() - PredefinedTypes.register_types() - PredefinedVariables.register_variables() - PredefinedFunctions.register_functions() - FrontendConfiguration.target_platform = "NEST_COMPARTMENTAL" - - class TestCoCos: - def test_invalid_cm_variables_declared(self, setUp): - model = ModelParser.parse_file( + @pytest.fixture(scope="module", autouse=True) + def setUp(self): + SymbolTable.initialize_symbol_table( + ASTSourceLocation( + start_line=0, + start_column=0, + end_line=0, + end_column=0)) + PredefinedUnits.register_units() + PredefinedTypes.register_types() + PredefinedVariables.register_variables() + PredefinedFunctions.register_functions() + + def test_invalid_cm_variables_declared(self): + model = self._parse_and_validate_model( os.path.join( os.path.realpath( os.path.join( @@ -63,11 +61,10 @@ def test_invalid_cm_variables_declared(self, setUp): 'invalid')), 'CoCoCmVariablesDeclared.nestml')) assert len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)) == 5 + model, LoggingLevel.ERROR)) == 6 - def test_valid_cm_variables_declared(self, setUp): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( + def test_valid_cm_variables_declared(self): + model = self._parse_and_validate_model( os.path.join( os.path.realpath( os.path.join( @@ -75,12 +72,12 @@ def test_valid_cm_variables_declared(self, setUp): 'valid')), 'CoCoCmVariablesDeclared.nestml')) assert len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)) == 0 + model, LoggingLevel.ERROR)) == 0 # it is currently not enforced for the non-cm parameter block, but cm # needs that - def test_invalid_cm_variable_has_rhs(self, setUp): - model = ModelParser.parse_file( + def test_invalid_cm_variable_has_rhs(self): + model = self._parse_and_validate_model( os.path.join( os.path.realpath( os.path.join( @@ -88,11 +85,11 @@ def test_invalid_cm_variable_has_rhs(self, setUp): 'invalid')), 'CoCoCmVariableHasRhs.nestml')) assert len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)) == 2 + model, LoggingLevel.ERROR)) == 2 - def test_valid_cm_variable_has_rhs(self, setUp): + def test_valid_cm_variable_has_rhs(self): Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( + model = self._parse_and_validate_model( os.path.join( os.path.realpath( os.path.join( @@ -100,12 +97,12 @@ def test_valid_cm_variable_has_rhs(self, setUp): 'valid')), 'CoCoCmVariableHasRhs.nestml')) assert len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)) == 0 + model, LoggingLevel.ERROR)) == 0 # it is currently not enforced for the non-cm parameter block, but cm # needs that - def test_invalid_cm_v_comp_exists(self, setUp): - model = ModelParser.parse_file( + def test_invalid_cm_v_comp_exists(self): + model = self._parse_and_validate_model( os.path.join( os.path.realpath( os.path.join( @@ -113,11 +110,11 @@ def test_invalid_cm_v_comp_exists(self, setUp): 'invalid')), 'CoCoCmVcompExists.nestml')) assert len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)) == 4 + model, LoggingLevel.ERROR)) == 4 - def test_valid_cm_v_comp_exists(self, setUp): + def test_valid_cm_v_comp_exists(self): Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( + model = self._parse_and_validate_model( os.path.join( os.path.realpath( os.path.join( @@ -125,4 +122,23 @@ def test_valid_cm_v_comp_exists(self, setUp): 'valid')), 'CoCoCmVcompExists.nestml')) assert len(Logger.get_all_messages_of_level_and_or_node( - model.get_model_list()[0], LoggingLevel.ERROR)) == 0 + model, LoggingLevel.ERROR)) == 0 + + def _parse_and_validate_model(self, fname: str) -> Optional[str]: + from pynestml.frontend.pynestml_frontend import generate_target + + Logger.init_logger(LoggingLevel.DEBUG) + + try: + generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG") + except BaseException: + return None + + ast_compilation_unit = ModelParser.parse_file(fname) + if ast_compilation_unit is None or len(ast_compilation_unit.get_model_list()) == 0: + return None + + model: ASTModel = ast_compilation_unit.get_model_list()[0] + model_name = model.get_name() + + return model_name diff --git a/tests/nest_tests/nest_delay_based_variables_test.py b/tests/nest_tests/nest_delay_based_variables_test.py index 51f863e19..a11c280f2 100644 --- a/tests/nest_tests/nest_delay_based_variables_test.py +++ b/tests/nest_tests/nest_delay_based_variables_test.py @@ -19,13 +19,12 @@ # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from typing import List + import numpy as np import os -from typing import List import pytest -import nest - try: import matplotlib import matplotlib.pyplot as plt @@ -34,15 +33,12 @@ except BaseException: TEST_PLOTS = False +import nest + from pynestml.codegeneration.nest_tools import NESTTools from pynestml.frontend.pynestml_frontend import generate_nest_target -target_path = "target_delay" -logging_level = "DEBUG" -suffix = "_nestml" - - def plot_fig(times, recordable_events_delay: dict, recordable_events: dict, filename: str): fig, axes = plt.subplots(len(recordable_events), 1, figsize=(7, 9), sharex=True) for i, recordable_name in enumerate(recordable_events_delay.keys()): @@ -86,6 +82,9 @@ def run_simulation(neuron_model_name: str, module_name: str, recordables: List[s ("DelayDifferentialEquationsWithNumericSolver.nestml", "dde_numeric_nestml", ["x", "z"]), ("DelayDifferentialEquationsWithMixedSolver.nestml", "dde_mixed_nestml", ["x", "z"])]) def test_dde_with_analytic_solver(file_name: str, neuron_model_name: str, recordables: List[str]): + target_path = "target_delay" + logging_level = "DEBUG" + suffix = "_nestml" input_path = os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), "resources", file_name))) module_name = neuron_model_name + "_module" print("Module name: ", module_name) @@ -112,16 +111,3 @@ def test_dde_with_analytic_solver(file_name: str, neuron_model_name: str, record if neuron_model_name == "dde_analytic_nestml": np.testing.assert_allclose(recordable_events_delay[recordables[1]][int(delay):], recordable_events[recordables[1]][:-int(delay)]) - - @pytest.fixture(scope="function", autouse=True) - def cleanup(self): - # Run the test - yield - - # clean up - import shutil - if self.target_path: - try: - shutil.rmtree(self.target_path) - except Exception: - pass diff --git a/tests/nest_tests/non_linear_dendrite_test.py b/tests/nest_tests/non_linear_dendrite_test.py index 2da978976..42ff2e7e4 100644 --- a/tests/nest_tests/non_linear_dendrite_test.py +++ b/tests/nest_tests/non_linear_dendrite_test.py @@ -47,8 +47,6 @@ class NestNonLinearDendriteTest(unittest.TestCase): @pytest.mark.skipif(NESTTools.detect_nest_version().startswith("v2"), reason="This test does not support NEST 2") def test_non_linear_dendrite(self): - MAX_SSE = 1E-12 - I_dend_alias_name = "I_dend" # synaptic current I_dend_internal_name = "I_kernel2__X__I_2" # alias for the synaptic current diff --git a/tests/nest_tests/resources/integrate_odes_test_params.nestml b/tests/nest_tests/resources/integrate_odes_test_params.nestml index d07fe8fd4..d6430e537 100644 --- a/tests/nest_tests/resources/integrate_odes_test_params.nestml +++ b/tests/nest_tests/resources/integrate_odes_test_params.nestml @@ -8,7 +8,6 @@ model integrate_odes_test: update: integrate_odes(2 * test_1) - integrate_odes(test_3) integrate_odes(100 ms) integrate_odes(test_1) integrate_odes(test_2) diff --git a/tests/nest_tests/resources/integrate_odes_test_params2.nestml b/tests/nest_tests/resources/integrate_odes_test_params2.nestml new file mode 100644 index 000000000..616401e48 --- /dev/null +++ b/tests/nest_tests/resources/integrate_odes_test_params2.nestml @@ -0,0 +1,10 @@ +""" +Model for testing the integrate_odes() function. +""" +model integrate_odes_test: + state: + test_1 real = 0. + test_2 real = 0. + + update: + integrate_odes(test_3) diff --git a/tests/nest_tests/test_integrate_odes.py b/tests/nest_tests/test_integrate_odes.py index 99b94c6ca..6ddb699b4 100644 --- a/tests/nest_tests/test_integrate_odes.py +++ b/tests/nest_tests/test_integrate_odes.py @@ -27,16 +27,9 @@ import nest -from pynestml.utils.ast_source_location import ASTSourceLocation -from pynestml.symbol_table.symbol_table import SymbolTable -from pynestml.symbols.predefined_functions import PredefinedFunctions -from pynestml.symbols.predefined_types import PredefinedTypes -from pynestml.symbols.predefined_units import PredefinedUnits -from pynestml.symbols.predefined_variables import PredefinedVariables from pynestml.codegeneration.nest_tools import NESTTools -from pynestml.frontend.pynestml_frontend import generate_nest_target +from pynestml.frontend.pynestml_frontend import generate_nest_target, generate_target from pynestml.utils.logger import LoggingLevel, Logger -from pynestml.utils.model_parser import ModelParser try: import matplotlib @@ -227,12 +220,15 @@ def test_integrate_odes_nonlinear(self): def test_integrate_odes_params(self): r"""Test the integrate_odes() function, in particular with respect to the parameter types.""" - Logger.init_logger(LoggingLevel.INFO) - SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0)) - PredefinedUnits.register_units() - PredefinedTypes.register_types() - PredefinedVariables.register_variables() - PredefinedFunctions.register_functions() - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file(os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params.nestml")))) - assert len(Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)) == 6 + fname = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params.nestml"))) + generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG") + + assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2 + + def test_integrate_odes_params2(self): + r"""Test the integrate_odes() function, in particular with respect to non-existent parameter variables.""" + + fname = os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join("resources", "integrate_odes_test_params2.nestml"))) + generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG") + + assert len(Logger.get_all_messages_of_level_and_or_node("integrate_odes_test", LoggingLevel.ERROR)) == 2 diff --git a/tests/test_cocos.py b/tests/test_cocos.py new file mode 100644 index 000000000..731fb8d8a --- /dev/null +++ b/tests/test_cocos.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- +# +# test_cocos.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +from __future__ import print_function + +from typing import Optional + +import os +import pytest + +from pynestml.meta_model.ast_model import ASTModel +from pynestml.symbol_table.symbol_table import SymbolTable +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.predefined_types import PredefinedTypes +from pynestml.symbols.predefined_units import PredefinedUnits +from pynestml.symbols.predefined_variables import PredefinedVariables +from pynestml.utils.ast_source_location import ASTSourceLocation +from pynestml.utils.logger import LoggingLevel, Logger +from pynestml.utils.model_parser import ModelParser + + +class TestCoCos: + + @pytest.fixture(scope="module", autouse=True) + def setUp(self): + SymbolTable.initialize_symbol_table( + ASTSourceLocation( + start_line=0, + start_column=0, + end_line=0, + end_column=0)) + PredefinedUnits.register_units() + PredefinedTypes.register_types() + PredefinedVariables.register_variables() + PredefinedFunctions.register_functions() + + def test_invalid_element_defined_after_usage(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVariableDefinedAfterUsage.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_element_defined_after_usage(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableDefinedAfterUsage.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_element_in_same_line(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoElementInSameLine.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_element_in_same_line(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoElementInSameLine.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_integrate_odes_called_if_equations_defined(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_integrate_odes_called_if_equations_defined(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoIntegrateOdesCalledIfEquationsDefined.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_element_not_defined_in_scope(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVariableNotDefined.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 6 + + def test_valid_element_not_defined_in_scope(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableNotDefined.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_variable_with_same_name_as_unit(self): + Logger.set_logging_level(LoggingLevel.NO) + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableWithSameNameAsUnit.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.WARNING)) == 3 + + def test_invalid_variable_redeclaration(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVariableRedeclared.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_variable_redeclaration(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVariableRedeclared.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_each_block_unique(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoEachBlockUnique.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_each_block_unique(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoEachBlockUnique.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_function_unique_and_defined(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoFunctionNotUnique.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 8 + + def test_valid_function_unique_and_defined(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoFunctionNotUnique.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_inline_expressions_have_rhs(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInlineExpressionHasNoRhs.nestml')) + assert model is None + + def test_valid_inline_expressions_have_rhs(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInlineExpressionHasNoRhs.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_inline_expression_has_several_lhs(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInlineExpressionWithSeveralLhs.nestml')) + assert model is None + + def test_valid_inline_expression_has_several_lhs(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInlineExpressionWithSeveralLhs.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_no_values_assigned_to_input_ports(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoValueAssignedToInputPort.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_no_values_assigned_to_input_ports(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoValueAssignedToInputPort.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_order_of_equations_correct(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoNoOrderOfEquations.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_order_of_equations_correct(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoNoOrderOfEquations.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_numerator_of_unit_one(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoUnitNumeratorNotOne.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_numerator_of_unit_one(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoUnitNumeratorNotOne.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_names_of_neurons_unique(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoMultipleNeuronsWithEqualName.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 3 + + def test_valid_names_of_neurons_unique(self): + self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoMultipleNeuronsWithEqualName.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(None, LoggingLevel.ERROR)) == 0 + + def test_invalid_no_nest_collision(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoNestNamespaceCollision.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_no_nest_collision(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoNestNamespaceCollision.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_redundant_input_port_keywords_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInputPortWithRedundantTypes.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_redundant_input_port_keywords_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInputPortWithRedundantTypes.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_parameters_assigned_only_in_parameters_block(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoParameterAssignedOutsideBlock.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_parameters_assigned_only_in_parameters_block(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoParameterAssignedOutsideBlock.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_inline_expressions_assigned_only_in_declaration(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoAssignmentToInlineExpression.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_invalid_internals_assigned_only_in_internals_block(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInternalAssignedOutsideBlock.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_internals_assigned_only_in_internals_block(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInternalAssignedOutsideBlock.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_function_with_wrong_arg_number_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_function_with_wrong_arg_number_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoFunctionCallNotConsistentWrongArgNumber.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_init_values_have_rhs_and_ode(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInitValuesWithoutOde.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.WARNING)) == 2 + + def test_valid_init_values_have_rhs_and_ode(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInitValuesWithoutOde.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.WARNING)) == 3 + + def test_invalid_incorrect_return_stmt_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoIncorrectReturnStatement.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 8 + + def test_valid_incorrect_return_stmt_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoIncorrectReturnStatement.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_ode_vars_outside_init_block_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOdeVarNotInInitialValues.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_ode_vars_outside_init_block_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoOdeVarNotInInitialValues.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_convolve_correctly_defined(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoConvolveNotCorrectlyProvided.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_convolve_correctly_defined(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoConvolveNotCorrectlyProvided.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_vector_in_non_vector_declaration_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorInNonVectorDeclaration.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_vector_in_non_vector_declaration_detected(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorInNonVectorDeclaration.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_vector_parameter_declaration(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorParameterDeclaration.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_vector_parameter_declaration(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorParameterDeclaration.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_vector_parameter_type(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorParameterType.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_vector_parameter_type(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorParameterType.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_vector_parameter_size(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorDeclarationSize.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_vector_parameter_size(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorDeclarationSize.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_convolve_correctly_parameterized(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoConvolveNotCorrectlyParametrized.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_convolve_correctly_parameterized(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoConvolveNotCorrectlyParametrized.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_invariant_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoInvariantNotBool.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_invariant_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoInvariantNotBool.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_expression_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoIllegalExpression.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_expression_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoIllegalExpression.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_compound_expression_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CompoundOperatorWithDifferentButCompatibleUnits.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 10 + + def test_valid_compound_expression_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CompoundOperatorWithDifferentButCompatibleUnits.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_ode_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOdeIncorrectlyTyped.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) > 0 + + def test_valid_ode_correctly_typed(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoOdeCorrectlyTyped.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_output_block_defined_if_emit_call(self): + """test that an error is raised when the emit_spike() function is called by the neuron, but an output block is not defined""" + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOutputPortDefinedIfEmitCall.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) > 0 + + def test_invalid_output_port_defined_if_emit_call(self): + """test that an error is raised when the emit_spike() function is called by the neuron, but a spiking output port is not defined""" + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoOutputPortDefinedIfEmitCall-2.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) > 0 + + def test_valid_output_port_defined_if_emit_call(self): + """test that no error is raised when the output block is missing, but not emit_spike() functions are called""" + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoOutputPortDefinedIfEmitCall.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_valid_coco_kernel_type(self): + """ + Test the functionality of CoCoKernelType. + """ + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoKernelType.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_coco_kernel_type(self): + """ + Test the functionality of CoCoKernelType. + """ + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoKernelType.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_invalid_coco_kernel_type_initial_values(self): + """ + Test the functionality of CoCoKernelType. + """ + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoKernelTypeInitialValues.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 4 + + def test_valid_coco_state_variables_initialized(self): + """ + Test that the CoCo condition is applicable for all the variables in the state block initialized with a value + """ + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoStateVariablesInitialized.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_coco_state_variables_initialized(self): + """ + Test that the CoCo condition is applicable for all the variables in the state block not initialized + """ + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoStateVariablesInitialized.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_invalid_co_co_priorities_correctly_specified(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoPrioritiesCorrectlySpecified.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def test_valid_co_co_priorities_correctly_specified(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoPrioritiesCorrectlySpecified.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_co_co_resolution_legally_used(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoResolutionLegallyUsed.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 2 + + def test_valid_co_co_resolution_legally_used(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoResolutionLegallyUsed.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_valid_co_co_vector_input_port(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'valid')), 'CoCoVectorInputPortSizeAndType.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 0 + + def test_invalid_co_co_vector_input_port(self): + model = self._parse_and_validate_model(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'invalid')), 'CoCoVectorInputPortSizeAndType.nestml')) + assert len(Logger.get_all_messages_of_level_and_or_node(model, LoggingLevel.ERROR)) == 1 + + def _parse_and_validate_model(self, fname: str) -> Optional[str]: + from pynestml.frontend.pynestml_frontend import generate_target + + Logger.init_logger(LoggingLevel.DEBUG) + + try: + generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG") + except BaseException: + return None + + ast_compilation_unit = ModelParser.parse_file(fname) + if ast_compilation_unit is None or len(ast_compilation_unit.get_model_list()) == 0: + return None + + model: ASTModel = ast_compilation_unit.get_model_list()[0] + model_name = model.get_name() + + return model_name diff --git a/tests/test_function_parameter_templating.py b/tests/test_function_parameter_templating.py new file mode 100644 index 000000000..b93e06780 --- /dev/null +++ b/tests/test_function_parameter_templating.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# +# test_function_parameter_templating.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os + +from pynestml.utils.logger import Logger, LoggingLevel +from pynestml.frontend.pynestml_frontend import generate_target + + +class TestFunctionParameterTemplating: + """ + This test is used to test the correct derivation of types when functions use templated type parameters. + """ + + def test(self): + fname = os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), "resources", "FunctionParameterTemplatingTest.nestml"))) + generate_target(input_path=fname, target_platform="NONE", logging_level="DEBUG") + assert len(Logger.get_all_messages_of_level_and_or_node("templated_function_parameters_type_test", LoggingLevel.ERROR)) == 5 diff --git a/tests/test_unit_system.py b/tests/test_unit_system.py new file mode 100644 index 000000000..2cad0b98d --- /dev/null +++ b/tests/test_unit_system.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# +# test_unit_system.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +import os +import pytest + +from pynestml.codegeneration.printers.constant_printer import ConstantPrinter +from pynestml.codegeneration.printers.cpp_expression_printer import CppExpressionPrinter +from pynestml.codegeneration.printers.cpp_simple_expression_printer import CppSimpleExpressionPrinter +from pynestml.codegeneration.printers.cpp_type_symbol_printer import CppTypeSymbolPrinter +from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter +from pynestml.codegeneration.printers.nest_cpp_function_call_printer import NESTCppFunctionCallPrinter +from pynestml.codegeneration.printers.nestml_variable_printer import NestMLVariablePrinter +from pynestml.frontend.pynestml_frontend import generate_target +from pynestml.symbol_table.symbol_table import SymbolTable +from pynestml.symbols.predefined_functions import PredefinedFunctions +from pynestml.symbols.predefined_types import PredefinedTypes +from pynestml.symbols.predefined_units import PredefinedUnits +from pynestml.symbols.predefined_variables import PredefinedVariables +from pynestml.utils.ast_source_location import ASTSourceLocation +from pynestml.utils.logger import Logger, LoggingLevel +from pynestml.utils.model_parser import ModelParser + + +class TestUnitSystem: + r""" + Test class for units system. + """ + + @pytest.fixture(scope="class", autouse=True) + def setUp(self, request): + Logger.set_logging_level(LoggingLevel.INFO) + + SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0)) + + PredefinedUnits.register_units() + PredefinedTypes.register_types() + PredefinedVariables.register_variables() + PredefinedFunctions.register_functions() + + Logger.init_logger(LoggingLevel.INFO) + + variable_printer = NestMLVariablePrinter(None) + function_call_printer = NESTCppFunctionCallPrinter(None) + cpp_variable_printer = CppVariablePrinter(None) + self.printer = CppExpressionPrinter(CppSimpleExpressionPrinter(cpp_variable_printer, + ConstantPrinter(), + function_call_printer)) + cpp_variable_printer._expression_printer = self.printer + variable_printer._expression_printer = self.printer + function_call_printer._expression_printer = self.printer + + request.cls.printer = self.printer + + def get_first_statement_in_update_block(self, model): + if model.get_model_list()[0].get_update_blocks()[0]: + return model.get_model_list()[0].get_update_blocks()[0].get_block().get_stmts()[0] + + return None + + def get_first_declaration_in_state_block(self, model): + assert len(model.get_model_list()[0].get_state_blocks()) == 1 + + return model.get_model_list()[0].get_state_blocks()[0].get_declarations()[0] + + def get_first_declared_function(self, model): + return model.get_model_list()[0].get_functions()[0] + + def print_rhs_of_first_assignment_in_update_block(self, model): + assignment = self.get_first_statement_in_update_block(model).small_stmt.get_assignment() + expression = assignment.get_expression() + + return self.printer.print(expression) + + def print_first_function_call_in_update_block(self, model): + function_call = self.get_first_statement_in_update_block(model).small_stmt.get_function_call() + + return self.printer.print(function_call) + + def print_rhs_of_first_declaration_in_state_block(self, model): + declaration = self.get_first_declaration_in_state_block(model) + expression = declaration.get_expression() + + return self.printer.print(expression) + + def print_first_return_statement_in_first_declared_function(self, model): + func = self.get_first_declared_function(model) + return_expression = func.get_block().get_stmts()[0].small_stmt.get_return_stmt().get_expression() + return self.printer.print(return_expression) + + def test_expression_after_magnitude_conversion_in_direct_assignment(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DirectAssignmentWithDifferentButCompatibleUnits.nestml')) + printed_rhs_expression = self.print_rhs_of_first_assignment_in_update_block(model) + + assert printed_rhs_expression == '(1000.0 * (10 * V))' + + def test_expression_after_nested_magnitude_conversion_in_direct_assignment(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DirectAssignmentWithDifferentButCompatibleNestedUnits.nestml')) + printed_rhs_expression = self.print_rhs_of_first_assignment_in_update_block(model) + + assert printed_rhs_expression == '(1000.0 * (10 * V + (0.001 * (5 * mV)) + 20 * V + (1000.0 * (1 * kV))))' + + def test_expression_after_magnitude_conversion_in_compound_assignment(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'CompoundAssignmentWithDifferentButCompatibleUnits.nestml')) + printed_rhs_expression = self.print_rhs_of_first_assignment_in_update_block(model) + + assert printed_rhs_expression == '(0.001 * (1200 * mV))' + + def test_expression_after_magnitude_conversion_in_declaration(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DeclarationWithDifferentButCompatibleUnitMagnitude.nestml')) + printed_rhs_expression = self.print_rhs_of_first_declaration_in_state_block(model) + + assert printed_rhs_expression == '(1000.0 * (10 * V))' + + def test_expression_after_type_conversion_in_declaration(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DeclarationWithDifferentButCompatibleUnits.nestml')) + declaration = self.get_first_declaration_in_state_block(model) + from astropy import units as u + + assert declaration.get_expression().type.unit.unit == u.mV + + def test_declaration_with_same_variable_name_as_unit(self): + Logger.init_logger(LoggingLevel.DEBUG) + + generate_target(input_path=os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'DeclarationWithSameVariableNameAsUnit.nestml'), target_platform="NONE", logging_level="DEBUG") + + assert len(Logger.get_all_messages_of_level_and_or_node("BlockTest", LoggingLevel.ERROR)) == 0 + assert len(Logger.get_all_messages_of_level_and_or_node("BlockTest", LoggingLevel.WARNING)) == 3 + + def test_expression_after_magnitude_conversion_in_standalone_function_call(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'FunctionCallWithDifferentButCompatibleUnits.nestml')) + printed_function_call = self.print_first_function_call_in_update_block(model) + + assert printed_function_call == 'foo((1000.0 * (10 * V)))' + + def test_expression_after_magnitude_conversion_in_rhs_function_call(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'RhsFunctionCallWithDifferentButCompatibleUnits.nestml')) + printed_function_call = self.print_rhs_of_first_assignment_in_update_block(model) + + assert printed_function_call == 'foo((1000.0 * (10 * V)))' + + def test_return_stmt_after_magnitude_conversion_in_function_body(self): + model = ModelParser.parse_file(os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), 'FunctionBodyReturnStatementWithDifferentButCompatibleUnits.nestml')) + printed_return_stmt = self.print_first_return_statement_in_first_declared_function(model) + + assert printed_return_stmt == '(0.001 * (bar))' diff --git a/tests/unit_system_test.py b/tests/unit_system_test.py deleted file mode 100644 index 1f7817b91..000000000 --- a/tests/unit_system_test.py +++ /dev/null @@ -1,177 +0,0 @@ -# -*- coding: utf-8 -*- -# -# unit_system_test.py -# -# This file is part of NEST. -# -# Copyright (C) 2004 The NEST Initiative -# -# NEST is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 2 of the License, or -# (at your option) any later version. -# -# NEST is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with NEST. If not, see . - -import os -import unittest -from pynestml.codegeneration.printers.constant_printer import ConstantPrinter - -from pynestml.codegeneration.printers.cpp_expression_printer import CppExpressionPrinter -from pynestml.codegeneration.printers.cpp_simple_expression_printer import CppSimpleExpressionPrinter -from pynestml.codegeneration.printers.cpp_type_symbol_printer import CppTypeSymbolPrinter -from pynestml.codegeneration.printers.nestml_variable_printer import NestMLVariablePrinter -from pynestml.symbol_table.symbol_table import SymbolTable -from pynestml.symbols.predefined_functions import PredefinedFunctions -from pynestml.symbols.predefined_types import PredefinedTypes -from pynestml.symbols.predefined_units import PredefinedUnits -from pynestml.symbols.predefined_variables import PredefinedVariables -from pynestml.utils.ast_source_location import ASTSourceLocation -from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter -from pynestml.codegeneration.printers.nest_cpp_function_call_printer import NESTCppFunctionCallPrinter -from pynestml.codegeneration.printers.cpp_function_call_printer import CppFunctionCallPrinter -from pynestml.utils.logger import Logger, LoggingLevel -from pynestml.utils.model_parser import ModelParser - - -SymbolTable.initialize_symbol_table(ASTSourceLocation(start_line=0, start_column=0, end_line=0, end_column=0)) - -PredefinedUnits.register_units() -PredefinedTypes.register_types() -PredefinedVariables.register_variables() -PredefinedFunctions.register_functions() - -Logger.init_logger(LoggingLevel.INFO) - -type_symbol_printer = CppTypeSymbolPrinter() -variable_printer = NestMLVariablePrinter(None) -function_call_printer = NESTCppFunctionCallPrinter(None) -cpp_variable_printer = CppVariablePrinter(None) -printer = CppExpressionPrinter(CppSimpleExpressionPrinter(cpp_variable_printer, - ConstantPrinter(), - function_call_printer)) -cpp_variable_printer._expression_printer = printer -variable_printer._expression_printer = printer -function_call_printer._expression_printer = printer - - -def get_first_statement_in_update_block(model): - if model.get_model_list()[0].get_update_blocks()[0]: - return model.get_model_list()[0].get_update_blocks()[0].get_block().get_stmts()[0] - return None - - -def get_first_declaration_in_state_block(model): - assert len(model.get_model_list()[0].get_state_blocks()) == 1 - return model.get_model_list()[0].get_state_blocks()[0].get_declarations()[0] - - -def get_first_declared_function(model): - return model.get_model_list()[0].get_functions()[0] - - -def print_rhs_of_first_assignment_in_update_block(model): - assignment = get_first_statement_in_update_block(model).small_stmt.get_assignment() - expression = assignment.get_expression() - return printer.print(expression) - - -def print_first_function_call_in_update_block(model): - function_call = get_first_statement_in_update_block(model).small_stmt.get_function_call() - return printer.print(function_call) - - -def print_rhs_of_first_declaration_in_state_block(model): - declaration = get_first_declaration_in_state_block(model) - expression = declaration.get_expression() - return printer.print(expression) - - -def print_first_return_statement_in_first_declared_function(model): - func = get_first_declared_function(model) - return_expression = func.get_block().get_stmts()[0].small_stmt.get_return_stmt().get_expression() - return printer.print(return_expression) - - -class UnitSystemTest(unittest.TestCase): - """ - Test class for everything Unit related. - """ - - def setUp(self): - Logger.set_logging_level(LoggingLevel.INFO) - - def test_expression_after_magnitude_conversion_in_direct_assignment(self): - Logger.set_logging_level(LoggingLevel.INFO) - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'DirectAssignmentWithDifferentButCompatibleUnits.nestml')) - printed_rhs_expression = print_rhs_of_first_assignment_in_update_block(model) - - self.assertEqual(printed_rhs_expression, '(1000.0 * (10 * V))') - - def test_expression_after_nested_magnitude_conversion_in_direct_assignment(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'DirectAssignmentWithDifferentButCompatibleNestedUnits.nestml')) - printed_rhs_expression = print_rhs_of_first_assignment_in_update_block(model) - - self.assertEqual(printed_rhs_expression, '(1000.0 * (10 * V + (0.001 * (5 * mV)) + 20 * V + (1000.0 * (1 * kV))))') - - def test_expression_after_magnitude_conversion_in_compound_assignment(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'CompoundAssignmentWithDifferentButCompatibleUnits.nestml')) - printed_rhs_expression = print_rhs_of_first_assignment_in_update_block(model) - self.assertEqual(printed_rhs_expression, '(0.001 * (1200 * mV))') - - def test_expression_after_magnitude_conversion_in_declaration(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'DeclarationWithDifferentButCompatibleUnitMagnitude.nestml')) - printed_rhs_expression = print_rhs_of_first_declaration_in_state_block(model) - self.assertEqual(printed_rhs_expression, '(1000.0 * (10 * V))') - - def test_expression_after_type_conversion_in_declaration(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'DeclarationWithDifferentButCompatibleUnits.nestml')) - declaration = get_first_declaration_in_state_block(model) - from astropy import units as u - self.assertTrue(declaration.get_expression().type.unit.unit == u.mV) - - def test_declaration_with_same_variable_name_as_unit(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'DeclarationWithSameVariableNameAsUnit.nestml')) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.ERROR)), 0) - self.assertEqual(len( - Logger.get_all_messages_of_level_and_or_node(model.get_model_list()[0], LoggingLevel.WARNING)), 3) - - def test_expression_after_magnitude_conversion_in_standalone_function_call(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'FunctionCallWithDifferentButCompatibleUnits.nestml')) - printed_function_call = print_first_function_call_in_update_block(model) - self.assertEqual(printed_function_call, 'foo((1000.0 * (10 * V)))') - - def test_expression_after_magnitude_conversion_in_rhs_function_call(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'RhsFunctionCallWithDifferentButCompatibleUnits.nestml')) - printed_function_call = print_rhs_of_first_assignment_in_update_block(model) - self.assertEqual(printed_function_call, 'foo((1000.0 * (10 * V)))') - - def test_return_stmt_after_magnitude_conversion_in_function_body(self): - model = ModelParser.parse_file( - os.path.join(os.path.realpath(os.path.join(os.path.dirname(__file__), 'resources')), - 'FunctionBodyReturnStatementWithDifferentButCompatibleUnits.nestml')) - printed_return_stmt = print_first_return_statement_in_first_declared_function(model) - self.assertEqual(printed_return_stmt, '(0.001 * (bar))')