Stan Math Library  2.6.3
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerator Friends Macros
log_softmax.hpp
Go to the documentation of this file.
1 #ifndef STAN_MATH_REV_MAT_FUN_LOG_SOFTMAX_HPP
2 #define STAN_MATH_REV_MAT_FUN_LOG_SOFTMAX_HPP
3 
4 #include <stan/math/rev/core.hpp>
9 #include <cmath>
10 #include <stdexcept>
11 #include <vector>
12 
13 namespace stan {
14  namespace math {
15 
16  namespace {
17 
18  class log_softmax_elt_vari : public vari {
19  private:
20  vari** alpha_;
21  const double* softmax_alpha_;
22  const int size_; // array sizes
23  const int idx_; // in in softmax output
24 
25  public:
26  log_softmax_elt_vari(double val,
27  vari** alpha,
28  const double* softmax_alpha,
29  int size,
30  int idx)
31  : vari(val),
32  alpha_(alpha),
33  softmax_alpha_(softmax_alpha),
34  size_(size),
35  idx_(idx) {
36  }
37  void chain() {
38  for (int m = 0; m < size_; ++m) {
39  if (m == idx_)
40  alpha_[m]->adj_ += adj_ * (1 - softmax_alpha_[m]);
41  else
42  alpha_[m]->adj_ -= adj_ * softmax_alpha_[m];
43  }
44  }
45  };
46 
47  }
48 
49 
60  inline Eigen::Matrix<var, Eigen::Dynamic, 1>
61  log_softmax(const Eigen::Matrix<var, Eigen::Dynamic, 1>& alpha) {
62  using Eigen::Matrix;
63  using Eigen::Dynamic;
64 
65  stan::math::check_nonzero_size("log_softmax", "alpha", alpha);
66 
67  if (alpha.size() == 0)
68  throw std::domain_error("arg vector to log_softmax() "
69  "must have size > 0");
70  if (alpha.size() == 0)
71  throw std::domain_error("arg vector to log_softmax() "
72  "must have size > 0");
73  if (alpha.size() == 0)
74  throw std::domain_error("arg vector to log_softmax() "
75  "must have size > 0");
76 
77  vari** alpha_vi_array
78  = reinterpret_cast<vari**>
79  (chainable::operator new(sizeof(vari*) * alpha.size()));
80  for (int i = 0; i < alpha.size(); ++i)
81  alpha_vi_array[i] = alpha(i).vi_;
82 
83 
84  Matrix<double, Dynamic, 1> alpha_d(alpha.size());
85  for (int i = 0; i < alpha_d.size(); ++i)
86  alpha_d(i) = alpha(i).val();
87 
88  // fold logic of math::softmax() and math::log_softmax()
89  // to save computations
90 
91  Matrix<double, Dynamic, 1> softmax_alpha_d(alpha_d.size());
92  Matrix<double, Dynamic, 1> log_softmax_alpha_d(alpha_d.size());
93 
94  double max_v = alpha_d.maxCoeff();
95 
96  double sum = 0.0;
97  for (int i = 0; i < alpha_d.size(); ++i) {
98  softmax_alpha_d(i) = std::exp(alpha_d(i) - max_v);
99  sum += softmax_alpha_d(i);
100  }
101 
102  for (int i = 0; i < alpha_d.size(); ++i)
103  softmax_alpha_d(i) /= sum;
104  double log_sum = std::log(sum);
105 
106  for (int i = 0; i < alpha_d.size(); ++i)
107  log_softmax_alpha_d(i) = (alpha_d(i) - max_v) - log_sum;
108 
109  // end fold
110 
111  double* softmax_alpha_d_array
112  = reinterpret_cast<double*>
113  (chainable::operator new(sizeof(double) * alpha_d.size()));
114 
115  for (int i = 0; i < alpha_d.size(); ++i)
116  softmax_alpha_d_array[i] = softmax_alpha_d(i);
117 
118  Matrix<var, Dynamic, 1> log_softmax_alpha(alpha.size());
119  for (int k = 0; k < log_softmax_alpha.size(); ++k)
120  log_softmax_alpha(k)
121  = var(new log_softmax_elt_vari(log_softmax_alpha_d[k],
122  alpha_vi_array,
123  softmax_alpha_d_array,
124  alpha.size(),
125  k));
126  return log_softmax_alpha;
127  }
128 
129 
130  }
131 }
132 
133 #endif
fvar< T > sum(const std::vector< fvar< T > > &m)
Return the sum of the entries of the specified standard vector.
Definition: sum.hpp:20
const int size_
Definition: log_softmax.hpp:22
const int idx_
Definition: log_softmax.hpp:23
const double * softmax_alpha_
Definition: log_softmax.hpp:21
fvar< T > log(const fvar< T > &x)
Definition: log.hpp:15
The variable implementation base class.
Definition: vari.hpp:28
Independent (input) and dependent (output) variables for gradients.
Definition: var.hpp:32
Eigen::Matrix< fvar< T >, Eigen::Dynamic, 1 > log_softmax(const Eigen::Matrix< fvar< T >, Eigen::Dynamic, 1 > &alpha)
Definition: log_softmax.hpp:16
bool check_nonzero_size(const char *function, const char *name, const T_y &y)
Return true if the specified matrix/vector is of non-zero size.
fvar< T > exp(const fvar< T > &x)
Definition: exp.hpp:10
void domain_error(const char *function, const char *name, const T &y, const char *msg1, const char *msg2)
Throw a domain error with a consistently formatted message.
int size(const std::vector< T > &x)
Definition: size.hpp:11
vari ** alpha_
Definition: log_softmax.hpp:20

     [ Stan Home Page ] © 2011–2015, Stan Development Team.