Stan Math Library  2.9.0
reverse mode automatic differentiation
OperandsAndPartials.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_PRIM_SCAL_META_OPERANDSANDPARTIALS_HPP
2 #define STAN_MATH_PRIM_SCAL_META_OPERANDSANDPARTIALS_HPP
3 
4 #include <stan/math/fwd/core.hpp>
21 #include <stan/math/rev/core.hpp>
24 
25 namespace stan {
26  namespace math {
27 
28  class partials_vari : public vari {
29  private:
30  const size_t N_;
31  vari** operands_;
32  double* partials_;
33  public:
34  partials_vari(double value,
35  size_t N,
36  vari** operands, double* partials)
37  : vari(value),
38  N_(N),
39  operands_(operands),
40  partials_(partials) { }
41  void chain() {
42  for (size_t n = 0; n < N_; ++n)
43  operands_[n]->adj_ += adj_ * partials_[n];
44  }
45  };
46 
47  namespace {
48  template<typename T1, typename T2, typename T3,
49  bool is_vec = is_vector<T2>::value,
50  bool is_const = is_constant_struct<T2>::value>
51  struct incr_deriv {
52  inline T3 incr(T1 d_x, const T2& x_d ) {
53  return 0;
54  }
55  };
56  template<typename T1, typename T2, typename T3>
57  struct incr_deriv<T1, T2, T3, false, false> {
58  inline T3 incr(T1 d_x, const T2& x_d) {
59  return d_x[0]*x_d.d_;
60  }
61  };
62  template<typename T1, typename T2, typename T3>
63  struct incr_deriv<T1, T2, T3, true, false> {
64  inline T3 incr(T1 d_x, const T2& x_d) {
65  T3 temp = 0;
66  for (size_t n = 0; n < length(x_d); n++)
67  temp += d_x[n] * x_d[n].d_;
68  return temp;
69  }
70  };
71 
72  template<typename T_return_type, typename T_partials_return,
73  typename T1, typename T2, typename T3, typename T4,
74  typename T5, typename T6,
77  struct partials_to_var {
78  inline
79  T_return_type to_var(double logp, size_t /* nvaris */,
80  vari** /* all_varis */,
81  T_partials_return* /* all_partials */,
82  const T1& x1, const T2& x2, const T3& x3,
83  const T4& x4, const T5& x5, const T6& x6,
84  VectorView<T_partials_return,
85  is_vector<T1>::value,
86  is_constant_struct<T1>::value> d_x1,
87  VectorView<T_partials_return,
88  is_vector<T2>::value,
89  is_constant_struct<T2>::value> d_x2,
90  VectorView<T_partials_return,
91  is_vector<T3>::value,
92  is_constant_struct<T3>::value> d_x3,
93  VectorView<T_partials_return,
94  is_vector<T4>::value,
95  is_constant_struct<T4>::value> d_x4,
96  VectorView<T_partials_return,
97  is_vector<T5>::value,
98  is_constant_struct<T5>::value> d_x5,
99  VectorView<T_partials_return,
100  is_vector<T6>::value,
101  is_constant_struct<T6>::value> d_x6) {
102  return logp;
103  }
104  };
105 
106  template<typename T_return_type, typename T_partials_return,
107  typename T1, typename T2, typename T3, typename T4,
108  typename T5, typename T6>
109  struct partials_to_var<T_return_type, T_partials_return,
110  T1, T2, T3, T4, T5, T6,
111  false, false> {
112  inline T_return_type to_var(T_partials_return logp, size_t nvaris,
113  vari** all_varis,
114  T_partials_return* all_partials,
115  const T1& x1, const T2& x2, const T3& x3,
116  const T4& x4, const T5& x5, const T6& x6,
117  VectorView<T_partials_return,
118  is_vector<T1>::value,
119  is_constant_struct<T1>::value> d_x1,
120  VectorView<T_partials_return,
121  is_vector<T2>::value,
122  is_constant_struct<T2>::value> d_x2,
123  VectorView<T_partials_return,
124  is_vector<T3>::value,
125  is_constant_struct<T3>::value> d_x3,
126  VectorView<T_partials_return,
127  is_vector<T4>::value,
128  is_constant_struct<T4>::value> d_x4,
129  VectorView<T_partials_return,
130  is_vector<T5>::value,
131  is_constant_struct<T5>::value> d_x5,
132  VectorView<T_partials_return,
133  is_vector<T6>::value,
134  is_constant_struct<T6>::value> d_x6) {
135  return var(new partials_vari(logp, nvaris, all_varis,
136  all_partials));
137  }
138  };
139 
140  template<typename T_return_type, typename T_partials_return,
141  typename T1, typename T2, typename T3, typename T4,
142  typename T5, typename T6>
143  struct partials_to_var<T_return_type, T_partials_return,
144  T1, T2, T3, T4, T5, T6,
145  true, false> {
146  inline T_return_type to_var(T_partials_return logp, size_t nvaris,
147  vari** all_varis,
148  T_partials_return* all_partials,
149  const T1& x1, const T2& x2, const T3& x3,
150  const T4& x4, const T5& x5, const T6& x6,
151  VectorView<T_partials_return,
152  is_vector<T1>::value,
153  is_constant_struct<T1>::value> d_x1,
154  VectorView<T_partials_return,
155  is_vector<T2>::value,
156  is_constant_struct<T2>::value> d_x2,
157  VectorView<T_partials_return,
158  is_vector<T3>::value,
159  is_constant_struct<T3>::value> d_x3,
160  VectorView<T_partials_return,
161  is_vector<T4>::value,
162  is_constant_struct<T4>::value> d_x4,
163  VectorView<T_partials_return,
164  is_vector<T5>::value,
165  is_constant_struct<T5>::value> d_x5,
166  VectorView<T_partials_return,
167  is_vector<T6>::value,
168  is_constant_struct<T6>::value> d_x6) {
169  T_partials_return temp_deriv = 0;
170  temp_deriv += incr_deriv<VectorView<T_partials_return,
171  is_vector<T1>::value,
172  is_constant_struct<T1>::value>,
173  T1, T_partials_return>().incr(d_x1, x1);
174  temp_deriv += incr_deriv<VectorView<T_partials_return,
175  is_vector<T2>::value,
176  is_constant_struct<T2>::value>,
177  T2, T_partials_return>().incr(d_x2, x2);
178  temp_deriv += incr_deriv<VectorView<T_partials_return,
179  is_vector<T3>::value,
180  is_constant_struct<T3>::value>,
181  T3, T_partials_return>().incr(d_x3, x3);
182  temp_deriv += incr_deriv<VectorView<T_partials_return,
183  is_vector<T4>::value,
184  is_constant_struct<T4>::value>,
185  T4, T_partials_return>().incr(d_x4, x4);
186  temp_deriv += incr_deriv<VectorView<T_partials_return,
187  is_vector<T5>::value,
188  is_constant_struct<T5>::value>,
189  T5, T_partials_return>().incr(d_x5, x5);
190  temp_deriv += incr_deriv<VectorView<T_partials_return,
191  is_vector<T6>::value,
192  is_constant_struct<T6>::value>,
193  T6, T_partials_return>().incr(d_x6, x6);
194  return stan::math::fvar<T_partials_return>(logp, temp_deriv);
195  }
196  };
197 
198  template<typename T,
199  bool is_vec = is_vector<T>::value,
200  bool is_const = is_constant_struct<T>::value,
201  bool contain_fvar = contains_fvar<T>::value>
202  struct set_varis {
203  inline size_t set(vari** /*varis*/, const T& /*x*/) {
204  return 0U;
205  }
206  };
207  template<typename T>
208  struct set_varis<T, true, false, false> {
209  inline size_t set(vari** varis, const T& x) {
210  for (size_t n = 0; n < length(x); n++)
211  varis[n] = x[n].vi_;
212  return length(x);
213  }
214  };
215  template<typename T>
216  struct set_varis<T, true, false, true> {
217  inline size_t set(vari** varis, const T& x) {
218  for (size_t n = 0; n < length(x); n++)
219  varis[n] = 0;
220  return length(x);
221  }
222  };
223  template<>
224  struct set_varis<var, false, false, false> {
225  inline size_t set(vari** varis, const var& x) {
226  varis[0] = x.vi_;
227  return (1);
228  }
229  };
230  }
231 
236  template<typename T1 = double, typename T2 = double, typename T3 = double,
237  typename T4 = double, typename T5 = double, typename T6 = double>
239  typedef
242 
243  typedef
245 
247  size_t nvaris;
249  T_partials_return* all_partials;
250 
269 
270  OperandsAndPartials(const T1& x1 = 0, const T2& x2 = 0, const T3& x3 = 0,
271  const T4& x4 = 0, const T5& x5 = 0, const T6& x6 = 0)
272  : nvaris(!is_constant_struct<T1>::value * length(x1) +
273  !is_constant_struct<T2>::value * length(x2) +
274  !is_constant_struct<T3>::value * length(x3) +
275  !is_constant_struct<T4>::value * length(x4) +
276  !is_constant_struct<T5>::value * length(x5) +
277  !is_constant_struct<T6>::value * length(x6)),
278  // TODO(carpenter): replace with array allocation fun
279  all_varis(static_cast<vari**>
280  (vari::operator new
281  (sizeof(vari*) * nvaris))),
282  all_partials(static_cast<T_partials_return*>
283  (vari::operator new
284  (sizeof(T_partials_return) * nvaris))),
285  d_x1(all_partials),
286  d_x2(all_partials
287  + (!is_constant_struct<T1>::value) * length(x1)),
288  d_x3(all_partials
289  + (!is_constant_struct<T1>::value) * length(x1)
290  + (!is_constant_struct<T2>::value) * length(x2)),
291  d_x4(all_partials
292  + (!is_constant_struct<T1>::value) * length(x1)
293  + (!is_constant_struct<T2>::value) * length(x2)
294  + (!is_constant_struct<T3>::value) * length(x3)),
295  d_x5(all_partials
296  + (!is_constant_struct<T1>::value) * length(x1)
297  + (!is_constant_struct<T2>::value) * length(x2)
298  + (!is_constant_struct<T3>::value) * length(x3)
299  + (!is_constant_struct<T4>::value) * length(x4)),
300  d_x6(all_partials
301  + (!is_constant_struct<T1>::value) * length(x1)
302  + (!is_constant_struct<T2>::value) * length(x2)
303  + (!is_constant_struct<T3>::value) * length(x3)
304  + (!is_constant_struct<T4>::value) * length(x4)
305  + (!is_constant_struct<T5>::value) * length(x5)) {
306  size_t base = 0;
308  base += set_varis<T1>().set(&all_varis[base], x1);
310  base += set_varis<T2>().set(&all_varis[base], x2);
312  base += set_varis<T3>().set(&all_varis[base], x3);
314  base += set_varis<T4>().set(&all_varis[base], x4);
316  base += set_varis<T5>().set(&all_varis[base], x5);
318  set_varis<T6>().set(&all_varis[base], x6);
319  std::fill(all_partials, all_partials+nvaris, 0);
320  }
321 
322  T_return_type
323  to_var(T_partials_return logp,
324  const T1& x1 = 0, const T2& x2 = 0, const T3& x3 = 0,
325  const T4& x4 = 0, const T5& x5 = 0, const T6& x6 = 0) {
326  return partials_to_var
328  T2, T3, T4, T5, T6>().to_var(logp, nvaris, all_varis,
329  all_partials,
330  x1, x2, x3, x4, x5, x6, d_x1, d_x2,
331  d_x3, d_x4, d_x5, d_x6);
332  }
333  };
334 
335 
336  }
337 }
338 
339 
340 #endif
VectorView< T_partials_return, is_vector< T6 >::value, is_constant_struct< T6 >::value > d_x6
stan::return_type< T1, T2, T3, T4, T5, T6 >::type T_return_type
Metaprogramming struct to detect whether a given type is constant in the mathematical sense (not the ...
Definition: is_constant.hpp:22
The variable implementation base class.
Definition: vari.hpp:30
size_t length(const std::vector< T > &x)
Definition: length.hpp:10
void chain()
Apply the chain rule to this variable based on the variables on which it depends. ...
stan::partials_return_type< T1, T2, T3, T4, T5, T6 >::type T_partials_return
T_return_type to_var(T_partials_return logp, const T1 &x1=0, const T2 &x2=0, const T3 &x3=0, const T4 &x4=0, const T5 &x5=0, const T6 &x6=0)
var to_var(const double &x)
Converts argument to an automatic differentiation variable.
Definition: to_var.hpp:21
boost::math::tools::promote_args< typename scalar_type< T1 >::type, typename scalar_type< T2 >::type, typename scalar_type< T3 >::type, typename scalar_type< T4 >::type, typename scalar_type< T5 >::type, typename scalar_type< T6 >::type >::type type
Definition: return_type.hpp:27
VectorView< T_partials_return, is_vector< T1 >::value, is_constant_struct< T1 >::value > d_x1
Metaprogram to determine if a type has a base scalar type that can be assigned to type double...
VectorView< T_partials_return, is_vector< T3 >::value, is_constant_struct< T3 >::value > d_x3
VectorView< T_partials_return, is_vector< T4 >::value, is_constant_struct< T4 >::value > d_x4
void fill(T &x, const S &y)
Fill the specified container with the specified value.
Definition: fill.hpp:22
A variable implementation that stores operands and derivatives with respect to the variable...
partials_vari(double value, size_t N, vari **operands, double *partials)
VectorView< T_partials_return, is_vector< T2 >::value, is_constant_struct< T2 >::value > d_x2
OperandsAndPartials(const T1 &x1=0, const T2 &x2=0, const T3 &x3=0, const T4 &x4=0, const T5 &x5=0, const T6 &x6=0)
double adj_
The adjoint of this variable, which is the partial derivative of this variable with respect to the ro...
Definition: vari.hpp:44
VectorView< T_partials_return, is_vector< T5 >::value, is_constant_struct< T5 >::value > d_x5
Metaprogram to calculate the base scalar return type resulting from promoting all the scalar types of...
VectorView is a template metaprogram that takes its argument and allows it to be used like a vector...
Definition: VectorView.hpp:41
boost::math::tools::promote_args< typename partials_type< typename scalar_type< T1 >::type >::type, typename partials_type< typename scalar_type< T2 >::type >::type, typename partials_type< typename scalar_type< T3 >::type >::type, typename partials_type< typename scalar_type< T4 >::type >::type, typename partials_type< typename scalar_type< T5 >::type >::type, typename partials_type< typename scalar_type< T6 >::type >::type >::type type

     [ Stan Home Page ] © 2011–2015, Stan Development Team.