1 #ifndef STAN_MATH_REV_MAT_FUN_TRACE_GEN_QUAD_FORM_HPP
2 #define STAN_MATH_REV_MAT_FUN_TRACE_GEN_QUAD_FORM_HPP
4 #include <boost/utility/enable_if.hpp>
5 #include <boost/type_traits.hpp>
19 template <
typename TD,
int RD,
int CD,
20 typename TA,
int RA,
int CA,
21 typename TB,
int RB,
int CB>
22 class trace_gen_quad_form_vari_alloc :
public chainable_alloc {
24 trace_gen_quad_form_vari_alloc(
const Eigen::Matrix<TD, RD, CD>& D,
25 const Eigen::Matrix<TA, RA, CA>& A,
26 const Eigen::Matrix<TB, RB, CB>& B)
37 Eigen::Matrix<TD, RD, CD>
D_;
38 Eigen::Matrix<TA, RA, CA>
A_;
39 Eigen::Matrix<TB, RB, CB>
B_;
42 template <
typename TD,
int RD,
int CD,
43 typename TA,
int RA,
int CA,
44 typename TB,
int RB,
int CB>
45 class trace_gen_quad_form_vari :
public vari {
48 computeAdjoints(
const double& adj,
49 const Eigen::Matrix<double, RD, CD>& D,
50 const Eigen::Matrix<double, RA, CA>& A,
51 const Eigen::Matrix<double, RB, CB>& B,
52 Eigen::Matrix<var, RD, CD> *varD,
53 Eigen::Matrix<var, RA, CA> *varA,
54 Eigen::Matrix<var, RB, CB> *varB) {
55 Eigen::Matrix<double, CA, CB> AtB;
56 Eigen::Matrix<double, RA, CB> BD;
60 AtB.noalias() = A.transpose()*B;
63 Eigen::Matrix<double, RB, CB> adjB(adj*(A*BD + AtB*D.transpose()));
64 for (
int j = 0; j < B.cols(); j++)
65 for (
int i = 0; i < B.rows(); i++)
66 (*varB)(i, j).vi_->adj_ += adjB(i, j);
69 Eigen::Matrix<double, RA, CA> adjA(adj*(B*BD.transpose()));
70 for (
int j = 0; j < A.cols(); j++)
71 for (
int i = 0; i < A.rows(); i++)
72 (*varA)(i, j).vi_->adj_ += adjA(i, j);
75 Eigen::Matrix<double, RD, CD> adjD(adj*(B.transpose()*AtB));
76 for (
int j = 0; j < D.cols(); j++)
77 for (
int i = 0; i < D.rows(); i++)
78 (*varD)(i, j).vi_->adj_ += adjD(i, j);
85 trace_gen_quad_form_vari(trace_gen_quad_form_vari_alloc
86 <TD, RD, CD, TA, RA, CA, TB, RB, CB> *impl)
87 : vari(impl->compute()),
_impl(impl) { }
89 virtual void chain() {
95 reinterpret_cast<Eigen::Matrix<var, RD, CD> *>
96 (boost::is_same<TD, var>::value?(&
_impl->D_):NULL),
97 reinterpret_cast<Eigen::Matrix<var, RA, CA> *>
98 (boost::is_same<TA, var>::value?(&
_impl->A_):NULL),
99 reinterpret_cast<Eigen::Matrix<var, RB, CB> *>
100 (boost::is_same<TB, var>::value?(&
_impl->B_):NULL));
103 trace_gen_quad_form_vari_alloc<TD, RD, CD, TA, RA, CA, TB, RB, CB>
108 template <
typename TD,
int RD,
int CD,
109 typename TA,
int RA,
int CA,
110 typename TB,
int RB,
int CB>
112 boost::enable_if_c< boost::is_same<TD, var>::value ||
113 boost::is_same<TA, var>::value ||
114 boost::is_same<TB, var>::value,
117 const Eigen::Matrix<TA, RA, CA>& A,
118 const Eigen::Matrix<TB, RB, CB>& B) {
128 trace_gen_quad_form_vari_alloc<TD, RD, CD, TA, RA, CA, TB, RB, CB>
130 =
new trace_gen_quad_form_vari_alloc<TD, RD, CD, TA, RA, CA, TB, RB, CB>
133 return var(
new trace_gen_quad_form_vari
134 <TD, RD, CD, TA, RA, CA, TB, RB, CB>(baseVari));
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Independent (input) and dependent (output) variables for gradients.
fvar< T > trace_gen_quad_form(const Eigen::Matrix< fvar< T >, RD, CD > &D, const Eigen::Matrix< fvar< T >, RA, CA > &A, const Eigen::Matrix< fvar< T >, RB, CB > &B)
bool check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Return true if the matrices can be multiplied.
bool check_square(const char *function, const char *name, const Eigen::Matrix< T_y, Eigen::Dynamic, Eigen::Dynamic > &y)
Return true if the specified matrix is square.