-
-
Notifications
You must be signed in to change notification settings - Fork 68
Expand file tree
/
Copy pathhessian.cpp
More file actions
41 lines (34 loc) · 1.25 KB
/
hessian.cpp
File metadata and controls
41 lines (34 loc) · 1.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include <rcpp_eigen_interop.hpp>
#include <stan/model/hessian.hpp>
template <class M>
struct hessian_wrapper {
const M& model;
bool jac_adjust;
std::ostream* o;
hessian_wrapper(const M& m, bool adj, std::ostream* out) : model(m), jac_adjust(adj), o(out) {}
template <typename T>
T operator()(const Eigen::Matrix<T, Eigen::Dynamic, 1>& x) const {
if (jac_adjust) {
// log_prob() requires non-const but doesn't modify its argument
return model.template log_prob<true, true, T>(
const_cast<Eigen::Matrix<T, -1, 1>&>(x), o);
} else {
// log_prob() requires non-const but doesn't modify its argument
return model.template log_prob<true, false, T>(
const_cast<Eigen::Matrix<T, -1, 1>&>(x), o);
}
}
};
// [[Rcpp::export]]
Rcpp::List hessian(SEXP ext_model_ptr, Eigen::VectorXd& upars, bool jac_adjust) {
Rcpp::XPtr<stan_model> ptr(ext_model_ptr);
double log_prob;
Eigen::VectorXd grad;
Eigen::MatrixXd hessian;
stan::math::hessian(hessian_wrapper<decltype(*ptr.get())>(*ptr.get(), jac_adjust, 0),
upars, log_prob, grad, hessian);
return Rcpp::List::create(
Rcpp::Named("log_prob") = log_prob,
Rcpp::Named("grad_log_prob") = grad,
Rcpp::Named("hessian") = hessian);
}