1 #ifndef STAN_MATH_REV_MAT_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
2 #define STAN_MATH_REV_MAT_FUN_TRACE_INV_QUAD_FORM_LDLT_HPP
8 #include <boost/utility/enable_if.hpp>
15 template <
typename T2,
int R2,
int C2,
typename T3,
int R3,
int C3>
16 class trace_inv_quad_form_ldlt_impl :
public chainable_alloc {
18 inline void initializeB(
const Eigen::Matrix<var, R3, C3> &B,
20 Eigen::Matrix<double, R3, C3> Bd(B.rows(), B.cols());
21 _variB.resize(B.rows(), B.cols());
22 for (
int j = 0; j < B.cols(); j++) {
23 for (
int i = 0; i < B.rows(); i++) {
24 _variB(i, j) = B(i, j).vi_;
25 Bd(i, j) = B(i, j).val();
30 C_.noalias() = Bd.transpose()*
AinvB_;
34 inline void initializeB(
const Eigen::Matrix<double, R3, C3> &B,
43 template<
int R1,
int C1>
44 inline void initializeD(
const Eigen::Matrix<var, R1, C1> &D) {
45 D_.resize(D.rows(), D.cols());
46 _variD.resize(D.rows(), D.cols());
47 for (
int j = 0; j < D.cols(); j++) {
48 for (
int i = 0; i < D.rows(); i++) {
49 _variD(i, j) = D(i, j).vi_;
50 D_(i, j) = D(i, j).val();
54 template<
int R1,
int C1>
55 inline void initializeD(
const Eigen::Matrix<double, R1, C1> &D) {
60 template<
typename T1,
int R1,
int C1>
61 trace_inv_quad_form_ldlt_impl(
const Eigen::Matrix<T1, R1, C1> &D,
64 const Eigen::Matrix<T3, R3, C3> &B)
75 const Eigen::Matrix<T3, R3, C3> &B)
78 initializeB(B,
false);
83 Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>
D_;
84 Eigen::Matrix<vari*, Eigen::Dynamic, Eigen::Dynamic>
_variD;
86 Eigen::Matrix<double, R3, C3>
AinvB_;
87 Eigen::Matrix<double, C3, C3>
C_;
91 template <
typename T2,
int R2,
int C2,
typename T3,
int R3,
int C3>
92 class trace_inv_quad_form_ldlt_vari :
public vari {
96 chainA(
const double &adj,
97 trace_inv_quad_form_ldlt_impl<double, R2, C2, T3, R3, C3>
102 chainB(
const double &adj,
103 trace_inv_quad_form_ldlt_impl<T2, R2, C2, double, R3, C3>
109 chainA(
const double &adj,
110 trace_inv_quad_form_ldlt_impl<var, R2, C2, T3, R3, C3> *impl) {
111 Eigen::Matrix<double, R2, C2> aA;
113 if (impl->Dtype_ != 2)
114 aA.noalias() = -adj * (impl->AinvB_ * impl->D_.transpose()
115 * impl->AinvB_.transpose());
117 aA.noalias() = -adj*(impl->AinvB_ * impl->AinvB_.transpose());
119 for (
int j = 0; j < aA.cols(); j++)
120 for (
int i = 0; i < aA.rows(); i++)
121 impl->_ldlt._alloc->_variA(i, j)->adj_ += aA(i, j);
125 chainB(
const double &adj,
126 trace_inv_quad_form_ldlt_impl<T2, R2, C2, var, R3, C3> *impl) {
127 Eigen::Matrix<double, R3, C3> aB;
129 if (impl->Dtype_ != 2)
130 aB.noalias() = adj*impl->AinvB_*(impl->D_ + impl->D_.transpose());
132 aB.noalias() = 2*adj*impl->AinvB_;
134 for (
int j = 0; j < aB.cols(); j++)
135 for (
int i = 0; i < aB.rows(); i++)
136 impl->_variB(i, j)->adj_ += aB(i, j);
140 explicit trace_inv_quad_form_ldlt_vari
141 (trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *impl)
145 virtual void chain() {
154 if (
_impl->Dtype_ == 1) {
155 for (
int j = 0; j <
_impl->_variD.cols(); j++)
156 for (
int i = 0; i <
_impl->_variD.rows(); i++)
157 _impl->_variD(i, j)->adj_ += adj_*
_impl->C_(i, j);
161 trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *
_impl;
172 template <
typename T2,
int R2,
int C2,
typename T3,
int R3,
int C3>
174 boost::enable_if_c<stan::is_var<T2>::value ||
178 const Eigen::Matrix<T3, R3, C3> &B) {
183 trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3> *
_impl
184 =
new trace_inv_quad_form_ldlt_impl<T2, R2, C2, T3, R3, C3>(A, B);
186 return var(
new trace_inv_quad_form_ldlt_vari<T2, R2, C2, T3, R3, C3>
boost::enable_if_c<!stan::is_var< T1 >::value &&!stan::is_var< T2 >::value, typename boost::math::tools::promote_args< T1, T2 >::type >::type trace_inv_quad_form_ldlt(const stan::math::LDLT_factor< T1, R2, C2 > &A, const Eigen::Matrix< T2, R3, C3 > &B)
Independent (input) and dependent (output) variables for gradients.
bool check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Return true if the matrices can be multiplied.
T trace(const Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > &m)
Returns the trace of the specified matrix.