1 #ifndef STAN_MATH_REV_SCAL_FUN_FMA_HPP
2 #define STAN_MATH_REV_SCAL_FUN_FMA_HPP
7 #include <boost/math/special_functions/fpclassify.hpp>
14 T
fma(T x, T y, T z) {
23 class fma_vvv_vari :
public op_vvv_vari {
25 fma_vvv_vari(vari* avi, vari* bvi, vari* cvi) :
26 op_vvv_vari(::
fma(avi->val_, bvi->val_, cvi->val_),
33 avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
34 bvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
35 cvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
37 avi_->adj_ += adj_ * bvi_->val_;
38 bvi_->adj_ += adj_ * avi_->val_;
44 class fma_vvd_vari :
public op_vvd_vari {
46 fma_vvd_vari(vari* avi, vari* bvi,
double c) :
47 op_vvd_vari(::
fma(avi->val_, bvi->val_, c),
54 avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
55 bvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
57 avi_->adj_ += adj_ * bvi_->val_;
58 bvi_->adj_ += adj_ * avi_->val_;
63 class fma_vdv_vari :
public op_vdv_vari {
65 fma_vdv_vari(vari* avi,
double b, vari* cvi) :
66 op_vdv_vari(::
fma(avi->val_ , b, cvi->val_),
73 avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
74 cvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
76 avi_->adj_ += adj_ * bd_;
82 class fma_vdd_vari :
public op_vdd_vari {
84 fma_vdd_vari(vari* avi,
double b,
double c) :
85 op_vdd_vari(::
fma(avi->val_ , b, c),
92 avi_->adj_ = std::numeric_limits<double>::quiet_NaN();
94 avi_->adj_ += adj_ * bd_;
98 class fma_ddv_vari :
public op_ddv_vari {
100 fma_ddv_vari(
double a,
double b, vari* cvi) :
101 op_ddv_vari(::
fma(a, b, cvi->val_),
108 cvi_->adj_ = std::numeric_limits<double>::quiet_NaN();
164 return var(
new fma_vvd_vari(a.
vi_, b.
vi_, c));
189 return var(
new fma_vdv_vari(a.
vi_, b, c.
vi_));
212 return var(
new fma_vdd_vari(a.
vi_, b, c));
235 return var(
new fma_vdd_vari(b.
vi_, a, c));
258 return var(
new fma_ddv_vari(a, b, c.
vi_));
283 return var(
new fma_vdv_vari(b.
vi_, a, c.
vi_));
Independent (input) and dependent (output) variables for gradients.
bool isnan(const stan::math::var &v)
Checks if the given number is NaN.
fvar< typename stan::return_type< T1, T2, T3 >::type > fma(const fvar< T1 > &x1, const fvar< T2 > &x2, const fvar< T3 > &x3)
The fused multiply-add operation (C99).
vari * vi_
Pointer to the implementation of this variable.
var fma(const double &a, const stan::math::var &b, const stan::math::var &c)
The fused multiply-add function for a value and two variables (C99).