@@ -33,8 +33,17 @@ def setUp(self):
3333 name = "b" , alpha = 3 , minimum = 1 , maximum = 2 , unit = "m/s" , boundary = None
3434 )
3535 self .third_prior = bilby .core .prior .DeltaFunction (name = "c" , peak = 42 , unit = "m" )
36+
37+ mvg = bilby .core .prior .MultivariateGaussianDist (
38+ names = ["testa" , "testb" ],
39+ mus = [1 , 1 ],
40+ covs = np .array ([[2.0 , 0.5 ], [0.5 , 2.0 ]]),
41+ weights = 1.0 ,
42+ )
43+ self .testa = bilby .core .prior .MultivariateGaussian (dist = mvg , name = "testa" , unit = "unit" )
44+ self .testb = bilby .core .prior .MultivariateGaussian (dist = mvg , name = "testb" , unit = "unit" )
3645 self .priors = dict (
37- mass = self .first_prior , speed = self .second_prior , length = self .third_prior
46+ mass = self .first_prior , speed = self .second_prior , length = self .third_prior , testa = self . testa , testb = self . testb
3847 )
3948 self .prior_set_from_dict = bilby .core .prior .PriorDict (dictionary = self .priors )
4049 self .default_prior_file = os .path .join (
@@ -70,7 +79,7 @@ def test_prior_set_is_dict(self):
7079 self .assertIsInstance (self .prior_set_from_dict , dict )
7180
7281 def test_prior_set_has_correct_length (self ):
73- self .assertEqual (3 , len (self .prior_set_from_dict ))
82+ self .assertEqual (5 , len (self .prior_set_from_dict ))
7483
7584 def test_prior_set_has_expected_priors (self ):
7685 self .assertDictEqual (self .priors , dict (self .prior_set_from_dict ))
@@ -160,6 +169,12 @@ def test_to_file(self):
160169 "unit='m/s', boundary=None)\n " ,
161170 "mass = Uniform(minimum=0, maximum=1, name='a', latex_label='a', "
162171 "unit='kg', boundary=None)\n " ,
172+ "testa_testb_mvg = MultivariateGaussianDist(names=['testa', 'testb'], nmodes=1, mus=[[1, 1]], "
173+ "sigmas=[[1.4142135623730951, 1.4142135623730951]], "
174+ "corrcoefs=[[[0.9999999999999998, 0.24999999999999994], [0.24999999999999994, 0.9999999999999998]]], "
175+ "covs=[[[2.0, 0.5], [0.5, 2.0]]], weights=[1.0], bounds={'testa': (-inf, inf), 'testb': (-inf, inf)})\n " ,
176+ "testa = MultivariateGaussian(dist=testa_testb_mvg, name='testa', latex_label='testa', unit='unit')\n " ,
177+ "testb = MultivariateGaussian(dist=testa_testb_mvg, name='testb', latex_label='testb', unit='unit')\n " ,
163178 ]
164179 self .prior_set_from_dict .to_file (outdir = "prior_files" , label = "to_file_test" )
165180 with open ("prior_files/to_file_test.prior" ) as f :
@@ -178,6 +193,13 @@ def test_from_dict_with_string(self):
178193 self .assertDictEqual (self .prior_set_from_dict , from_dict )
179194
180195 def test_convert_floats_to_delta_functions (self ):
196+ mvg = bilby .core .prior .MultivariateGaussianDist (
197+ names = ["testa" , "testb" ],
198+ mus = [1 , 1 ],
199+ covs = np .array ([[2.0 , 0.5 ], [0.5 , 2.0 ]]),
200+ weights = 1.0 ,
201+ )
202+
181203 self .prior_set_from_dict ["d" ] = 5
182204 self .prior_set_from_dict ["e" ] = 7.3
183205 self .prior_set_from_dict ["f" ] = "unconvertable"
@@ -190,6 +212,8 @@ def test_convert_floats_to_delta_functions(self):
190212 name = "b" , alpha = 3 , minimum = 1 , maximum = 2 , unit = "m/s" , boundary = None
191213 ),
192214 length = bilby .core .prior .DeltaFunction (name = "c" , peak = 42 , unit = "m" ),
215+ testa = bilby .core .prior .MultivariateGaussian (dist = mvg , name = "testa" , unit = "unit" ),
216+ testb = bilby .core .prior .MultivariateGaussian (dist = mvg , name = "testb" , unit = "unit" ),
193217 d = bilby .core .prior .DeltaFunction (peak = 5 ),
194218 e = bilby .core .prior .DeltaFunction (peak = 7.3 ),
195219 f = "unconvertable" ,
@@ -321,12 +345,15 @@ def test_ln_prob(self):
321345 self .assertEqual (expected , self .prior_set_from_dict .ln_prob (samples ))
322346
323347 def test_rescale (self ):
324- theta = [0.5 , 0.5 , 0.5 ]
348+ theta = [0.5 , 0.5 , 0.5 , 0.5 , 0.5 ]
325349 expected = [
326350 self .first_prior .rescale (0.5 ),
327351 self .second_prior .rescale (0.5 ),
328352 self .third_prior .rescale (0.5 ),
353+ self .testa .rescale (0.5 ),
354+ self .testb .rescale (0.5 )
329355 ]
356+ assert not np .any (np .isnan (expected ))
330357 self .assertListEqual (
331358 sorted (expected ),
332359 sorted (
@@ -342,7 +369,7 @@ def test_cdf(self):
342369
343370 Note that the format of inputs/outputs is different between the two methods.
344371 """
345- sample = self .prior_set_from_dict .sample ( )
372+ sample = self .prior_set_from_dict .sample_subset ( keys = [ "length" , "speed" , "mass" ] )
346373 original = np .array (list (sample .values ()))
347374 new = np .array (self .prior_set_from_dict .rescale (
348375 sample .keys (),
0 commit comments