-
-
Notifications
You must be signed in to change notification settings - Fork 380
Expand file tree
/
Copy pathstandalone_gqs_test.cpp
More file actions
81 lines (73 loc) · 3 KB
/
standalone_gqs_test.cpp
File metadata and controls
81 lines (73 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include <boost/algorithm/string.hpp>
#include <gtest/gtest.h>
#include <iostream>
#include <stan/callbacks/stream_logger.hpp>
#include <stan/callbacks/stream_writer.hpp>
#include <stan/io/json/json_data.hpp>
#include <stan/io/stan_csv_reader.hpp>
#include <stan/services/error_codes.hpp>
#include <stan/services/sample/standalone_gqs.hpp>
#include <test/test-models/good/services/bernoulli.hpp>
#include <test/unit/services/instrumented_callbacks.hpp>
#include <test/unit/util.hpp>
#include <vector>
class ServicesStandaloneGQ : public ::testing::Test {
public:
ServicesStandaloneGQ()
: logger(logger_ss, logger_ss, logger_ss, logger_ss, logger_ss) {}
void SetUp() {
std::fstream data_stream(
"src/test/test-models/good/services/bernoulli.data.json",
std::fstream::in);
stan::json::json_data data_var_context(data_stream);
data_stream.close();
model = new stan_model(data_var_context);
}
void TearDown() { delete model; }
stan::test::unit::instrumented_interrupt interrupt;
std::stringstream logger_ss;
stan::callbacks::stream_logger logger;
stan_model *model;
};
TEST_F(ServicesStandaloneGQ, genDraws_bernoulli) {
stan::io::stan_csv bern_csv;
std::stringstream out;
std::ifstream csv_stream;
csv_stream.open("src/test/test-models/good/services/bernoulli_fit.csv");
bern_csv = stan::io::stan_csv_reader::parse(csv_stream, &out);
csv_stream.close();
EXPECT_EQ(12345U, bern_csv.metadata.seed);
ASSERT_EQ(19, bern_csv.header.size());
EXPECT_EQ("theta", bern_csv.header[7]);
ASSERT_EQ(1000, bern_csv.samples.rows());
ASSERT_EQ(19, bern_csv.samples.cols());
std::stringstream sample_ss;
stan::callbacks::stream_writer sample_writer(sample_ss, "");
int return_code = stan::services::standalone_generate(
*model, bern_csv.samples.middleCols<1>(7), 12345, interrupt, logger,
sample_writer);
EXPECT_EQ(return_code, stan::services::error_codes::OK);
EXPECT_EQ(count_matches("mu", sample_ss.str()), 1);
EXPECT_EQ(count_matches("y_rep", sample_ss.str()), 10);
EXPECT_EQ(count_matches("\n", sample_ss.str()), 1004);
match_csv_columns(bern_csv.samples, sample_ss.str(), 1000, 1, 8);
}
TEST_F(ServicesStandaloneGQ, genDraws_empty_draws) {
const Eigen::MatrixXd draws;
std::stringstream sample_ss;
stan::callbacks::stream_writer sample_writer(sample_ss, "");
int return_code = stan::services::standalone_generate(
*model, draws, 12345, interrupt, logger, sample_writer);
EXPECT_EQ(return_code, stan::services::error_codes::DATAERR);
EXPECT_EQ(count_matches("Empty set of draws", logger_ss.str()), 1);
}
TEST_F(ServicesStandaloneGQ, genDraws_bad) {
Eigen::MatrixXd draws(2, 2);
std::stringstream sample_ss;
stan::callbacks::stream_writer sample_writer(sample_ss, "");
int return_code = stan::services::standalone_generate(
*model, draws, 12345, interrupt, logger, sample_writer);
EXPECT_EQ(return_code, stan::services::error_codes::DATAERR);
EXPECT_EQ(count_matches("Wrong number of parameter values", logger_ss.str()),
1);
}