Skip to content

Commit 0d8d0d0

Browse files
committed
PEtab v2 import via amici
Support for [PEtab v2](https://petab.readthedocs.io/en/latest/v2/documentation_data_format.html) problems via amici. **WIP**
1 parent 8aac10c commit 0d8d0d0

8 files changed

Lines changed: 941 additions & 53 deletions

File tree

pypesto/objective/amici/amici.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
from pathlib import Path
1010
from typing import TYPE_CHECKING, Union
1111

12+
import amici.importers.petab
1213
import numpy as np
14+
import pandas as pd
1315

1416
from ...C import (
1517
FVAL,
@@ -718,3 +720,67 @@ def update_from_problem(
718720
) in condition_mapping.map_preeq_fix.items():
719721
if (val := id_to_val.get(mapped_to_par)) is not None:
720722
condition_mapping.map_preeq_fix[model_par] = val
723+
724+
725+
class AmiciPetabV2Objective(AmiciObjective):
726+
"""An AMICI objective constructed from a PEtab v2 problem."""
727+
728+
def __init__(
729+
self,
730+
petab_importer: amici.importers.petab.PetabImporter,
731+
**kwargs,
732+
) -> None:
733+
from .amici_calculator import AmiciCalculatorPetabV2
734+
735+
self._petab_simulator: amici.petab.petab_importer.PetabSimulator = (
736+
petab_importer.create_simulator()
737+
)
738+
self.petab_problem = petab_importer.petab_problem
739+
amici_model = self._petab_simulator.model
740+
amici_solver = self._petab_simulator.solver
741+
edatas = self._petab_simulator.exp_man.create_edatas()
742+
743+
super().__init__(
744+
amici_model=amici_model,
745+
amici_solver=amici_solver,
746+
edatas=edatas,
747+
calculator=AmiciCalculatorPetabV2(self._petab_simulator),
748+
**kwargs,
749+
)
750+
751+
def __deepcopy__(self, memo=None):
752+
"""Override AmiciObjective.__deepcopy__."""
753+
if memo is None:
754+
memo = {}
755+
cls = self.__class__
756+
result = cls.__new__(cls)
757+
memo[id(self)] = result
758+
for k, v in self.__dict__.items():
759+
setattr(result, k, copy.deepcopy(v, memo))
760+
return result
761+
762+
def __getstate__(self) -> dict:
763+
"""Use Python's default pickling semantics (shallow copy of instance dict)."""
764+
return dict(self.__dict__)
765+
766+
def __setstate__(self, state: dict) -> None:
767+
"""Restore state using the instance dict (default unpickling behaviour)."""
768+
self.__dict__.update(state)
769+
770+
def rdatas_to_simulation_df(
771+
self,
772+
rdatas: Sequence[amici.ReturnData],
773+
) -> pd.DataFrame:
774+
"""
775+
See :meth:`rdatas_to_measurement_df`.
776+
777+
Except a petab simulation dataframe is created, i.e. the measurement
778+
column label is adjusted.
779+
"""
780+
from amici.importers.petab import rdatas_to_simulation_df
781+
782+
return rdatas_to_simulation_df(
783+
rdatas,
784+
self._petab_simulator._model,
785+
self._petab_simulator._petab_problem,
786+
)

pypesto/objective/amici/amici_calculator.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,163 @@ def __call__(
154154
)
155155

156156

157+
class AmiciCalculatorPetabV2(AmiciCalculator):
158+
"""Class to perform the AMICI call and obtain objective function values."""
159+
160+
def __init__(
161+
self,
162+
petab_simulator: amici.petab.petab_importer.PetabSimulator,
163+
**kwargs,
164+
):
165+
super().__init__(**kwargs)
166+
self.petab_simulator = petab_simulator
167+
168+
def __call__(
169+
self,
170+
x_dct: dict,
171+
sensi_orders: tuple[int],
172+
mode: ModeType,
173+
amici_model: AmiciModel,
174+
amici_solver: AmiciSolver,
175+
edatas: list[amici.ExpData],
176+
n_threads: int,
177+
x_ids: Sequence[str],
178+
parameter_mapping: ParameterMapping,
179+
fim_for_hess: bool,
180+
):
181+
"""Perform the actual AMICI call.
182+
183+
Called within the :func:`AmiciObjective.__call__` method.
184+
185+
Parameters
186+
----------
187+
x_dct:
188+
Parameters for which to compute function value and derivatives.
189+
sensi_orders:
190+
Tuple of requested sensitivity orders.
191+
mode:
192+
Call mode (function value or residual based).
193+
amici_model:
194+
The AMICI model.
195+
amici_solver:
196+
The AMICI solver.
197+
edatas:
198+
The experimental data.
199+
n_threads:
200+
Number of threads for AMICI call.
201+
x_ids:
202+
Ids of optimization parameters.
203+
parameter_mapping:
204+
Mapping of optimization to simulation parameters.
205+
fim_for_hess:
206+
Whether to use the FIM (if available) instead of the Hessian (if
207+
requested).
208+
"""
209+
amici_solver = self.petab_simulator._solver
210+
211+
if mode != MODE_FUN:
212+
raise NotImplementedError(
213+
"Only function value mode is currently supported for "
214+
f"PEtab v2. Got mode {mode}."
215+
)
216+
217+
# TODO: -> method
218+
# set order in solver
219+
sensi_order = 0
220+
if sensi_orders:
221+
sensi_order = max(sensi_orders)
222+
223+
if sensi_order == 2 and fim_for_hess:
224+
# we use the FIM
225+
amici_solver.set_sensitivity_order(sensi_order - 1)
226+
else:
227+
amici_solver.set_sensitivity_order(sensi_order)
228+
229+
dim = len(x_ids)
230+
231+
# run amici simulation
232+
result = self.petab_simulator.simulate(x_dct)
233+
rdatas = result.rdatas
234+
235+
# check if the simulation failed
236+
if any(rdata["status"] < 0.0 for rdata in rdatas):
237+
return get_error_output(
238+
amici_model, edatas, rdatas, sensi_orders, mode, dim
239+
)
240+
241+
nllh, snllh, s2nllh, chi2, res, sres = init_return_values(
242+
sensi_orders, mode, dim
243+
)
244+
nllh = -result.llh
245+
246+
if (
247+
not self._known_least_squares_safe
248+
and mode == MODE_RES
249+
and 1 in sensi_orders
250+
):
251+
if not amici_model.get_add_sigma_residuals() and any(
252+
(
253+
(r["ssigmay"] is not None and np.any(r["ssigmay"]))
254+
or (r["ssigmaz"] is not None and np.any(r["ssigmaz"]))
255+
)
256+
for r in rdatas
257+
):
258+
raise RuntimeError(
259+
"Cannot use least squares solver with"
260+
"parameter dependent sigma! Support can be "
261+
"enabled via "
262+
"amici_model.setAddSigmaResiduals()."
263+
)
264+
self._known_least_squares_safe = True # don't check this again
265+
266+
# TODO: compute res, sres
267+
268+
if 1 in sensi_orders:
269+
if result.sllh is None and np.isnan(result["llh"]):
270+
# TODO: to amici -- set sllh even if llh is nan?
271+
snllh = np.full(len(x_ids), np.nan)
272+
else:
273+
try:
274+
# llh to nllh, dict to array
275+
snllh = -np.array(
276+
[
277+
result.sllh[
278+
x_id
279+
] # if x_id in res["sllh"] else 0.0
280+
for x_id in x_ids
281+
if x_id in x_dct.keys()
282+
]
283+
)
284+
except KeyError as e:
285+
# A requested sensitivity is missing.
286+
# Probably the affected parameter is a fixed parameter
287+
# in amici instead of a sensitivity parameter
288+
# (non_estimated_parameters_as_constants=True ?).
289+
# In this case, only max(sensi_orders) == 0 is supported
290+
# unless this parameter is fixed in the pypesto problem.
291+
raise ValueError(
292+
f"Cannot compute gradient, missing entry for "
293+
f"{set(result.sllh) - set(x_dct.keys())}."
294+
) from e
295+
if 2 in sensi_orders:
296+
if result.s2llh is None and np.isnan(result.llh):
297+
# TODO: to amici -- set s2llh even if llh is nan?
298+
s2nllh = np.full((len(x_ids), len(x_ids)), np.nan)
299+
else:
300+
s2nllh = -result.s2llh
301+
302+
ret = {
303+
FVAL: nllh,
304+
GRAD: snllh,
305+
HESS: s2nllh,
306+
RES: res,
307+
SRES: sres,
308+
RDATAS: rdatas,
309+
}
310+
311+
return filter_return_dict(ret)
312+
313+
157314
def calculate_function_values(
158315
rdatas,
159316
sensi_orders: tuple[int, ...],

0 commit comments

Comments
 (0)