1 #ifndef STAN_MATH_REV_ARR_FUNCTOR_COUPLED_ODE_SYSTEM_HPP
2 #define STAN_MATH_REV_ARR_FUNCTOR_COUPLED_ODE_SYSTEM_HPP
38 std::vector<std::vector<stan::math::var> >& y) {
39 for (
size_t n = 0; n < y.size(); n++)
40 for (
size_t m = 0; m < y0.size(); m++)
69 const std::vector<stan::math::var>&
theta_;
71 const std::vector<double>&
x_;
91 const std::vector<double>& y0,
92 const std::vector<stan::math::var>& theta,
93 const std::vector<double>& x,
94 const std::vector<int>& x_int,
99 theta_dbl_(theta.
size(), 0.0),
106 for (
size_t m = 0; m <
M_; m++)
128 std::vector<double>& dz_dt,
133 vector<double> y(z.begin(), z.begin() +
N_);
134 dz_dt = f_(t, y, theta_dbl_, x_, x_int_, msgs_);
136 "dz_dt", dz_dt.size(),
N_);
138 vector<double> coupled_sys(N_ * M_);
139 vector<double>
grad(N_ + M_);
145 z_vars.reserve(N_ + M_);
147 vector<var> y_vars(y.begin(), y.end());
148 z_vars.insert(z_vars.end(), y_vars.begin(), y_vars.end());
150 vector<var> theta_vars(theta_dbl_.begin(), theta_dbl_.end());
151 z_vars.insert(z_vars.end(), theta_vars.begin(), theta_vars.end());
153 vector<var> dy_dt_vars = f_(t, y_vars, theta_vars, x_, x_int_, msgs_);
155 for (
size_t i = 0; i <
N_; i++) {
157 dy_dt_vars[i].grad(z_vars, grad);
159 for (
size_t j = 0; j <
M_; j++) {
163 double temp_deriv = grad[N_ + j];
164 for (
size_t k = 0; k <
N_; k++)
165 temp_deriv += z[N_ + N_ * j + k] * grad[k];
167 coupled_sys[i + j *
N_] = temp_deriv;
170 }
catch (
const std::exception&
e) {
176 dz_dt.insert(dz_dt.end(), coupled_sys.begin(), coupled_sys.end());
202 std::vector<double> state(size_, 0.0);
203 for (
size_t n = 0; n <
N_; n++)
204 state[n] = y0_dbl_[n];
214 std::vector<std::vector<stan::math::var> >
217 std::vector<stan::math::var> temp_vars;
218 std::vector<double> temp_gradients;
219 std::vector<std::vector<stan::math::var> > y_return(y.size());
221 for (
size_t i = 0; i < y.size(); i++) {
225 for (
size_t j = 0; j <
N_; j++) {
226 temp_gradients.clear();
229 for (
size_t k = 0; k <
M_; k++)
230 temp_gradients.push_back(y[i][y0_dbl_.size()
231 + y0_dbl_.size() * k + j]);
237 y_return[i] = temp_vars;
270 template <
typename F>
273 const std::vector<stan::math::var>&
y0_;
276 const std::vector<double>&
x_;
297 const std::vector<stan::math::var>& y0,
298 const std::vector<double>& theta,
299 const std::vector<double>& x,
300 const std::vector<int>& x_int,
304 y0_dbl_(y0.
size(), 0.0),
311 size_(N_ + N_ * N_) {
312 for (
size_t n = 0; n <
N_; n++)
333 std::vector<double>& dz_dt,
338 std::vector<double> y(z.begin(), z.begin() +
N_);
339 for (
size_t n = 0; n <
N_; n++)
342 dz_dt = f_(t, y, theta_dbl_, x_, x_int_, msgs_);
344 "dz_dt", dz_dt.size(),
N_);
346 std::vector<double> coupled_sys(N_ * N_);
347 std::vector<double>
grad(N_);
355 vector<var> y_vars(y.begin(), y.end());
356 z_vars.insert(z_vars.end(), y_vars.begin(), y_vars.end());
358 vector<var> dy_dt_vars = f_(t, y_vars, theta_dbl_, x_, x_int_, msgs_);
360 for (
size_t i = 0; i <
N_; i++) {
362 dy_dt_vars[i].grad(z_vars, grad);
364 for (
size_t j = 0; j <
N_; j++) {
368 double temp_deriv = grad[j];
369 for (
size_t k = 0; k <
N_; k++)
370 temp_deriv += z[N_ + N_ * j + k] * grad[k];
372 coupled_sys[i + j *
N_] = temp_deriv;
375 }
catch (
const std::exception&
e) {
381 dz_dt.insert(dz_dt.end(), coupled_sys.begin(), coupled_sys.end());
408 return std::vector<double>(
size_, 0.0);
418 std::vector<std::vector<stan::math::var> >
424 vector<var> temp_vars;
425 vector<double> temp_gradients;
426 vector<vector<var> > y_return(y.size());
428 for (
size_t i = 0; i < y.size(); i++) {
432 for (
size_t j = 0; j <
N_; j++) {
433 temp_gradients.clear();
436 for (
size_t k = 0; k <
N_; k++)
437 temp_gradients.push_back(y[i][y0_.size() + y0_.size() * k + j]);
440 y0_, temp_gradients));
443 y_return[i] = temp_vars;
487 template <
typename F>
490 const std::vector<stan::math::var>&
y0_;
492 const std::vector<stan::math::var>&
theta_;
494 const std::vector<double>&
x_;
515 const std::vector<stan::math::var>& y0,
516 const std::vector<stan::math::var>& theta,
517 const std::vector<double>& x,
518 const std::vector<int>& x_int,
522 y0_dbl_(y0.
size(), 0.0),
524 theta_dbl_(theta.
size(), 0.0),
529 size_(N_ + N_ * (N_ + M_)),
531 for (
size_t n = 0; n <
N_; n++)
534 for (
size_t m = 0; m <
M_; m++)
555 std::vector<double>& dz_dt,
560 vector<double> y(z.begin(), z.begin() +
N_);
561 for (
size_t n = 0; n <
N_; n++)
564 dz_dt = f_(t, y, theta_dbl_, x_, x_int_, msgs_);
566 "dz_dt", dz_dt.size(),
N_);
568 vector<double> coupled_sys(N_ * (N_ + M_));
569 vector<double>
grad(N_ + M_);
575 z_vars.reserve(N_ + M_);
577 vector<var> y_vars(y.begin(), y.end());
578 z_vars.insert(z_vars.end(), y_vars.begin(), y_vars.end());
580 vector<var> theta_vars(theta_dbl_.begin(), theta_dbl_.end());
581 z_vars.insert(z_vars.end(), theta_vars.begin(), theta_vars.end());
583 vector<var> dy_dt_vars = f_(t, y_vars, theta_vars, x_, x_int_, msgs_);
585 for (
size_t i = 0; i <
N_; i++) {
587 dy_dt_vars[i].grad(z_vars, grad);
589 for (
size_t j = 0; j < N_ +
M_; j++) {
593 double temp_deriv = grad[j];
594 for (
size_t k = 0; k <
N_; k++)
595 temp_deriv += z[N_ + N_ * j + k] * grad[k];
597 coupled_sys[i + j *
N_] = temp_deriv;
600 }
catch (
const std::exception&
e) {
606 dz_dt.insert(dz_dt.end(), coupled_sys.begin(), coupled_sys.end());
630 return std::vector<double>(
size_, 0.0);
640 std::vector<std::vector<stan::math::var> >
646 vector<var> vars = y0_;
647 vars.insert(vars.end(), theta_.begin(), theta_.end());
649 vector<var> temp_vars;
650 vector<double> temp_gradients;
651 vector<vector<var> > y_return(y.size());
653 for (
size_t i = 0; i < y.size(); i++) {
657 for (
size_t j = 0; j <
N_; j++) {
658 temp_gradients.clear();
661 for (
size_t k = 0; k < N_ +
M_; k++)
662 temp_gradients.push_back(y[i][N_ + N_ * k + j]);
665 vars, temp_gradients));
667 y_return[i] = temp_vars;
var precomputed_gradients(const double value, const std::vector< var > &operands, const std::vector< double > &gradients)
This function returns a var for an expression that has the specified value, vector of operands...
std::vector< std::vector< stan::math::var > > decouple_states(const std::vector< std::vector< double > > &y)
Returns the base ODE system state corresponding to the specified coupled system state.
const std::vector< stan::math::var > & theta_
std::vector< double > y0_dbl_
std::vector< std::vector< stan::math::var > > decouple_states(const std::vector< std::vector< double > > &y)
Return the solutions to the basic ODE system, including appropriate autodiff partial derivatives...
coupled_ode_system(const F &f, const std::vector< double > &y0, const std::vector< stan::math::var > &theta, const std::vector< double > &x, const std::vector< int > &x_int, std::ostream *msgs)
Construct a coupled ODE system with the specified base ODE system, base initial state, parameters, data, and a message stream.
std::vector< double > theta_dbl_
T value_of(const fvar< T > &v)
Return the value of the specified variable.
static void set_zero_all_adjoints_nested()
Reset all adjoint values in the top nested portion of the stack to zero.
std::vector< double > initial_state()
Returns the initial state of the coupled system.
std::vector< std::vector< stan::math::var > > decouple_states(const std::vector< std::vector< double > > &y)
Return the basic ODE solutions given the specified coupled system solutions, including the partials v...
Independent (input) and dependent (output) variables for gradients.
static void grad(vari *vi)
Compute the gradient for all variables starting from the specified root variable implementation.
size_t size() const
Returns the size of the coupled system.
std::vector< double > theta_dbl_
const std::vector< double > & theta_dbl_
void operator()(const std::vector< double > &z, std::vector< double > &dz_dt, double t)
Calculates the derivative of the coupled ode system with respect to the state y at time t...
void operator()(const std::vector< double > &z, std::vector< double > &dz_dt, double t)
Assign the derivative vector with the system derivatives at the specified state and time...
std::vector< double > initial_state()
Returns the initial state of the coupled system.
coupled_ode_system(const F &f, const std::vector< stan::math::var > &y0, const std::vector< double > &theta, const std::vector< double > &x, const std::vector< int > &x_int, std::ostream *msgs)
Construct a coupled ODE system for an unknown initial state and known parameters givne the specified ...
const std::vector< int > & x_int_
bool check_equal(const char *function, const char *name, const T_y &y, const T_eq &eq)
Return true if y is equal to eq.
const std::vector< double > & y0_dbl_
void add_initial_values(const std::vector< stan::math::var > &y0, std::vector< std::vector< stan::math::var > > &y)
Increment the state derived from the coupled system in the with the original initial state...
const std::vector< double > & x_
const std::vector< int > & x_int_
double e()
Return the base of the natural logarithm.
const std::vector< stan::math::var > & y0_
std::vector< double > initial_state()
Returns the initial state of the coupled system.
int size(const std::vector< T > &x)
Return the size of the specified standard vector.
const std::vector< stan::math::var > & theta_
coupled_ode_system(const F &f, const std::vector< stan::math::var > &y0, const std::vector< stan::math::var > &theta, const std::vector< double > &x, const std::vector< int > &x_int, std::ostream *msgs)
Construct a coupled ODE system with unknown initial value and known parameters, given the base ODE sy...
const std::vector< double > & x_
void operator()(const std::vector< double > &z, std::vector< double > &dz_dt, double t)
Populates the derivative vector with derivatives of the coupled ODE system state with respect to time...
const std::vector< double > & x_
Base template class for a coupled ordinary differential equation system, which adds sensitivities to ...
static void recover_memory_nested()
Recover only the memory used for the top nested call.
const std::vector< stan::math::var > & y0_
std::vector< double > y0_dbl_
size_t size() const
Returns the size of the coupled system.
size_t size() const
Returns the size of the coupled system.
const std::vector< int > & x_int_
static void start_nested()
Record the current position so that recover_memory_nested() can find it.