5858)
5959from cmdstanpy .utils import (
6060 EXTENSION ,
61- MaybeDictToFilePath ,
6261 SanitizedOrTmpFilePath ,
6362 cmdstan_path ,
6463 cmdstan_version ,
6766 get_logger ,
6867 returncode_msg ,
6968)
69+ from cmdstanpy .utils .filesystem import temp_inits , temp_single_json
7070
7171from . import progress as progbar
7272
@@ -573,7 +573,7 @@ def optimize(
573573 self ,
574574 data : Union [Mapping [str , Any ], str , os .PathLike , None ] = None ,
575575 seed : Optional [int ] = None ,
576- inits : Union [Dict [str , float ], float , str , os .PathLike , None ] = None ,
576+ inits : Union [Mapping [str , Any ], float , str , os .PathLike , None ] = None ,
577577 output_dir : OptionalPath = None ,
578578 sig_figs : Optional [int ] = None ,
579579 save_profile : bool = False ,
@@ -722,7 +722,9 @@ def optimize(
722722 "in CmdStan 2.32 and above."
723723 )
724724
725- with MaybeDictToFilePath (data , inits ) as (_data , _inits ):
725+ with temp_single_json (data ) as _data , temp_inits (
726+ inits , allow_multiple = False
727+ ) as _inits :
726728 args = CmdStanArgs (
727729 self ._name ,
728730 self ._exe_file ,
@@ -766,7 +768,14 @@ def sample(
766768 threads_per_chain : Optional [int ] = None ,
767769 seed : Union [int , List [int ], None ] = None ,
768770 chain_ids : Union [int , List [int ], None ] = None ,
769- inits : Union [Dict [str , float ], float , str , List [str ], None ] = None ,
771+ inits : Union [
772+ Mapping [str , Any ],
773+ float ,
774+ str ,
775+ List [str ],
776+ List [Mapping [str , Any ]],
777+ None ,
778+ ] = None ,
770779 iter_warmup : Optional [int ] = None ,
771780 iter_sampling : Optional [int ] = None ,
772781 save_warmup : bool = False ,
@@ -1006,6 +1015,69 @@ def sample(
10061015 chains
10071016 )
10081017 )
1018+
1019+ if parallel_chains is None :
1020+ parallel_chains = max (min (cpu_count (), chains ), 1 )
1021+ elif parallel_chains > chains :
1022+ get_logger ().info (
1023+ 'Requested %u parallel_chains but only %u required, '
1024+ 'will run all chains in parallel.' ,
1025+ parallel_chains ,
1026+ chains ,
1027+ )
1028+ parallel_chains = chains
1029+ elif parallel_chains < 1 :
1030+ raise ValueError (
1031+ 'Argument parallel_chains must be a positive integer, '
1032+ 'found {}.' .format (parallel_chains )
1033+ )
1034+ if threads_per_chain is None :
1035+ threads_per_chain = 1
1036+ if threads_per_chain < 1 :
1037+ raise ValueError (
1038+ 'Argument threads_per_chain must be a positive integer, '
1039+ 'found {}.' .format (threads_per_chain )
1040+ )
1041+
1042+ parallel_procs = parallel_chains
1043+ num_threads = threads_per_chain
1044+ one_process_per_chain = True
1045+ info_dict = self .exe_info ()
1046+ stan_threads = info_dict .get ('STAN_THREADS' , 'false' ).lower ()
1047+ # run multi-chain sampler unless algo is fixed_param or 1 chain
1048+ if fixed_param or (chains == 1 ):
1049+ force_one_process_per_chain = True
1050+
1051+ if (
1052+ force_one_process_per_chain is None
1053+ and not cmdstan_version_before (2 , 28 , info_dict )
1054+ and stan_threads == 'true'
1055+ ):
1056+ one_process_per_chain = False
1057+ num_threads = parallel_chains * num_threads
1058+ parallel_procs = 1
1059+ if force_one_process_per_chain is False :
1060+ if not cmdstan_version_before (2 , 28 , info_dict ):
1061+ one_process_per_chain = False
1062+ num_threads = parallel_chains * num_threads
1063+ parallel_procs = 1
1064+ if stan_threads == 'false' :
1065+ get_logger ().warning (
1066+ 'Stan program not compiled for threading, '
1067+ 'process will run chains sequentially. '
1068+ 'For multi-chain parallelization, recompile '
1069+ 'the model with argument '
1070+ '"cpp_options={\' STAN_THREADS\' :\' TRUE\' }.'
1071+ )
1072+ else :
1073+ get_logger ().warning (
1074+ 'Installed version of CmdStan cannot multi-process '
1075+ 'chains, will run %d processes. '
1076+ 'Run "install_cmdstan" to upgrade to latest version.' ,
1077+ chains ,
1078+ )
1079+ os .environ ['STAN_NUM_THREADS' ] = str (num_threads )
1080+
10091081 if chain_ids is None :
10101082 chain_ids = [i + 1 for i in range (chains )]
10111083 else :
@@ -1017,6 +1089,13 @@ def sample(
10171089 )
10181090 chain_ids = [i + chain_ids for i in range (chains )]
10191091 else :
1092+ if not one_process_per_chain :
1093+ for i , j in zip (chain_ids , chain_ids [1 :]):
1094+ if i != j - 1 :
1095+ raise ValueError (
1096+ 'chain_ids must be sequential list of integers,'
1097+ ' found {}.' .format (chain_ids )
1098+ )
10201099 if not len (chain_ids ) == chains :
10211100 raise ValueError (
10221101 'Chain_ids must correspond to number of chains'
@@ -1032,6 +1111,7 @@ def sample(
10321111 )
10331112
10341113 sampler_args = SamplerArgs (
1114+ num_chains = 1 if one_process_per_chain else chains ,
10351115 iter_warmup = iter_warmup ,
10361116 iter_sampling = iter_sampling ,
10371117 save_warmup = save_warmup ,
@@ -1046,14 +1126,25 @@ def sample(
10461126 adapt_step_size = adapt_step_size ,
10471127 fixed_param = fixed_param ,
10481128 )
1049- with MaybeDictToFilePath (data , inits ) as (_data , _inits ):
1129+
1130+ with temp_single_json (data ) as _data , temp_inits (
1131+ inits , id = chain_ids [0 ]
1132+ ) as _inits :
1133+ cmdstan_inits : Union [str , List [str ], int , float , None ]
1134+ if one_process_per_chain and isinstance (inits , list ): # legacy
1135+ cmdstan_inits = [
1136+ f"{ _inits [:- 5 ]} _{ i } .json" for i in chain_ids # type: ignore
1137+ ]
1138+ else :
1139+ cmdstan_inits = _inits
1140+
10501141 args = CmdStanArgs (
10511142 self ._name ,
10521143 self ._exe_file ,
10531144 chain_ids = chain_ids ,
10541145 data = _data ,
10551146 seed = seed ,
1056- inits = _inits ,
1147+ inits = cmdstan_inits ,
10571148 output_dir = output_dir ,
10581149 sig_figs = sig_figs ,
10591150 save_latent_dynamics = save_latent_dynamics ,
@@ -1062,68 +1153,6 @@ def sample(
10621153 refresh = refresh ,
10631154 )
10641155
1065- if parallel_chains is None :
1066- parallel_chains = max (min (cpu_count (), chains ), 1 )
1067- elif parallel_chains > chains :
1068- get_logger ().info (
1069- 'Requested %u parallel_chains but only %u required, '
1070- 'will run all chains in parallel.' ,
1071- parallel_chains ,
1072- chains ,
1073- )
1074- parallel_chains = chains
1075- elif parallel_chains < 1 :
1076- raise ValueError (
1077- 'Argument parallel_chains must be a positive integer, '
1078- 'found {}.' .format (parallel_chains )
1079- )
1080- if threads_per_chain is None :
1081- threads_per_chain = 1
1082- if threads_per_chain < 1 :
1083- raise ValueError (
1084- 'Argument threads_per_chain must be a positive integer, '
1085- 'found {}.' .format (threads_per_chain )
1086- )
1087-
1088- parallel_procs = parallel_chains
1089- num_threads = threads_per_chain
1090- one_process_per_chain = True
1091- info_dict = self .exe_info ()
1092- stan_threads = info_dict .get ('STAN_THREADS' , 'false' ).lower ()
1093- # run multi-chain sampler unless algo is fixed_param or 1 chain
1094- if fixed_param or (chains == 1 ):
1095- force_one_process_per_chain = True
1096-
1097- if (
1098- force_one_process_per_chain is None
1099- and not cmdstan_version_before (2 , 28 , info_dict )
1100- and stan_threads == 'true'
1101- ):
1102- one_process_per_chain = False
1103- num_threads = parallel_chains * num_threads
1104- parallel_procs = 1
1105- if force_one_process_per_chain is False :
1106- if not cmdstan_version_before (2 , 28 , info_dict ):
1107- one_process_per_chain = False
1108- num_threads = parallel_chains * num_threads
1109- parallel_procs = 1
1110- if stan_threads == 'false' :
1111- get_logger ().warning (
1112- 'Stan program not compiled for threading, '
1113- 'process will run chains sequentially. '
1114- 'For multi-chain parallelization, recompile '
1115- 'the model with argument '
1116- '"cpp_options={\' STAN_THREADS\' :\' TRUE\' }.'
1117- )
1118- else :
1119- get_logger ().warning (
1120- 'Installed version of CmdStan cannot multi-process '
1121- 'chains, will run %d processes. '
1122- 'Run "install_cmdstan" to upgrade to latest version.' ,
1123- chains ,
1124- )
1125- os .environ ['STAN_NUM_THREADS' ] = str (num_threads )
1126-
11271156 if show_console :
11281157 show_progress = False
11291158 else :
@@ -1376,7 +1405,7 @@ def generate_quantities(
13761405 csv_files = fit_csv_files
13771406 )
13781407 generate_quantities_args .validate (chains )
1379- with MaybeDictToFilePath (data , None ) as ( _data , _inits ) :
1408+ with temp_single_json (data ) as _data :
13801409 args = CmdStanArgs (
13811410 self ._name ,
13821411 self ._exe_file ,
@@ -1551,7 +1580,9 @@ def variational(
15511580 output_samples = output_samples ,
15521581 )
15531582
1554- with MaybeDictToFilePath (data , inits ) as (_data , _inits ):
1583+ with temp_single_json (data ) as _data , temp_inits (
1584+ inits , allow_multiple = False
1585+ ) as _inits :
15551586 args = CmdStanArgs (
15561587 self ._name ,
15571588 self ._exe_file ,
@@ -1658,7 +1689,9 @@ def log_prob(
16581689 "Method 'log_prob' not available for CmdStan versions "
16591690 "before 2.31"
16601691 )
1661- with MaybeDictToFilePath (data , params ) as (_data , _params ):
1692+ with temp_single_json (data ) as _data , temp_single_json (
1693+ params
1694+ ) as _params :
16621695 cmd = [
16631696 str (self .exe_file ),
16641697 "log_prob" ,
@@ -1766,7 +1799,7 @@ def laplace_sample(
17661799 cmdstan_mode .runset .csv_files [0 ], draws , jacobian
17671800 )
17681801
1769- with MaybeDictToFilePath (data ) as ( _data ,) :
1802+ with temp_single_json (data ) as _data :
17701803 args = CmdStanArgs (
17711804 self ._name ,
17721805 self ._exe_file ,
0 commit comments