Skip to content

Commit 60d7be1

Browse files
committed
Improved nomenclature, added comments, new tests that include joint priors
1 parent 27f4ef6 commit 60d7be1

3 files changed

Lines changed: 66 additions & 26 deletions

File tree

bilby/core/prior/joint.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,11 @@ def __init__(self, names, bounds=None):
6363
self.requested_parameters = dict()
6464
self.reset_request()
6565

66-
# a dictionary of the rescale(d) parameters
67-
self._rescale_parameters = dict()
68-
self._rescaled_parameters = dict()
66+
# a dictionary that stores the unit-cube values of parameters for later rescaling
67+
self._current_unit_cube_parameter_values = dict()
68+
# a dictionary of arrays that are used as intermediate return values of JointPrior.rescale()
69+
# and updated in-place once all parameters have been requested
70+
self._current_rescaled_parameter_values = dict()
6971
self.reset_rescale()
7072

7173
# a list of sampled parameters
@@ -95,24 +97,24 @@ def filled_rescale(self):
9597
Check if all the rescaled parameters have been filled.
9698
"""
9799

98-
return not np.any([val is None for val in self._rescale_parameters.values()])
100+
return not np.any([val is None for val in self._current_unit_cube_parameter_values.values()])
99101

100102
def set_rescale(self, key, values):
101-
values = np.array(values)
102-
self._rescale_parameters[key] = values
103-
self._rescaled_parameters[key] = np.atleast_1d(np.ones_like(values)) * np.nan
103+
self._current_unit_cube_parameter_values[key] = np.array(values)
104+
self._current_rescaled_parameter_values[key] = np.full_like(values, np.nan, dtype=float)
104105

105106
def reset_rescale(self):
106107
"""
107108
Reset the rescaled parameters to None.
108109
"""
109-
110110
for name in self.names:
111-
self._rescale_parameters[name] = None
112-
self._rescaled_parameters[name] = None
111+
self._current_unit_cube_parameter_values[name] = None
112+
self._current_rescaled_parameter_values[name] = None
113113

114114
def get_rescaled(self, key):
115-
return self._rescaled_parameters[key]
115+
"""Return an array that will be updated in-place once the rescale-operation
116+
has been performed."""
117+
return self._current_rescaled_parameter_values[key]
116118

117119
def get_instantiation_dict(self):
118120
subclass_args = infer_args_from_method(self.__init__)
@@ -317,7 +319,7 @@ def rescale(self, value, **kwargs):
317319
If given, a 1d vector sample (one for each parameter) drawn from a uniform
318320
distribution between 0 and 1, or a 2d NxM array of samples where
319321
N is the number of samples and M is the number of parameters.
320-
If None, values previously set using BaseJointPriorDist.set_rescale() are used.
322+
If None, the values previously set using BaseJointPriorDist.set_rescale() are used.
321323
kwargs: dict
322324
All keyword args that need to be passed to _rescale method, these keyword
323325
args are called in the JointPrior rescale methods for each parameter
@@ -329,9 +331,11 @@ def rescale(self, value, **kwargs):
329331
distribution.
330332
"""
331333
if value is None:
332-
samp = np.array(list(self._rescale_parameters.values())).T
334+
samp = np.array(list(self._current_unit_cube_parameter_values.values())).T
333335
else:
334-
samp = np.array(value)
336+
for key, val in zip(self.names, value):
337+
self.set_rescale(key, val)
338+
samp = np.asarray(value)
335339

336340
if len(samp.shape) == 1:
337341
samp = samp.reshape(1, self.num_vars)
@@ -342,11 +346,12 @@ def rescale(self, value, **kwargs):
342346
raise ValueError("Array is the wrong shape")
343347

344348
samp = self._rescale(samp, **kwargs)
345-
if value is None:
346-
for i, key in enumerate(self.names):
347-
output = self.get_rescaled(key)
348-
# update in-place for proper handling in PriorDict-instances
349-
output[:] = samp[:, i]
349+
for i, key in enumerate(self.names):
350+
# get the numpy array used for indermediate outputs
351+
# prior to a full rescale-operation
352+
output = self.get_rescaled(key)
353+
# update the array in-place
354+
output[...] = samp[:, i]
350355
return np.squeeze(samp)
351356

352357
def _rescale(self, samp, **kwargs):
@@ -819,10 +824,16 @@ def rescale(self, val, **kwargs):
819824
self.dist.set_rescale(self.name, val)
820825

821826
if self.dist.filled_rescale():
827+
# If all names have been filled, perform rescale operation
822828
self.dist.rescale(value=None, **kwargs)
829+
# get the rescaled values for the requested parameter
823830
output = self.dist.get_rescaled(self.name)
831+
# reset the rescale operation
824832
self.dist.reset_rescale()
825833
else:
834+
# If not all names have been filled, return a *numpy array*
835+
# filled only with `np.nan`. Once all names have been requested,
836+
# this array is updated *in-place* with the rescaled values.
826837
output = self.dist.get_rescaled(self.name)
827838

828839
# have to return raw output to conserve in-place modifications

test/core/prior/conditional_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def test_rescale_with_joint_prior(self):
334334

335335
# set multivariate Gaussian distribution
336336
names = ["mvgvar_0", "mvgvar_1"]
337-
mu = [[0.79, -0.83]]
337+
mu = [[1, 1]]
338338
cov = [[[0.03, 0.], [0., 0.04]]]
339339
mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov)
340340

@@ -349,7 +349,7 @@ def test_rescale_with_joint_prior(self):
349349
)
350350
)
351351

352-
ref_variables = list(self.test_sample.values()) + [0.4, 0.1]
352+
ref_variables = list(self.test_sample.values()) + [0.5, 0.5]
353353
keys = list(self.test_sample.keys()) + names
354354
res = priordict.rescale(keys=keys, theta=ref_variables)
355355

@@ -359,9 +359,11 @@ def test_rescale_with_joint_prior(self):
359359

360360
# check conditional values are still as expected
361361
expected = [self.test_sample["var_0"]]
362+
self.assertFalse(np.any(np.isnan(res)))
362363
for ii in range(1, 4):
363364
expected.append(expected[-1] * self.test_sample[f"var_{ii}"])
364-
self.assertListEqual(expected, res[0:4])
365+
expected.extend([1, 1])
366+
self.assertListEqual(expected, res)
365367

366368
def test_cdf(self):
367369
"""

test/core/prior/dict_test.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,17 @@ def setUp(self):
3333
name="b", alpha=3, minimum=1, maximum=2, unit="m/s", boundary=None
3434
)
3535
self.third_prior = bilby.core.prior.DeltaFunction(name="c", peak=42, unit="m")
36+
37+
mvg = bilby.core.prior.MultivariateGaussianDist(
38+
names=["testa", "testb"],
39+
mus=[1, 1],
40+
covs=np.array([[2.0, 0.5], [0.5, 2.0]]),
41+
weights=1.0,
42+
)
43+
self.testa = bilby.core.prior.MultivariateGaussian(dist=mvg, name="testa", unit="unit")
44+
self.testb = bilby.core.prior.MultivariateGaussian(dist=mvg, name="testb", unit="unit")
3645
self.priors = dict(
37-
mass=self.first_prior, speed=self.second_prior, length=self.third_prior
46+
mass=self.first_prior, speed=self.second_prior, length=self.third_prior, testa=self.testa, testb=self.testb
3847
)
3948
self.prior_set_from_dict = bilby.core.prior.PriorDict(dictionary=self.priors)
4049
self.default_prior_file = os.path.join(
@@ -70,7 +79,7 @@ def test_prior_set_is_dict(self):
7079
self.assertIsInstance(self.prior_set_from_dict, dict)
7180

7281
def test_prior_set_has_correct_length(self):
73-
self.assertEqual(3, len(self.prior_set_from_dict))
82+
self.assertEqual(5, len(self.prior_set_from_dict))
7483

7584
def test_prior_set_has_expected_priors(self):
7685
self.assertDictEqual(self.priors, dict(self.prior_set_from_dict))
@@ -160,6 +169,12 @@ def test_to_file(self):
160169
"unit='m/s', boundary=None)\n",
161170
"mass = Uniform(minimum=0, maximum=1, name='a', latex_label='a', "
162171
"unit='kg', boundary=None)\n",
172+
"testa_testb_mvg = MultivariateGaussianDist(names=['testa', 'testb'], nmodes=1, mus=[[1, 1]], "
173+
"sigmas=[[1.4142135623730951, 1.4142135623730951]], "
174+
"corrcoefs=[[[0.9999999999999998, 0.24999999999999994], [0.24999999999999994, 0.9999999999999998]]], "
175+
"covs=[[[2.0, 0.5], [0.5, 2.0]]], weights=[1.0], bounds={'testa': (-inf, inf), 'testb': (-inf, inf)})\n",
176+
"testa = MultivariateGaussian(dist=testa_testb_mvg, name='testa', latex_label='testa', unit='unit')\n",
177+
"testb = MultivariateGaussian(dist=testa_testb_mvg, name='testb', latex_label='testb', unit='unit')\n",
163178
]
164179
self.prior_set_from_dict.to_file(outdir="prior_files", label="to_file_test")
165180
with open("prior_files/to_file_test.prior") as f:
@@ -178,6 +193,13 @@ def test_from_dict_with_string(self):
178193
self.assertDictEqual(self.prior_set_from_dict, from_dict)
179194

180195
def test_convert_floats_to_delta_functions(self):
196+
mvg = bilby.core.prior.MultivariateGaussianDist(
197+
names=["testa", "testb"],
198+
mus=[1, 1],
199+
covs=np.array([[2.0, 0.5], [0.5, 2.0]]),
200+
weights=1.0,
201+
)
202+
181203
self.prior_set_from_dict["d"] = 5
182204
self.prior_set_from_dict["e"] = 7.3
183205
self.prior_set_from_dict["f"] = "unconvertable"
@@ -190,6 +212,8 @@ def test_convert_floats_to_delta_functions(self):
190212
name="b", alpha=3, minimum=1, maximum=2, unit="m/s", boundary=None
191213
),
192214
length=bilby.core.prior.DeltaFunction(name="c", peak=42, unit="m"),
215+
testa=bilby.core.prior.MultivariateGaussian(dist=mvg, name="testa", unit="unit"),
216+
testb=bilby.core.prior.MultivariateGaussian(dist=mvg, name="testb", unit="unit"),
193217
d=bilby.core.prior.DeltaFunction(peak=5),
194218
e=bilby.core.prior.DeltaFunction(peak=7.3),
195219
f="unconvertable",
@@ -321,12 +345,15 @@ def test_ln_prob(self):
321345
self.assertEqual(expected, self.prior_set_from_dict.ln_prob(samples))
322346

323347
def test_rescale(self):
324-
theta = [0.5, 0.5, 0.5]
348+
theta = [0.5, 0.5, 0.5, 0.5, 0.5]
325349
expected = [
326350
self.first_prior.rescale(0.5),
327351
self.second_prior.rescale(0.5),
328352
self.third_prior.rescale(0.5),
353+
self.testa.rescale(0.5),
354+
self.testb.rescale(0.5)
329355
]
356+
assert not np.any(np.isnan(expected))
330357
self.assertListEqual(
331358
sorted(expected),
332359
sorted(
@@ -342,7 +369,7 @@ def test_cdf(self):
342369
343370
Note that the format of inputs/outputs is different between the two methods.
344371
"""
345-
sample = self.prior_set_from_dict.sample()
372+
sample = self.prior_set_from_dict.sample_subset(keys=["length", "speed", "mass"])
346373
original = np.array(list(sample.values()))
347374
new = np.array(self.prior_set_from_dict.rescale(
348375
sample.keys(),

0 commit comments

Comments
 (0)