Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/stan/services/util/initialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ namespace util {
*
* When at least some of the initialization is random, it will
* randomly initialize until it finds a set of unconstrained
* parameters that are valid or it hits <code>MAX_INIT_TRIES =
* 100</code> (hard-coded).
* parameters that are valid or it hits <code>max_tries</code>,
* which defaults to 100.
*
* Valid initialization is defined as a finite, non-NaN value for the
* evaluation of the log probability density function and all its
Expand All @@ -58,6 +58,8 @@ namespace util {
* be printed to the logger
* @param[in,out] logger logger for messages
* @param[in,out] init_writer init writer (on the unconstrained scale)
* @param[in] max_tries The maximum number of times a random initialization
* will be re-tried to achieve a finite log density and gradient. Default 100.
* @throws exception passed through from the model if the model has a
* fatal error (not a std::domain_error)
* @throws std::domain_error if the model can not be initialized and
Expand All @@ -70,7 +72,8 @@ template <bool Jacobian = true, typename Model, typename InitContext,
std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
double init_radius, bool print_timing,
stan::callbacks::logger& logger,
stan::callbacks::writer& init_writer) {
stan::callbacks::writer& init_writer,
int max_tries = 100) {
std::vector<double> unconstrained;
std::vector<int> disc_vector;

Expand All @@ -86,7 +89,7 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
bool is_initialized_with_zero = init_radius == 0.0;

int MAX_INIT_TRIES
= is_fully_initialized || is_initialized_with_zero ? 1 : 100;
= is_fully_initialized || is_initialized_with_zero ? 1 : max_tries;
int num_init_tries = 0;
for (; num_init_tries < MAX_INIT_TRIES; num_init_tries++) {
std::stringstream msg;
Expand Down
100 changes: 100 additions & 0 deletions src/test/unit/services/util/initialize_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
#include <stan/services/util/create_rng.hpp>
#include <stan/callbacks/stream_writer.hpp>
#include <stan/callbacks/stream_logger.hpp>
#include <stan/model/prob_grad.hpp>
#include <stdexcept>
#include <test/test-models/good/services/test_lp.hpp>
#include <test/unit/util.hpp>
#include <test/unit/services/instrumented_callbacks.hpp>
#include <gtest/gtest.h>
#include <sstream>
#include <limits>

class ServicesUtilInitialize : public testing::Test {
public:
Expand Down Expand Up @@ -577,3 +580,100 @@ TEST_F(ServicesUtilInitialize, model_throws_in_write_array__full_init) {
EXPECT_EQ(2, logger.call_count_error());
EXPECT_EQ(100, logger.find_warn("throwing within write_array"));
}

namespace test {
// Mock model that returns non-finite lp
class mock_inf_model : public stan::model::prob_grad {
public:
mock_inf_model()
: stan::model::prob_grad(1),
templated_log_prob_calls(0),
transform_inits_calls(0),
write_array_calls(0),
log_prob_return_value(-std::numeric_limits<double>::infinity()) {}

void reset() {
templated_log_prob_calls = 0;
transform_inits_calls = 0;
write_array_calls = 0;
log_prob_return_value = 0.0;
}

template <bool propto__, bool jacobian__, typename T__>
T__ log_prob(std::vector<T__>& params_r__, std::vector<int>& params_i__,
std::ostream* pstream__ = 0) const {
++templated_log_prob_calls;
return log_prob_return_value;
}

void transform_inits(const stan::io::var_context& context__,
std::vector<int>& params_i__,
std::vector<double>& params_r__,
std::ostream* pstream__) const {
++transform_inits_calls;
for (size_t n = 0; n < params_r__.size(); ++n) {
params_r__[n] = n;
}
}

void get_dims(std::vector<std::vector<size_t> >& dimss__,
bool include_tparams = true, bool include_gqs = true) const {
dimss__.resize(0);
std::vector<size_t> scalar_dim;
dimss__.push_back(scalar_dim);
}

void constrained_param_names(std::vector<std::string>& param_names__,
bool include_tparams__ = true,
bool include_gqs__ = true) const {
param_names__.push_back("theta");
}

void get_param_names(std::vector<std::string>& names,
bool include_tparams = true,
bool include_gqs = true) const {
constrained_param_names(names);
}

void unconstrained_param_names(std::vector<std::string>& param_names__,
bool include_tparams__ = true,
bool include_gqs__ = true) const {
param_names__.clear();
for (size_t n = 0; n < num_params_r__; ++n) {
std::stringstream param_name;
param_name << "param_" << n;
param_names__.push_back(param_name.str());
}
}
template <typename RNG>
void write_array(RNG& base_rng__, std::vector<double>& params_r__,
std::vector<int>& params_i__, std::vector<double>& vars__,
bool include_tparams__ = true, bool include_gqs__ = true,
std::ostream* pstream__ = 0) const {
++write_array_calls;
vars__.resize(0);
for (size_t i = 0; i < params_r__.size(); ++i)
vars__.push_back(params_r__[i]);
}

mutable int templated_log_prob_calls;
mutable int transform_inits_calls;
mutable int write_array_calls;
double log_prob_return_value;
};
} // namespace test

TEST_F(ServicesUtilInitialize, model_errors_retries) {
test::mock_inf_model error_model;

double init_radius = 1.2;
bool print_timing = false;
EXPECT_THROW_MSG(stan::services::util::initialize(
error_model, empty_context, rng, init_radius,
print_timing, logger, init, 23),
std::domain_error, "Initialization failed.");
EXPECT_EQ(23, error_model.templated_log_prob_calls);
EXPECT_EQ(72, logger.call_count());
EXPECT_EQ(2, logger.call_count_error());
EXPECT_EQ(1, logger.find_error("after 23 attempts"));
}
Loading