Skip to content

Add numeric solver to synapses #1208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pynestml/codegeneration/nest_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,23 @@ def setup_printers(self):
self._gsl_variable_printer = GSLVariablePrinter(None)
if self.option_exists("nest_version") and (self.get_option("nest_version").startswith("2") or self.get_option("nest_version").startswith("v2")):
self._gsl_function_call_printer = NEST2GSLFunctionCallPrinter(None)
self._gsl_function_call_printer_no_origin = NEST2GSLFunctionCallPrinter(None)
else:
self._gsl_function_call_printer = NESTGSLFunctionCallPrinter(None)
self._gsl_function_call_printer_no_origin = NEST2GSLFunctionCallPrinter(None)

self._gsl_printer = CppExpressionPrinter(simple_expression_printer=CppSimpleExpressionPrinter(variable_printer=self._gsl_variable_printer,
constant_printer=self._constant_printer,
function_call_printer=self._gsl_function_call_printer))
self._gsl_function_call_printer._expression_printer = self._gsl_printer

self._gsl_variable_printer_no_origin = GSLVariablePrinter(None, with_origin=False)
self._gsl_printer_no_origin = CppExpressionPrinter(simple_expression_printer=CppSimpleExpressionPrinter(variable_printer=self._gsl_variable_printer_no_origin,
constant_printer=self._constant_printer,
function_call_printer=self._gsl_function_call_printer))
self._gsl_variable_printer_no_origin._expression_printer = self._gsl_printer_no_origin
self._gsl_function_call_printer_no_origin._expression_printer = self._gsl_printer_no_origin

# ODE-toolbox printers
self._ode_toolbox_variable_printer = ODEToolboxVariablePrinter(None)
self._ode_toolbox_function_call_printer = ODEToolboxFunctionCallPrinter(None)
Expand Down Expand Up @@ -518,6 +527,7 @@ def _get_model_namespace(self, astnode: ASTModel) -> Dict:
namespace["printer"] = self._nest_printer
namespace["printer_no_origin"] = self._printer_no_origin
namespace["gsl_printer"] = self._gsl_printer
namespace["gsl_printer_no_origin"] = self._gsl_printer_no_origin
namespace["nestml_printer"] = NESTMLPrinter()
namespace["type_symbol_printer"] = self._type_symbol_printer

Expand Down Expand Up @@ -664,6 +674,9 @@ def _get_synapse_model_namespace(self, synapse: ASTModel) -> Dict:
expr_ast.accept(ASTSymbolTableVisitor())
namespace["numeric_update_expressions"][sym] = expr_ast

ASTUtils.assign_numeric_non_numeric_state_variables(synapse, namespace["numeric_state_variables"],
namespace["numeric_update_expressions"] if "numeric_update_expressions" in namespace.keys() else None, namespace["update_expressions"] if "update_expressions" in namespace.keys() else None)

namespace["spike_updates"] = synapse.spike_updates

# special case for NEST delay variable (state or parameter)
Expand Down
53 changes: 28 additions & 25 deletions pynestml/codegeneration/printers/gsl_variable_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
from pynestml.codegeneration.nest_code_generator_utils import NESTCodeGeneratorUtils
from pynestml.codegeneration.nest_unit_converter import NESTUnitConverter
from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter
from pynestml.codegeneration.printers.expression_printer import ExpressionPrinter
from pynestml.meta_model.ast_variable import ASTVariable
from pynestml.symbols.predefined_units import PredefinedUnits
from pynestml.symbols.symbol import SymbolKind
Expand All @@ -33,46 +35,39 @@ class GSLVariablePrinter(CppVariablePrinter):
Variable printer for C++ syntax and using the GSL (GNU Scientific Library) API from inside the ``extern "C"`` stepping function.
"""

def print_variable(self, node: ASTVariable) -> str:
def __init__(self, expression_printer: ExpressionPrinter, with_origin: bool = True, ):
super().__init__(expression_printer)
self.with_origin = with_origin

def print_variable(self, variable: ASTVariable) -> str:
"""
Converts a single name reference to a gsl processable format.
:param node: a single variable
:param variable: a single variable
:return: a gsl processable format of the variable
"""
assert isinstance(node, ASTVariable)
symbol = node.get_scope().resolve_to_symbol(node.get_complete_name(), SymbolKind.VARIABLE)
assert isinstance(variable, ASTVariable)
symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE)

if symbol is None:
# test if variable name can be resolved to a type
if PredefinedUnits.is_unit(node.get_complete_name()):
return str(NESTUnitConverter.get_factor(PredefinedUnits.get_unit(node.get_complete_name()).get_unit()))
if PredefinedUnits.is_unit(variable.get_complete_name()):
return str(
NESTUnitConverter.get_factor(PredefinedUnits.get_unit(variable.get_complete_name()).get_unit()))

code, message = Messages.get_could_not_resolve(node.get_name())
code, message = Messages.get_could_not_resolve(variable.get_name())
Logger.log_message(log_level=LoggingLevel.ERROR, code=code, message=message,
error_position=node.get_source_position())
error_position=variable.get_source_position())
return ""

if node.is_delay_variable():
return self._print_delay_variable(node)
if variable.is_delay_variable():
return self._print_delay_variable(variable)

if symbol.is_state() and not symbol.is_inline_expression:
if "_is_numeric" in dir(node) and node._is_numeric:
if "_is_numeric" in dir(variable) and variable._is_numeric:
# ode_state[] here is---and must be---the state vector supplied by the integrator, not the state vector in the node, node.S_.ode_state[].
return "ode_state[State_::" + CppVariablePrinter._print_cpp_name(node.get_complete_name()) + "]"

# non-ODE state symbol
return "node.S_." + CppVariablePrinter._print_cpp_name(node.get_complete_name())

if symbol.is_parameters():
return "node.P_." + super().print_variable(node)

if symbol.is_internals():
return "node.V_." + super().print_variable(node)
return "ode_state[State_::" + CppVariablePrinter._print_cpp_name(variable.get_complete_name()) + "]"

if symbol.is_input():
return "node.B_." + self._print_buffer_value(node)

raise Exception("Unknown node type")
return self._print(variable, symbol, with_origin=self.with_origin)

def _print_delay_variable(self, variable: ASTVariable) -> str:
"""
Expand Down Expand Up @@ -104,3 +99,11 @@ def _print_buffer_value(self, variable: ASTVariable) -> str:
return "spike_inputs_grid_sum_[node." + var_name + " - node.MIN_SPIKE_RECEPTOR]"

return variable_symbol.get_symbol_name() + '_grid_sum_'

def _print(self, variable, symbol, with_origin: bool = True):
variable_name = CppVariablePrinter._print_cpp_name(variable.get_complete_name())

if with_origin:
return "node." + NESTCodeGeneratorUtils.print_symbol_origin(symbol, variable) % variable_name

return "node." + variable_name
Loading
Loading