Skip to content

Commit 82d0b5f

Browse files
Fabian Fröhlichdweindl
andauthored
fix initial events, fixes #1760 (#1789)
* initial implementation, fixes #1760 * fixup * fixup * fixup * fix doc * fix doc * fixup swig * fixup sbml import * add sensitivity check, fixup time-independent triggers * fixup * fixup * fixup * fixup parameter event assignments * update test stats * Apply suggestions from code review Co-authored-by: Daniel Weindl <[email protected]> Co-authored-by: Daniel Weindl <[email protected]>
1 parent 32c6c42 commit 82d0b5f

File tree

13 files changed

+148
-103
lines changed

13 files changed

+148
-103
lines changed

documentation/python_interface.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ AMICI can import :term:`SBML` models via the
2626
Status of SBML support in Python-AMICI
2727
++++++++++++++++++++++++++++++++++++++
2828

29-
Python-AMICI currently **passes 1014 out of the 1821 (~56%) test cases** from
29+
Python-AMICI currently **passes 1030 out of the 1821 (~57%) test cases** from
3030
the semantic
3131
`SBML Test Suite <https://github.com/sbmlteam/sbml-test-suite/>`_
3232
(`current status <https://github.com/AMICI-dev/AMICI/actions>`_).

include/amici/forwardproblem.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,11 @@ class ForwardProblem {
290290
*
291291
* @param tlastroot pointer to the timepoint of the last event
292292
* @param seflag Secondary event flag
293+
* @param initial_event initial event flag
293294
*/
294295

295-
void handleEvent(realtype *tlastroot,bool seflag);
296+
void handleEvent(realtype *tlastroot, bool seflag,
297+
bool initial_event);
296298

297299
/**
298300
* @brief Extract output information for events

include/amici/model.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ class Model : public AbstractModel, public ModelDimensions {
9797
SimulationParameters simulation_parameters,
9898
amici::SecondOrderMode o2mode,
9999
std::vector<amici::realtype> idlist,
100-
std::vector<int> z2event, bool pythonGenerated = false,
100+
std::vector<int> z2event,
101+
bool pythonGenerated = false,
101102
int ndxdotdp_explicit = 0, int ndxdotdx_explicit = 0,
102103
int w_recursion_depth = 0);
103104

@@ -202,9 +203,11 @@ class Model : public AbstractModel, public ModelDimensions {
202203
* @param sdx Reference to time derivative of state sensitivities (DAE only)
203204
* @param computeSensitivities Flag indicating whether sensitivities are to
204205
* be computed
206+
* @param roots_found boolean indicators indicating whether roots were found at t0 by this fun
205207
*/
206208
void initialize(AmiVector &x, AmiVector &dx, AmiVectorArray &sx,
207-
AmiVectorArray &sdx, bool computeSensitivities);
209+
AmiVectorArray &sdx, bool computeSensitivities,
210+
std::vector<int> &roots_found);
208211

209212
/**
210213
* @brief Initialize model properties.
@@ -236,8 +239,10 @@ class Model : public AbstractModel, public ModelDimensions {
236239
*
237240
* @param x Reference to state variables
238241
* @param dx Reference to time derivative of states (DAE only)
242+
* @param roots_found boolean indicators indicating whether roots were found at t0 by this fun
239243
*/
240-
void initHeaviside(const AmiVector &x, const AmiVector &dx);
244+
void initEvents(const AmiVector &x, const AmiVector &dx,
245+
std::vector<int> &roots_found);
241246

242247
/**
243248
* @brief Get number of parameters wrt to which sensitivities are computed.
@@ -1864,6 +1869,11 @@ class Model : public AbstractModel, public ModelDimensions {
18641869
/** vector of bools indicating whether state variables are to be assumed to
18651870
* be positive */
18661871
std::vector<bool> state_is_non_negative_;
1872+
1873+
/** Vector of booleans indicating the initial boolean value for every event trigger function. Events at t0
1874+
* can only trigger if the initial value is set to `false`. Must be specified during model compilation by
1875+
* setting the `initialValue` attribute of an event trigger. */
1876+
std::vector<bool> root_initial_values_;
18671877

18681878
/** boolean indicating whether any entry in stateIsNonNegative is `true` */
18691879
bool any_state_non_negative_ {false};

python/amici/ode_export.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -822,7 +822,7 @@ def transform_dxdt_to_concentration(species_id, dxdt):
822822
else:
823823
args += ['value']
824824
if symbol_name == SymbolId.EVENT:
825-
args += ['state_update', 'event_observable']
825+
args += ['state_update', 'event_observable', 'initial_value']
826826
if symbol_name == SymbolId.OBSERVABLE:
827827
args += ['transformation']
828828

@@ -1589,34 +1589,43 @@ def _compute_equation(self, name: str) -> None:
15891589
elif name == 'stau':
15901590
self._eqs[name] = [
15911591
-self.eq('sroot')[ie, :] / self.eq('drootdt_total')[ie]
1592+
if not self.eq('drootdt_total')[ie].is_zero else
1593+
sp.zeros(*self.eq('sroot')[ie, :].shape)
15921594
for ie in range(self.num_events())
15931595
]
15941596

15951597
elif name == 'deltasx':
15961598
event_eqs = []
15971599
for ie, event in enumerate(self._events):
1598-
if event._state_update is not None:
1599-
# ====== chain rule for the state variables ===============
1600-
# get xdot with expressions back-substituted
1601-
tmp_eq = smart_multiply(
1600+
1601+
tmp_eq = sp.zeros(self.num_states_solver(), self.num_par())
1602+
1603+
# only add stau part if trigger is time-dependent
1604+
if not self.eq('drootdt_total')[ie].is_zero:
1605+
tmp_eq += smart_multiply(
16021606
(self.sym('xdot_old') - self.sym('xdot')),
16031607
self.eq('stau')[ie])
1604-
# construct an enhanced state sensitivity, which accounts
1605-
# for the time point sensitivity as well
1608+
1609+
# only add deltax part if there is state update
1610+
if event._state_update is not None:
1611+
# partial derivative for the parameters
1612+
tmp_eq += self.eq('ddeltaxdp')[ie]
1613+
1614+
# initial part of chain rule state variables
16061615
tmp_dxdp = self.sym('sx') * sp.ones(1, self.num_par())
1607-
tmp_dxdp += smart_multiply(self.sym('xdot'),
1608-
self.eq('stau')[ie])
1616+
1617+
# only add stau part if trigger is time-dependent
1618+
if not self.eq('drootdt_total')[ie].is_zero:
1619+
# chain rule for the time point
1620+
tmp_eq += smart_multiply(self.eq('ddeltaxdt')[ie],
1621+
self.eq('stau')[ie])
1622+
1623+
# additional part of chain rule state variables
1624+
tmp_dxdp += smart_multiply(self.sym('xdot'),
1625+
self.eq('stau')[ie])
1626+
# finish chain rule for the state variables
16091627
tmp_eq += smart_multiply(self.eq('ddeltaxdx')[ie],
16101628
tmp_dxdp)
1611-
# ====== chain rule for the time point ====================
1612-
tmp_eq += smart_multiply(self.eq('ddeltaxdt')[ie],
1613-
self.eq('stau')[ie])
1614-
# ====== partial derivative for the parameters ============
1615-
tmp_eq += self.eq('ddeltaxdp')[ie]
1616-
else:
1617-
tmp_eq = smart_multiply(
1618-
(self.eq('xdot_old') - self.eq('xdot')),
1619-
self.eq('stau')[ie])
16201629

16211630
event_eqs.append(tmp_eq)
16221631

@@ -2937,6 +2946,11 @@ def _write_model_header_cpp(self) -> None:
29372946
'W_RECURSION_DEPTH': self.model._w_recursion_depth,
29382947
'QUADRATIC_LLH': 'true'
29392948
if self.model._has_quadratic_nllh else 'false',
2949+
'ROOT_INITIAL_VALUES':
2950+
', '.join([
2951+
'true' if event.get_initial_value() else 'false'
2952+
for event in self.model._events
2953+
])
29402954
}
29412955

29422956
for func_name, func_info in self.functions.items():

python/amici/ode_model.py

Lines changed: 27 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,50 +2,27 @@
22

33

44
import sympy as sp
5-
import numpy as np
6-
import re
7-
import shutil
8-
import subprocess
9-
import sys
10-
import os
11-
import copy
125
import numbers
13-
import logging
14-
import itertools
15-
import contextlib
166

177
try:
188
import pysb
199
except ImportError:
2010
pysb = None
2111

2212
from typing import (
23-
Callable, Optional, Union, List, Dict, Tuple, SupportsFloat, Sequence,
24-
Set, Any
13+
Optional, Union, Dict, SupportsFloat, Set
2514
)
26-
from dataclasses import dataclass
27-
from string import Template
28-
from sympy.matrices.immutable import ImmutableDenseMatrix
29-
from sympy.matrices.dense import MutableDenseMatrix
30-
from sympy.logic.boolalg import BooleanAtom
31-
from itertools import chain
32-
from .cxxcodeprinter import AmiciCxxCodePrinter, get_switch_statement
33-
34-
from . import (
35-
amiciSwigPath, amiciSrcPath, amiciModulePath, __version__, __commit__,
36-
sbml_import
37-
)
38-
from .logging import get_logger, log_execution_time, set_log_level
39-
from .constants import SymbolId
40-
from .import_utils import smart_subs_dict, toposort_symbols, \
41-
ObservableTransformation, generate_measurement_symbol, RESERVED_SYMBOLS
15+
16+
from .import_utils import ObservableTransformation, \
17+
generate_measurement_symbol, RESERVED_SYMBOLS
4218
from .import_utils import cast_to_sym
4319

4420
__all__ = [
4521
'ConservationLaw', 'Constant', 'Event', 'Expression', 'LogLikelihood',
4622
'ModelQuantity', 'Observable', 'Parameter', 'SigmaY', 'State'
4723
]
4824

25+
4926
class ModelQuantity:
5027
"""
5128
Base class for model components
@@ -166,14 +143,6 @@ def __init__(self,
166143
self._ncoeff: sp.Expr = coefficients[state_id]
167144
super(ConservationLaw, self).__init__(identifier, name, value)
168145

169-
def get_state(self) -> sp.Symbol:
170-
"""
171-
Get the identifier of the state that this conservation law replaces
172-
173-
:return: identifier of the state
174-
"""
175-
return self._state_id
176-
177146
def get_ncoeff(self, state_id) -> Union[sp.Expr, int, float]:
178147
"""
179148
Computes the normalized coefficient a_i/a_j where i is the index of
@@ -211,10 +180,6 @@ class State(ModelQuantity):
211180
algebraic formula that defines the temporal derivative of this state
212181
213182
"""
214-
215-
_dt: Union[sp.Expr, None] = None
216-
_conservation_law: Union[sp.Expr, None] = None
217-
218183
def __init__(self,
219184
identifier: sp.Symbol,
220185
name: str,
@@ -276,7 +241,7 @@ def get_dt(self) -> sp.Expr:
276241
"""
277242
return self._dt
278243

279-
def get_free_symbols(self) -> Set[sp.Symbol]:
244+
def get_free_symbols(self) -> Set[sp.Basic]:
280245
"""
281246
Gets the set of free symbols in time derivative and initial conditions
282247
@@ -307,8 +272,9 @@ def get_x_rdata(self):
307272

308273
def get_dx_rdata_dx_solver(self, state_id):
309274
"""
310-
Returns the expression that allows computation of ``dx_rdata_dx_solver`` for this
311-
state, accounting for conservation laws.
275+
Returns the expression that allows computation of
276+
``dx_rdata_dx_solver`` for this state, accounting for conservation
277+
laws.
312278
313279
:return: dx_rdata_dx_solver expression
314280
"""
@@ -514,7 +480,8 @@ def __init__(self,
514480
name: str,
515481
value: sp.Expr,
516482
state_update: Union[sp.Expr, None],
517-
event_observable: Union[sp.Expr, None]):
483+
event_observable: Union[sp.Expr, None],
484+
initial_value: Optional[bool] = True):
518485
"""
519486
Create a new Event instance.
520487
@@ -534,15 +501,30 @@ def __init__(self,
534501
:param event_observable:
535502
formula a potential observable linked to the event
536503
(None for Heaviside functions, empty events without observable)
504+
505+
:param initial_value:
506+
initial boolean value of the trigger function at t0. If set to
507+
`False`, events may trigger at ``t==t0``, otherwise not.
537508
"""
538509
super(Event, self).__init__(identifier, name, value)
539510
# add the Event specific components
540511
self._state_update = state_update
541512
self._observable = event_observable
513+
self._initial_value = initial_value
514+
515+
def get_initial_value(self) -> bool:
516+
"""
517+
Return the initial value for the root function.
518+
519+
:return:
520+
initial value formula
521+
"""
522+
return self._initial_value
542523

543524
def __eq__(self, other):
544525
"""
545526
Check equality of events at the level of trigger/root functions, as we
546527
need to collect unique root functions for ``roots.cpp``
547528
"""
548-
return self.get_val() == other.get_val()
529+
return self.get_val() == other.get_val() and \
530+
(self.get_initial_value() == other.get_initial_value())

python/amici/sbml_import.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -484,13 +484,6 @@ def check_event_support(self) -> None:
484484
if trigger_sbml.getMath() is None:
485485
logger.warning(f'Event {event_id} trigger has no trigger '
486486
'expression, so a dummy trigger will be set.')
487-
if not trigger_sbml.getInitialValue():
488-
# True: event not executed if triggered at time == 0
489-
# (corresponding to AMICI default). Raise if set to False.
490-
raise SBMLException(
491-
f'Event {event_id} has a trigger that has an initial '
492-
'value of False. This is currently not supported in AMICI.'
493-
)
494487

495488
if not trigger_sbml.getPersistent():
496489
raise SBMLException(
@@ -1001,10 +994,16 @@ def _convert_event_assignment_parameter_targets_to_species(self):
1001994
parameter_def = \
1002995
self.symbols[symbol_id].pop(parameter_target)
1003996
if parameter_def is None:
1004-
raise AssertionError(
1005-
'Unexpected error. The parameter target of an event '
1006-
'assignment could not be found.'
997+
# this happens for parameters that have initial assignments
998+
# or are assignment rule targets
999+
par = self.sbml.getElementBySId(str(parameter_target))
1000+
ia_init = self._get_element_initial_assignment(
1001+
par.getId()
10071002
)
1003+
parameter_def = {
1004+
'name': par.getName() if par.isSetName() else par.getId(),
1005+
'value': par.getValue() if ia_init is None else ia_init
1006+
}
10081007
# Fixed parameters are added as species such that they can be
10091008
# targets of events.
10101009
self.symbols[SymbolId.SPECIES][parameter_target] = {
@@ -1140,6 +1139,9 @@ def get_empty_bolus_value() -> sp.Float:
11401139
'value': trigger,
11411140
'state_update': sp.MutableDenseMatrix(bolus),
11421141
'event_observable': None,
1142+
'initial_value':
1143+
trigger_sbml.getInitialValue() if trigger_sbml is not None
1144+
else True,
11431145
}
11441146

11451147
@log_execution_time('processing SBML observables', logger)

0 commit comments

Comments
 (0)