Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1import numpy as np 

2from scipy._lib._util import _asarray_validated 

3 

4__all__ = ["logsumexp", "softmax", "log_softmax"] 

5 

6 

7def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): 

8 """Compute the log of the sum of exponentials of input elements. 

9 

10 Parameters 

11 ---------- 

12 a : array_like 

13 Input array. 

14 axis : None or int or tuple of ints, optional 

15 Axis or axes over which the sum is taken. By default `axis` is None, 

16 and all elements are summed. 

17 

18 .. versionadded:: 0.11.0 

19 keepdims : bool, optional 

20 If this is set to True, the axes which are reduced are left in the 

21 result as dimensions with size one. With this option, the result 

22 will broadcast correctly against the original array. 

23 

24 .. versionadded:: 0.15.0 

25 b : array-like, optional 

26 Scaling factor for exp(`a`) must be of the same shape as `a` or 

27 broadcastable to `a`. These values may be negative in order to 

28 implement subtraction. 

29 

30 .. versionadded:: 0.12.0 

31 return_sign : bool, optional 

32 If this is set to True, the result will be a pair containing sign 

33 information; if False, results that are negative will be returned 

34 as NaN. Default is False (no sign information). 

35 

36 .. versionadded:: 0.16.0 

37 

38 Returns 

39 ------- 

40 res : ndarray 

41 The result, ``np.log(np.sum(np.exp(a)))`` calculated in a numerically 

42 more stable way. If `b` is given then ``np.log(np.sum(b*np.exp(a)))`` 

43 is returned. 

44 sgn : ndarray 

45 If return_sign is True, this will be an array of floating-point 

46 numbers matching res and +1, 0, or -1 depending on the sign 

47 of the result. If False, only one result is returned. 

48 

49 See Also 

50 -------- 

51 numpy.logaddexp, numpy.logaddexp2 

52 

53 Notes 

54 ----- 

55 NumPy has a logaddexp function which is very similar to `logsumexp`, but 

56 only handles two arguments. `logaddexp.reduce` is similar to this 

57 function, but may be less stable. 

58 

59 Examples 

60 -------- 

61 >>> from scipy.special import logsumexp 

62 >>> a = np.arange(10) 

63 >>> np.log(np.sum(np.exp(a))) 

64 9.4586297444267107 

65 >>> logsumexp(a) 

66 9.4586297444267107 

67 

68 With weights 

69 

70 >>> a = np.arange(10) 

71 >>> b = np.arange(10, 0, -1) 

72 >>> logsumexp(a, b=b) 

73 9.9170178533034665 

74 >>> np.log(np.sum(b*np.exp(a))) 

75 9.9170178533034647 

76 

77 Returning a sign flag 

78 

79 >>> logsumexp([1,2],b=[1,-1],return_sign=True) 

80 (1.5413248546129181, -1.0) 

81 

82 Notice that `logsumexp` does not directly support masked arrays. To use it 

83 on a masked array, convert the mask into zero weights: 

84 

85 >>> a = np.ma.array([np.log(2), 2, np.log(3)], 

86 ... mask=[False, True, False]) 

87 >>> b = (~a.mask).astype(int) 

88 >>> logsumexp(a.data, b=b), np.log(5) 

89 1.6094379124341005, 1.6094379124341005 

90 

91 """ 

92 a = _asarray_validated(a, check_finite=False) 

93 if b is not None: 

94 a, b = np.broadcast_arrays(a, b) 

95 if np.any(b == 0): 

96 a = a + 0. # promote to at least float 

97 a[b == 0] = -np.inf 

98 

99 a_max = np.amax(a, axis=axis, keepdims=True) 

100 

101 if a_max.ndim > 0: 

102 a_max[~np.isfinite(a_max)] = 0 

103 elif not np.isfinite(a_max): 

104 a_max = 0 

105 

106 if b is not None: 

107 b = np.asarray(b) 

108 tmp = b * np.exp(a - a_max) 

109 else: 

110 tmp = np.exp(a - a_max) 

111 

112 # suppress warnings about log of zero 

113 with np.errstate(divide='ignore'): 

114 s = np.sum(tmp, axis=axis, keepdims=keepdims) 

115 if return_sign: 

116 sgn = np.sign(s) 

117 s *= sgn # /= makes more sense but we need zero -> zero 

118 out = np.log(s) 

119 

120 if not keepdims: 

121 a_max = np.squeeze(a_max, axis=axis) 

122 out += a_max 

123 

124 if return_sign: 

125 return out, sgn 

126 else: 

127 return out 

128 

129 

130def softmax(x, axis=None): 

131 r""" 

132 Softmax function 

133 

134 The softmax function transforms each element of a collection by 

135 computing the exponential of each element divided by the sum of the 

136 exponentials of all the elements. That is, if `x` is a one-dimensional 

137 numpy array:: 

138 

139 softmax(x) = np.exp(x)/sum(np.exp(x)) 

140 

141 Parameters 

142 ---------- 

143 x : array_like 

144 Input array. 

145 axis : int or tuple of ints, optional 

146 Axis to compute values along. Default is None and softmax will be 

147 computed over the entire array `x`. 

148 

149 Returns 

150 ------- 

151 s : ndarray 

152 An array the same shape as `x`. The result will sum to 1 along the 

153 specified axis. 

154 

155 Notes 

156 ----- 

157 The formula for the softmax function :math:`\sigma(x)` for a vector 

158 :math:`x = \{x_0, x_1, ..., x_{n-1}\}` is 

159 

160 .. math:: \sigma(x)_j = \frac{e^{x_j}}{\sum_k e^{x_k}} 

161 

162 The `softmax` function is the gradient of `logsumexp`. 

163 

164 .. versionadded:: 1.2.0 

165 

166 Examples 

167 -------- 

168 >>> from scipy.special import softmax 

169 >>> np.set_printoptions(precision=5) 

170 

171 >>> x = np.array([[1, 0.5, 0.2, 3], 

172 ... [1, -1, 7, 3], 

173 ... [2, 12, 13, 3]]) 

174 ... 

175 

176 Compute the softmax transformation over the entire array. 

177 

178 >>> m = softmax(x) 

179 >>> m 

180 array([[ 4.48309e-06, 2.71913e-06, 2.01438e-06, 3.31258e-05], 

181 [ 4.48309e-06, 6.06720e-07, 1.80861e-03, 3.31258e-05], 

182 [ 1.21863e-05, 2.68421e-01, 7.29644e-01, 3.31258e-05]]) 

183 

184 >>> m.sum() 

185 1.0000000000000002 

186 

187 Compute the softmax transformation along the first axis (i.e., the 

188 columns). 

189 

190 >>> m = softmax(x, axis=0) 

191 

192 >>> m 

193 array([[ 2.11942e-01, 1.01300e-05, 2.75394e-06, 3.33333e-01], 

194 [ 2.11942e-01, 2.26030e-06, 2.47262e-03, 3.33333e-01], 

195 [ 5.76117e-01, 9.99988e-01, 9.97525e-01, 3.33333e-01]]) 

196 

197 >>> m.sum(axis=0) 

198 array([ 1., 1., 1., 1.]) 

199 

200 Compute the softmax transformation along the second axis (i.e., the rows). 

201 

202 >>> m = softmax(x, axis=1) 

203 >>> m 

204 array([[ 1.05877e-01, 6.42177e-02, 4.75736e-02, 7.82332e-01], 

205 [ 2.42746e-03, 3.28521e-04, 9.79307e-01, 1.79366e-02], 

206 [ 1.22094e-05, 2.68929e-01, 7.31025e-01, 3.31885e-05]]) 

207 

208 >>> m.sum(axis=1) 

209 array([ 1., 1., 1.]) 

210 

211 """ 

212 

213 # compute in log space for numerical stability 

214 return np.exp(x - logsumexp(x, axis=axis, keepdims=True)) 

215 

216 

217def log_softmax(x, axis=None): 

218 r""" 

219 Logarithm of softmax function:: 

220 

221 log_softmax(x) = log(softmax(x)) 

222 

223 Parameters 

224 ---------- 

225 x : array_like 

226 Input array. 

227 axis : int or tuple of ints, optional 

228 Axis to compute values along. Default is None and softmax will be 

229 computed over the entire array `x`. 

230 

231 Returns 

232 ------- 

233 s : ndarray or scalar 

234 An array with the same shape as `x`. Exponential of the result will 

235 sum to 1 along the specified axis. If `x` is a scalar, a scalar is 

236 returned. 

237 

238 Notes 

239 ----- 

240 `log_softmax` is more accurate than ``np.log(softmax(x))`` with inputs that 

241 make `softmax` saturate (see examples below). 

242 

243 .. versionadded:: 1.5.0 

244 

245 Examples 

246 -------- 

247 >>> from scipy.special import log_softmax 

248 >>> from scipy.special import softmax 

249 >>> np.set_printoptions(precision=5) 

250 

251 >>> x = np.array([1000.0, 1.0]) 

252 

253 >>> y = log_softmax(x) 

254 >>> y 

255 array([ 0., -999.]) 

256 

257 >>> with np.errstate(divide='ignore'): 

258 ... y = np.log(softmax(x)) 

259 ... 

260 >>> y 

261 array([ 0., -inf]) 

262 

263 """ 

264 

265 x = _asarray_validated(x, check_finite=False) 

266 

267 x_max = np.amax(x, axis=axis, keepdims=True) 

268 

269 if x_max.ndim > 0: 

270 x_max[~np.isfinite(x_max)] = 0 

271 elif not np.isfinite(x_max): 

272 x_max = 0 

273 

274 tmp = x - x_max 

275 exp_tmp = np.exp(tmp) 

276 

277 # suppress warnings about log of zero 

278 with np.errstate(divide='ignore'): 

279 s = np.sum(exp_tmp, axis=axis, keepdims=True) 

280 out = np.log(s) 

281 

282 out = tmp - out 

283 return out