Skip to content

Commit 107a347

Browse files
authored
Merge pull request #688 from stan-dev/feature/multiple-inits
Allow multiple dicts in inits, fix multichain
2 parents 7cf6483 + e27e387 commit 107a347

7 files changed

Lines changed: 266 additions & 151 deletions

File tree

cmdstanpy/cmdstan_args.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
adapt_metric_window: Optional[int] = None,
5656
adapt_step_size: Optional[int] = None,
5757
fixed_param: bool = False,
58+
num_chains: int = 1,
5859
) -> None:
5960
"""Initialize object."""
6061
self.iter_warmup = iter_warmup
@@ -73,6 +74,7 @@ def __init__(
7374
self.adapt_step_size = adapt_step_size
7475
self.fixed_param = fixed_param
7576
self.diagnostic_file = None
77+
self.num_chains = num_chains
7678

7779
def validate(self, chains: Optional[int]) -> None:
7880
"""
@@ -316,6 +318,10 @@ def validate(self, chains: Optional[int]) -> None:
316318
'Argument "adapt_step_size" must be a non-negative integer,'
317319
'found {}'.format(self.adapt_step_size)
318320
)
321+
if self.num_chains < 1 or not isinstance(
322+
self.num_chains, (int, np.integer)
323+
):
324+
raise ValueError("num_chains must be positive")
319325

320326
if self.fixed_param and (
321327
self.max_treedepth is not None
@@ -378,6 +384,8 @@ def compose(self, idx: int, cmd: List[str]) -> List[str]:
378384
cmd.append('window={}'.format(self.adapt_metric_window))
379385
if self.adapt_step_size is not None:
380386
cmd.append('term_buffer={}'.format(self.adapt_step_size))
387+
if self.num_chains > 1:
388+
cmd.append('num_chains={}'.format(self.num_chains))
381389

382390
return cmd
383391

@@ -921,8 +929,12 @@ def validate(self) -> None:
921929
)
922930
)
923931
elif isinstance(self.inits, str):
924-
if not os.path.exists(self.inits):
925-
raise ValueError('no such file {}'.format(self.inits))
932+
if not (
933+
isinstance(self.method_args, SamplerArgs)
934+
and self.method_args.num_chains > 1
935+
):
936+
if not os.path.exists(self.inits):
937+
raise ValueError('no such file {}'.format(self.inits))
926938
elif isinstance(self.inits, list):
927939
if self.chain_ids is None:
928940
raise ValueError(
@@ -948,7 +960,6 @@ def compose_command(
948960
*,
949961
diagnostic_file: Optional[str] = None,
950962
profile_file: Optional[str] = None,
951-
num_chains: Optional[int] = None,
952963
) -> List[str]:
953964
"""
954965
Compose CmdStan command for non-default arguments.
@@ -992,6 +1003,4 @@ def compose_command(
9921003
if self.sig_figs is not None:
9931004
cmd.append('sig_figs={}'.format(self.sig_figs))
9941005
cmd = self.method_args.compose(idx, cmd)
995-
if num_chains:
996-
cmd.append('num_chains={}'.format(num_chains))
9971006
return cmd

cmdstanpy/model.py

Lines changed: 105 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
)
5959
from cmdstanpy.utils import (
6060
EXTENSION,
61-
MaybeDictToFilePath,
6261
SanitizedOrTmpFilePath,
6362
cmdstan_path,
6463
cmdstan_version,
@@ -67,6 +66,7 @@
6766
get_logger,
6867
returncode_msg,
6968
)
69+
from cmdstanpy.utils.filesystem import temp_inits, temp_single_json
7070

7171
from . import progress as progbar
7272

@@ -573,7 +573,7 @@ def optimize(
573573
self,
574574
data: Union[Mapping[str, Any], str, os.PathLike, None] = None,
575575
seed: Optional[int] = None,
576-
inits: Union[Dict[str, float], float, str, os.PathLike, None] = None,
576+
inits: Union[Mapping[str, Any], float, str, os.PathLike, None] = None,
577577
output_dir: OptionalPath = None,
578578
sig_figs: Optional[int] = None,
579579
save_profile: bool = False,
@@ -722,7 +722,9 @@ def optimize(
722722
"in CmdStan 2.32 and above."
723723
)
724724

725-
with MaybeDictToFilePath(data, inits) as (_data, _inits):
725+
with temp_single_json(data) as _data, temp_inits(
726+
inits, allow_multiple=False
727+
) as _inits:
726728
args = CmdStanArgs(
727729
self._name,
728730
self._exe_file,
@@ -766,7 +768,14 @@ def sample(
766768
threads_per_chain: Optional[int] = None,
767769
seed: Union[int, List[int], None] = None,
768770
chain_ids: Union[int, List[int], None] = None,
769-
inits: Union[Dict[str, float], float, str, List[str], None] = None,
771+
inits: Union[
772+
Mapping[str, Any],
773+
float,
774+
str,
775+
List[str],
776+
List[Mapping[str, Any]],
777+
None,
778+
] = None,
770779
iter_warmup: Optional[int] = None,
771780
iter_sampling: Optional[int] = None,
772781
save_warmup: bool = False,
@@ -1006,6 +1015,69 @@ def sample(
10061015
chains
10071016
)
10081017
)
1018+
1019+
if parallel_chains is None:
1020+
parallel_chains = max(min(cpu_count(), chains), 1)
1021+
elif parallel_chains > chains:
1022+
get_logger().info(
1023+
'Requested %u parallel_chains but only %u required, '
1024+
'will run all chains in parallel.',
1025+
parallel_chains,
1026+
chains,
1027+
)
1028+
parallel_chains = chains
1029+
elif parallel_chains < 1:
1030+
raise ValueError(
1031+
'Argument parallel_chains must be a positive integer, '
1032+
'found {}.'.format(parallel_chains)
1033+
)
1034+
if threads_per_chain is None:
1035+
threads_per_chain = 1
1036+
if threads_per_chain < 1:
1037+
raise ValueError(
1038+
'Argument threads_per_chain must be a positive integer, '
1039+
'found {}.'.format(threads_per_chain)
1040+
)
1041+
1042+
parallel_procs = parallel_chains
1043+
num_threads = threads_per_chain
1044+
one_process_per_chain = True
1045+
info_dict = self.exe_info()
1046+
stan_threads = info_dict.get('STAN_THREADS', 'false').lower()
1047+
# run multi-chain sampler unless algo is fixed_param or 1 chain
1048+
if fixed_param or (chains == 1):
1049+
force_one_process_per_chain = True
1050+
1051+
if (
1052+
force_one_process_per_chain is None
1053+
and not cmdstan_version_before(2, 28, info_dict)
1054+
and stan_threads == 'true'
1055+
):
1056+
one_process_per_chain = False
1057+
num_threads = parallel_chains * num_threads
1058+
parallel_procs = 1
1059+
if force_one_process_per_chain is False:
1060+
if not cmdstan_version_before(2, 28, info_dict):
1061+
one_process_per_chain = False
1062+
num_threads = parallel_chains * num_threads
1063+
parallel_procs = 1
1064+
if stan_threads == 'false':
1065+
get_logger().warning(
1066+
'Stan program not compiled for threading, '
1067+
'process will run chains sequentially. '
1068+
'For multi-chain parallelization, recompile '
1069+
'the model with argument '
1070+
'"cpp_options={\'STAN_THREADS\':\'TRUE\'}.'
1071+
)
1072+
else:
1073+
get_logger().warning(
1074+
'Installed version of CmdStan cannot multi-process '
1075+
'chains, will run %d processes. '
1076+
'Run "install_cmdstan" to upgrade to latest version.',
1077+
chains,
1078+
)
1079+
os.environ['STAN_NUM_THREADS'] = str(num_threads)
1080+
10091081
if chain_ids is None:
10101082
chain_ids = [i + 1 for i in range(chains)]
10111083
else:
@@ -1017,6 +1089,13 @@ def sample(
10171089
)
10181090
chain_ids = [i + chain_ids for i in range(chains)]
10191091
else:
1092+
if not one_process_per_chain:
1093+
for i, j in zip(chain_ids, chain_ids[1:]):
1094+
if i != j - 1:
1095+
raise ValueError(
1096+
'chain_ids must be sequential list of integers,'
1097+
' found {}.'.format(chain_ids)
1098+
)
10201099
if not len(chain_ids) == chains:
10211100
raise ValueError(
10221101
'Chain_ids must correspond to number of chains'
@@ -1032,6 +1111,7 @@ def sample(
10321111
)
10331112

10341113
sampler_args = SamplerArgs(
1114+
num_chains=1 if one_process_per_chain else chains,
10351115
iter_warmup=iter_warmup,
10361116
iter_sampling=iter_sampling,
10371117
save_warmup=save_warmup,
@@ -1046,14 +1126,25 @@ def sample(
10461126
adapt_step_size=adapt_step_size,
10471127
fixed_param=fixed_param,
10481128
)
1049-
with MaybeDictToFilePath(data, inits) as (_data, _inits):
1129+
1130+
with temp_single_json(data) as _data, temp_inits(
1131+
inits, id=chain_ids[0]
1132+
) as _inits:
1133+
cmdstan_inits: Union[str, List[str], int, float, None]
1134+
if one_process_per_chain and isinstance(inits, list): # legacy
1135+
cmdstan_inits = [
1136+
f"{_inits[:-5]}_{i}.json" for i in chain_ids # type: ignore
1137+
]
1138+
else:
1139+
cmdstan_inits = _inits
1140+
10501141
args = CmdStanArgs(
10511142
self._name,
10521143
self._exe_file,
10531144
chain_ids=chain_ids,
10541145
data=_data,
10551146
seed=seed,
1056-
inits=_inits,
1147+
inits=cmdstan_inits,
10571148
output_dir=output_dir,
10581149
sig_figs=sig_figs,
10591150
save_latent_dynamics=save_latent_dynamics,
@@ -1062,68 +1153,6 @@ def sample(
10621153
refresh=refresh,
10631154
)
10641155

1065-
if parallel_chains is None:
1066-
parallel_chains = max(min(cpu_count(), chains), 1)
1067-
elif parallel_chains > chains:
1068-
get_logger().info(
1069-
'Requested %u parallel_chains but only %u required, '
1070-
'will run all chains in parallel.',
1071-
parallel_chains,
1072-
chains,
1073-
)
1074-
parallel_chains = chains
1075-
elif parallel_chains < 1:
1076-
raise ValueError(
1077-
'Argument parallel_chains must be a positive integer, '
1078-
'found {}.'.format(parallel_chains)
1079-
)
1080-
if threads_per_chain is None:
1081-
threads_per_chain = 1
1082-
if threads_per_chain < 1:
1083-
raise ValueError(
1084-
'Argument threads_per_chain must be a positive integer, '
1085-
'found {}.'.format(threads_per_chain)
1086-
)
1087-
1088-
parallel_procs = parallel_chains
1089-
num_threads = threads_per_chain
1090-
one_process_per_chain = True
1091-
info_dict = self.exe_info()
1092-
stan_threads = info_dict.get('STAN_THREADS', 'false').lower()
1093-
# run multi-chain sampler unless algo is fixed_param or 1 chain
1094-
if fixed_param or (chains == 1):
1095-
force_one_process_per_chain = True
1096-
1097-
if (
1098-
force_one_process_per_chain is None
1099-
and not cmdstan_version_before(2, 28, info_dict)
1100-
and stan_threads == 'true'
1101-
):
1102-
one_process_per_chain = False
1103-
num_threads = parallel_chains * num_threads
1104-
parallel_procs = 1
1105-
if force_one_process_per_chain is False:
1106-
if not cmdstan_version_before(2, 28, info_dict):
1107-
one_process_per_chain = False
1108-
num_threads = parallel_chains * num_threads
1109-
parallel_procs = 1
1110-
if stan_threads == 'false':
1111-
get_logger().warning(
1112-
'Stan program not compiled for threading, '
1113-
'process will run chains sequentially. '
1114-
'For multi-chain parallelization, recompile '
1115-
'the model with argument '
1116-
'"cpp_options={\'STAN_THREADS\':\'TRUE\'}.'
1117-
)
1118-
else:
1119-
get_logger().warning(
1120-
'Installed version of CmdStan cannot multi-process '
1121-
'chains, will run %d processes. '
1122-
'Run "install_cmdstan" to upgrade to latest version.',
1123-
chains,
1124-
)
1125-
os.environ['STAN_NUM_THREADS'] = str(num_threads)
1126-
11271156
if show_console:
11281157
show_progress = False
11291158
else:
@@ -1376,7 +1405,7 @@ def generate_quantities(
13761405
csv_files=fit_csv_files
13771406
)
13781407
generate_quantities_args.validate(chains)
1379-
with MaybeDictToFilePath(data, None) as (_data, _inits):
1408+
with temp_single_json(data) as _data:
13801409
args = CmdStanArgs(
13811410
self._name,
13821411
self._exe_file,
@@ -1551,7 +1580,9 @@ def variational(
15511580
output_samples=output_samples,
15521581
)
15531582

1554-
with MaybeDictToFilePath(data, inits) as (_data, _inits):
1583+
with temp_single_json(data) as _data, temp_inits(
1584+
inits, allow_multiple=False
1585+
) as _inits:
15551586
args = CmdStanArgs(
15561587
self._name,
15571588
self._exe_file,
@@ -1658,7 +1689,9 @@ def log_prob(
16581689
"Method 'log_prob' not available for CmdStan versions "
16591690
"before 2.31"
16601691
)
1661-
with MaybeDictToFilePath(data, params) as (_data, _params):
1692+
with temp_single_json(data) as _data, temp_single_json(
1693+
params
1694+
) as _params:
16621695
cmd = [
16631696
str(self.exe_file),
16641697
"log_prob",
@@ -1766,7 +1799,7 @@ def laplace_sample(
17661799
cmdstan_mode.runset.csv_files[0], draws, jacobian
17671800
)
17681801

1769-
with MaybeDictToFilePath(data) as (_data,):
1802+
with temp_single_json(data) as _data:
17701803
args = CmdStanArgs(
17711804
self._name,
17721805
self._exe_file,

cmdstanpy/stanfit/runset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ def cmd(self, idx: int) -> List[str]:
179179
profile_file=self.file_path(".csv", extra="-profile")
180180
if self._args.save_profile
181181
else None,
182-
num_chains=self._chains,
183182
)
184183

185184
@property

cmdstanpy/utils/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from .command import do_command, returncode_msg
2323
from .data_munging import build_xarray_data, flatten_chains
2424
from .filesystem import (
25-
MaybeDictToFilePath,
2625
SanitizedOrTmpFilePath,
2726
create_named_text_file,
2827
pushd,
@@ -116,7 +115,6 @@ def show_versions(output: bool = True) -> str:
116115
__all__ = [
117116
'BaseType',
118117
'EXTENSION',
119-
'MaybeDictToFilePath',
120118
'SanitizedOrTmpFilePath',
121119
'build_xarray_data',
122120
'check_sampler_csv',

0 commit comments

Comments
 (0)