-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSherpaSpectralModel.py
More file actions
86 lines (73 loc) · 3.25 KB
/
Copy pathSherpaSpectralModel.py
File metadata and controls
86 lines (73 loc) · 3.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Licensed under a 3-clause BSD style license - see LICENSE.rst
from gammapy.modeling.models import SpectralModel, SkyModel, SPECTRAL_MODEL_REGISTRY
from gammapy.modeling import Parameter, Parameters
from astropy import units as u
import numpy as np
class SherpaSpectralModel(SpectralModel):
"""A wrapper for Sherpa spectral models.
Parameters
----------
sherpa_model :
An instance of the models defined in `~sherpa.models` or `~sherpa.astro.xspec`.
default_units : tuple
Units of the input energy array and output model evaluation (find them in the sherpa/xspec docs!)
"""
tag = ["SherpaSpectralModel", "sherpa", "xspec"]
def __init__(
self, sherpa_model, integrated=True, default_units=(u.keV, 1 / (u.keV * u.cm ** 2 * u.s))
):
self.sherpa_model = sherpa_model
#names=[x.name for x in self.sherpa_model.pars]
#for j, val in enumerate(names):
self.default_units = default_units
self.default_parameters = self._wrap_parameters()
self.integrated = integrated
super().__init__()
def _wrap_parameters(self):
parameters = []
for par in self.sherpa_model.pars:
is_norm = par.name in ["ampl", "norm", "K", "norm2"]
parameter = Parameter(
name=par.name, value=par.val, frozen=par.frozen, min=par.min, max=par.max#, is_norm=is_norm
)
# TODO: set unit?
parameters.append(parameter)
return Parameters(parameters)
def _update_sherpa_parameters(self, *args, **kwargs):
"""Update sherpa model parameters"""
if kwargs:
for name, value in kwargs.items():
self.sherpa_model.pars[[x.name for x in self.sherpa_model.pars].index(name)].val=value
#setattr(self.sherpa_model, name, value)
else:
for i, val in enumerate(args):
self.sherpa_model.pars[i].val=val
def evaluate(self, energy, *args, **kwargs):
if not isinstance(energy, u.Quantity):
raise ValueError("The energy must be a Quantity object.")
else:
energy = energy.to(self.default_units[0])
# Trickeries due to the sherpa model evaluation scheme
# (https://sherpa.readthedocs.io/en/4.14.1/evaluation/index.html)
energy = np.array(energy, dtype=float)
shape = energy.shape
energy = energy.flatten()
energy = np.append(energy, energy[-1] * 2)
# Remove duplicate energies by adding a negligible value
# (otherwise there are problems with some models, e.g. wabs)
for idx in range(len(energy) - 1):
if energy[idx] == energy[idx + 1]:
delta = (energy[idx + 2] - energy[idx + 1]) / 100
energy[idx + 1] += delta
if kwargs:
self._update_sherpa_parameters(**kwargs)
else:
self._update_sherpa_parameters(*args)
print(args)
epsilon = np.min((energy[1:] - energy[:-1])*1e-5)
y_ = self.sherpa_model(energy,energy+epsilon)[:-1]
if self.integrated:
y_ /= epsilon#energy[1:] - energy[:-1]
y_ = y_ * self.default_units[1]
return y_.reshape(shape)
SPECTRAL_MODEL_REGISTRY.append(SherpaSpectralModel)