diff --git a/src/stan/services/util/initialize.hpp b/src/stan/services/util/initialize.hpp
index a6467ae394d..c52c331b449 100644
--- a/src/stan/services/util/initialize.hpp
+++ b/src/stan/services/util/initialize.hpp
@@ -36,8 +36,8 @@ namespace util {
*
* When at least some of the initialization is random, it will
* randomly initialize until it finds a set of unconstrained
- * parameters that are valid or it hits MAX_INIT_TRIES =
- * 100 (hard-coded).
+ * parameters that are valid or it hits max_tries,
+ * which defaults to 100.
*
* Valid initialization is defined as a finite, non-NaN value for the
* evaluation of the log probability density function and all its
@@ -58,6 +58,8 @@ namespace util {
* be printed to the logger
* @param[in,out] logger logger for messages
* @param[in,out] init_writer init writer (on the unconstrained scale)
+ * @param[in] max_tries The maximum number of times a random initialization
+ * will be re-tried to achieve a finite log density and gradient. Default 100.
* @throws exception passed through from the model if the model has a
* fatal error (not a std::domain_error)
* @throws std::domain_error if the model can not be initialized and
@@ -70,7 +72,8 @@ template initialize(Model& model, const InitContext& init, RNG& rng,
double init_radius, bool print_timing,
stan::callbacks::logger& logger,
- stan::callbacks::writer& init_writer) {
+ stan::callbacks::writer& init_writer,
+ int max_tries = 100) {
std::vector unconstrained;
std::vector disc_vector;
@@ -86,7 +89,7 @@ std::vector initialize(Model& model, const InitContext& init, RNG& rng,
bool is_initialized_with_zero = init_radius == 0.0;
int MAX_INIT_TRIES
- = is_fully_initialized || is_initialized_with_zero ? 1 : 100;
+ = is_fully_initialized || is_initialized_with_zero ? 1 : max_tries;
int num_init_tries = 0;
for (; num_init_tries < MAX_INIT_TRIES; num_init_tries++) {
std::stringstream msg;
diff --git a/src/test/unit/services/util/initialize_test.cpp b/src/test/unit/services/util/initialize_test.cpp
index e268d9c20bc..39b1def24f6 100644
--- a/src/test/unit/services/util/initialize_test.cpp
+++ b/src/test/unit/services/util/initialize_test.cpp
@@ -5,11 +5,14 @@
#include
#include
#include
+#include
+#include
#include
#include
#include
#include
#include
+#include
class ServicesUtilInitialize : public testing::Test {
public:
@@ -577,3 +580,100 @@ TEST_F(ServicesUtilInitialize, model_throws_in_write_array__full_init) {
EXPECT_EQ(2, logger.call_count_error());
EXPECT_EQ(100, logger.find_warn("throwing within write_array"));
}
+
+namespace test {
+// Mock model that returns non-finite lp
+class mock_inf_model : public stan::model::prob_grad {
+ public:
+ mock_inf_model()
+ : stan::model::prob_grad(1),
+ templated_log_prob_calls(0),
+ transform_inits_calls(0),
+ write_array_calls(0),
+ log_prob_return_value(-std::numeric_limits::infinity()) {}
+
+ void reset() {
+ templated_log_prob_calls = 0;
+ transform_inits_calls = 0;
+ write_array_calls = 0;
+ log_prob_return_value = 0.0;
+ }
+
+ template
+ T__ log_prob(std::vector& params_r__, std::vector& params_i__,
+ std::ostream* pstream__ = 0) const {
+ ++templated_log_prob_calls;
+ return log_prob_return_value;
+ }
+
+ void transform_inits(const stan::io::var_context& context__,
+ std::vector& params_i__,
+ std::vector& params_r__,
+ std::ostream* pstream__) const {
+ ++transform_inits_calls;
+ for (size_t n = 0; n < params_r__.size(); ++n) {
+ params_r__[n] = n;
+ }
+ }
+
+ void get_dims(std::vector >& dimss__,
+ bool include_tparams = true, bool include_gqs = true) const {
+ dimss__.resize(0);
+ std::vector scalar_dim;
+ dimss__.push_back(scalar_dim);
+ }
+
+ void constrained_param_names(std::vector& param_names__,
+ bool include_tparams__ = true,
+ bool include_gqs__ = true) const {
+ param_names__.push_back("theta");
+ }
+
+ void get_param_names(std::vector& names,
+ bool include_tparams = true,
+ bool include_gqs = true) const {
+ constrained_param_names(names);
+ }
+
+ void unconstrained_param_names(std::vector& param_names__,
+ bool include_tparams__ = true,
+ bool include_gqs__ = true) const {
+ param_names__.clear();
+ for (size_t n = 0; n < num_params_r__; ++n) {
+ std::stringstream param_name;
+ param_name << "param_" << n;
+ param_names__.push_back(param_name.str());
+ }
+ }
+ template
+ void write_array(RNG& base_rng__, std::vector& params_r__,
+ std::vector& params_i__, std::vector& vars__,
+ bool include_tparams__ = true, bool include_gqs__ = true,
+ std::ostream* pstream__ = 0) const {
+ ++write_array_calls;
+ vars__.resize(0);
+ for (size_t i = 0; i < params_r__.size(); ++i)
+ vars__.push_back(params_r__[i]);
+ }
+
+ mutable int templated_log_prob_calls;
+ mutable int transform_inits_calls;
+ mutable int write_array_calls;
+ double log_prob_return_value;
+};
+} // namespace test
+
+TEST_F(ServicesUtilInitialize, model_errors_retries) {
+ test::mock_inf_model error_model;
+
+ double init_radius = 1.2;
+ bool print_timing = false;
+ EXPECT_THROW_MSG(stan::services::util::initialize(
+ error_model, empty_context, rng, init_radius,
+ print_timing, logger, init, 23),
+ std::domain_error, "Initialization failed.");
+ EXPECT_EQ(23, error_model.templated_log_prob_calls);
+ EXPECT_EQ(72, logger.call_count());
+ EXPECT_EQ(2, logger.call_count_error());
+ EXPECT_EQ(1, logger.find_error("after 23 attempts"));
+}