Stan Math Library  2.6.3
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerator Friends Macros
trace_quad_form.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_TRACE_QUAD_FORM_HPP
2 #define STAN_MATH_REV_MAT_FUN_TRACE_QUAD_FORM_HPP
3 
4 #include <boost/utility/enable_if.hpp>
5 #include <boost/type_traits.hpp>
6 #include <stan/math/rev/core.hpp>
15 
16 namespace stan {
17  namespace math {
18  namespace {
19  template <typename TA, int RA, int CA, typename TB, int RB, int CB>
20  class trace_quad_form_vari_alloc : public chainable_alloc {
21  public:
22  trace_quad_form_vari_alloc(const Eigen::Matrix<TA, RA, CA>& A,
23  const Eigen::Matrix<TB, RB, CB>& B)
24  : A_(A), B_(B)
25  { }
26 
27  double compute() {
30  value_of(B_));
31  }
32 
33  Eigen::Matrix<TA, RA, CA> A_;
34  Eigen::Matrix<TB, RB, CB> B_;
35  };
36 
37  template <typename TA, int RA, int CA, typename TB, int RB, int CB>
38  class trace_quad_form_vari : public vari {
39  protected:
40  static inline void chainA(Eigen::Matrix<double, RA, CA>& A,
41  const Eigen::Matrix<double, RB, CB>& Bd,
42  const double& adjC) {}
43  static inline void chainB(Eigen::Matrix<double, RB, CB>& B,
44  const Eigen::Matrix<double, RA, CA>& Ad,
45  const Eigen::Matrix<double, RB, CB>& Bd,
46  const double& adjC) {}
47 
48  static inline void chainA(Eigen::Matrix<var, RA, CA>& A,
49  const Eigen::Matrix<double, RB, CB>& Bd,
50  const double& adjC) {
51  Eigen::Matrix<double, RA, CA> adjA(adjC*Bd*Bd.transpose());
52  for (int j = 0; j < A.cols(); j++)
53  for (int i = 0; i < A.rows(); i++)
54  A(i, j).vi_->adj_ += adjA(i, j);
55  }
56  static inline void chainB(Eigen::Matrix<var, RB, CB>& B,
57  const Eigen::Matrix<double, RA, CA>& Ad,
58  const Eigen::Matrix<double, RB, CB>& Bd,
59  const double& adjC) {
60  Eigen::Matrix<double, RA, CA> adjB(adjC*(Ad + Ad.transpose())*Bd);
61  for (int j = 0; j < B.cols(); j++)
62  for (int i = 0; i < B.rows(); i++)
63  B(i, j).vi_->adj_ += adjB(i, j);
64  }
65 
66  inline void chainAB(Eigen::Matrix<TA, RA, CA>& A,
67  Eigen::Matrix<TB, RB, CB>& B,
68  const Eigen::Matrix<double, RA, CA>& Ad,
69  const Eigen::Matrix<double, RB, CB>& Bd,
70  const double& adjC) {
71  chainA(A, Bd, adjC);
72  chainB(B, Ad, Bd, adjC);
73  }
74 
75 
76  public:
77  explicit
78  trace_quad_form_vari
79  (trace_quad_form_vari_alloc<TA, RA, CA, TB, RB, CB> *impl)
80  : vari(impl->compute()), _impl(impl) { }
81 
82  virtual void chain() {
84  chainAB(_impl->A_, _impl->B_,
85  value_of(_impl->A_), value_of(_impl->B_),
86  adj_);
87  }
88 
89  trace_quad_form_vari_alloc<TA, RA, CA, TB, RB, CB> *_impl;
90  };
91  }
92 
93  template <typename TA, int RA, int CA, typename TB, int RB, int CB>
94  inline typename
95  boost::enable_if_c< boost::is_same<TA, var>::value ||
96  boost::is_same<TB, var>::value,
97  var >::type
98  trace_quad_form(const Eigen::Matrix<TA, RA, CA>& A,
99  const Eigen::Matrix<TB, RB, CB>& B) {
100  stan::math::check_square("trace_quad_form", "A", A);
101  stan::math::check_multiplicable("trace_quad_form",
102  "A", A,
103  "B", B);
104 
105  trace_quad_form_vari_alloc<TA, RA, CA, TB, RB, CB> *baseVari
106  = new trace_quad_form_vari_alloc<TA, RA, CA, TB, RB, CB>(A, B);
107 
108  return var(new trace_quad_form_vari<TA, RA, CA, TB, RB, CB>(baseVari));
109  }
110  }
111 }
112 
113 #endif
T value_of(const fvar< T > &v)
Return the value of the specified variable.
Definition: value_of.hpp:16
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:32
Eigen::Matrix< TA, RA, CA > A_
stan::math::fvar< T > trace_quad_form(const Eigen::Matrix< stan::math::fvar< T >, RA, CA > &A, const Eigen::Matrix< stan::math::fvar< T >, RB, CB > &B)
bool check_multiplicable(const char *function, const char *name1, const T1 &y1, const char *name2, const T2 &y2)
Return true if the matrices can be multiplied.
Eigen::Matrix< TB, RB, CB > B_
trace_quad_form_vari_alloc< TA, RA, CA, TB, RB, CB > * _impl
bool check_square(const char *function, const char *name, const Eigen::Matrix< T_y, Eigen::Dynamic, Eigen::Dynamic > &y)
Return true if the specified matrix is square.

     [ Stan Home Page ] © 2011–2015, Stan Development Team.