Skip to content

Commit 195c3b5

Browse files
committed
rhat now checks that there are two chains. Fixes #533.
1 parent 03fabcc commit 195c3b5

2 files changed

Lines changed: 18 additions & 12 deletions

File tree

pints/_diagnostics.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -201,28 +201,29 @@ def rhat(chains, warm_up=0.0):
201201
"""
202202
if not (chains.ndim == 2 or chains.ndim == 3):
203203
raise ValueError(
204-
'Dimension of chains is %d. ' % chains.ndim
205-
+ 'Method computes Rhat for one '
206-
'or multiple parameters and therefore only accepts 2 or 3 '
207-
'dimensional arrays.')
204+
f'Dimension of chains is {chains.ndim}. This method computes Rhat'
205+
' for one or multiple parameters and therefore only accepts 2 or 3'
206+
' dimensional arrays.')
208207
if warm_up > 1 or warm_up < 0:
209208
raise ValueError(
210-
'`warm_up` is set to %f. `warm_up` only takes values in [0,1].' %
211-
warm_up)
209+
f'`warm_up` is set to {warm_up}. `warm_up` only takes values in'
210+
' [0,1].')
211+
if chains.shape[0] < 2:
212+
raise ValueError('Number of chains needs to be 2 or higher')
212213

213214
# Get number of samples
214215
n = chains.shape[1]
216+
if n < 2:
217+
raise ValueError(
218+
'Number of samples per chain after warm-up and chain splitting is '
219+
f'{n}. Method needs at least 2 samples per chain.')
220+
215221

216222
# Exclude warm-up
217223
chains = chains[:, int(n * warm_up):]
218-
n = chains.shape[1]
219224

220225
# Split chains in half
221226
n = n // 2 # new length of chains
222-
if n < 1:
223-
raise ValueError(
224-
'Number of samples per chain after warm-up and chain splitting is '
225-
'%d. Method needs at least 1 sample per chain.' % n)
226227
chains = np.vstack([chains[:, :n], chains[:, -n:]])
227228

228229
# Compute mean within-chain variance

pints/tests/test_diagnostics.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,12 @@ def test_bad_rhat_inputs(self):
138138
# Pass only a single chain
139139
chains = np.empty(shape=(1, 5))
140140
self.assertRaisesRegex(
141-
ValueError, 'only accepts 2 or 3 dimensional', pints.rhat, chains)
141+
ValueError, '2 or higher', pints.rhat, chains)
142+
143+
# Pass only a single sample
144+
chains = np.empty(shape=(5, 1))
145+
self.assertRaisesRegex(
146+
ValueError, 'at least 2 samples', pints.rhat, chains)
142147

143148
# Pass bad warm-up arguments
144149
chains = np.empty(shape=(2, 4))

0 commit comments

Comments
 (0)