Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/scipy/sparse/linalg/isolve/lsmr.py : 4%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Copyright (C) 2010 David Fong and Michael Saunders
4LSMR uses an iterative method.
607 Jun 2010: Documentation updated
703 Jun 2010: First release version in Python
9David Chin-lung Fong clfong@stanford.edu
10Institute for Computational and Mathematical Engineering
11Stanford University
13Michael Saunders saunders@stanford.edu
14Systems Optimization Laboratory
15Dept of MS&E, Stanford University.
17"""
19__all__ = ['lsmr']
21from numpy import zeros, infty, atleast_1d, result_type
22from numpy.linalg import norm
23from math import sqrt
24from scipy.sparse.linalg.interface import aslinearoperator
26from .lsqr import _sym_ortho
29def lsmr(A, b, damp=0.0, atol=1e-6, btol=1e-6, conlim=1e8,
30 maxiter=None, show=False, x0=None):
31 """Iterative solver for least-squares problems.
33 lsmr solves the system of linear equations ``Ax = b``. If the system
34 is inconsistent, it solves the least-squares problem ``min ||b - Ax||_2``.
35 A is a rectangular matrix of dimension m-by-n, where all cases are
36 allowed: m = n, m > n, or m < n. B is a vector of length m.
37 The matrix A may be dense or sparse (usually sparse).
39 Parameters
40 ----------
41 A : {matrix, sparse matrix, ndarray, LinearOperator}
42 Matrix A in the linear system.
43 Alternatively, ``A`` can be a linear operator which can
44 produce ``Ax`` and ``A^H x`` using, e.g.,
45 ``scipy.sparse.linalg.LinearOperator``.
46 b : array_like, shape (m,)
47 Vector b in the linear system.
48 damp : float
49 Damping factor for regularized least-squares. `lsmr` solves
50 the regularized least-squares problem::
52 min ||(b) - ( A )x||
53 ||(0) (damp*I) ||_2
55 where damp is a scalar. If damp is None or 0, the system
56 is solved without regularization.
57 atol, btol : float, optional
58 Stopping tolerances. `lsmr` continues iterations until a
59 certain backward error estimate is smaller than some quantity
60 depending on atol and btol. Let ``r = b - Ax`` be the
61 residual vector for the current approximate solution ``x``.
62 If ``Ax = b`` seems to be consistent, ``lsmr`` terminates
63 when ``norm(r) <= atol * norm(A) * norm(x) + btol * norm(b)``.
64 Otherwise, lsmr terminates when ``norm(A^H r) <=
65 atol * norm(A) * norm(r)``. If both tolerances are 1.0e-6 (say),
66 the final ``norm(r)`` should be accurate to about 6
67 digits. (The final x will usually have fewer correct digits,
68 depending on ``cond(A)`` and the size of LAMBDA.) If `atol`
69 or `btol` is None, a default value of 1.0e-6 will be used.
70 Ideally, they should be estimates of the relative error in the
71 entries of A and B respectively. For example, if the entries
72 of `A` have 7 correct digits, set atol = 1e-7. This prevents
73 the algorithm from doing unnecessary work beyond the
74 uncertainty of the input data.
75 conlim : float, optional
76 `lsmr` terminates if an estimate of ``cond(A)`` exceeds
77 `conlim`. For compatible systems ``Ax = b``, conlim could be
78 as large as 1.0e+12 (say). For least-squares problems,
79 `conlim` should be less than 1.0e+8. If `conlim` is None, the
80 default value is 1e+8. Maximum precision can be obtained by
81 setting ``atol = btol = conlim = 0``, but the number of
82 iterations may then be excessive.
83 maxiter : int, optional
84 `lsmr` terminates if the number of iterations reaches
85 `maxiter`. The default is ``maxiter = min(m, n)``. For
86 ill-conditioned systems, a larger value of `maxiter` may be
87 needed.
88 show : bool, optional
89 Print iterations logs if ``show=True``.
90 x0 : array_like, shape (n,), optional
91 Initial guess of x, if None zeros are used.
93 .. versionadded:: 1.0.0
94 Returns
95 -------
96 x : ndarray of float
97 Least-square solution returned.
98 istop : int
99 istop gives the reason for stopping::
101 istop = 0 means x=0 is a solution. If x0 was given, then x=x0 is a
102 solution.
103 = 1 means x is an approximate solution to A*x = B,
104 according to atol and btol.
105 = 2 means x approximately solves the least-squares problem
106 according to atol.
107 = 3 means COND(A) seems to be greater than CONLIM.
108 = 4 is the same as 1 with atol = btol = eps (machine
109 precision)
110 = 5 is the same as 2 with atol = eps.
111 = 6 is the same as 3 with CONLIM = 1/eps.
112 = 7 means ITN reached maxiter before the other stopping
113 conditions were satisfied.
115 itn : int
116 Number of iterations used.
117 normr : float
118 ``norm(b-Ax)``
119 normar : float
120 ``norm(A^H (b - Ax))``
121 norma : float
122 ``norm(A)``
123 conda : float
124 Condition number of A.
125 normx : float
126 ``norm(x)``
128 Notes
129 -----
131 .. versionadded:: 0.11.0
133 References
134 ----------
135 .. [1] D. C.-L. Fong and M. A. Saunders,
136 "LSMR: An iterative algorithm for sparse least-squares problems",
137 SIAM J. Sci. Comput., vol. 33, pp. 2950-2971, 2011.
138 https://arxiv.org/abs/1006.0758
139 .. [2] LSMR Software, https://web.stanford.edu/group/SOL/software/lsmr/
141 Examples
142 --------
143 >>> from scipy.sparse import csc_matrix
144 >>> from scipy.sparse.linalg import lsmr
145 >>> A = csc_matrix([[1., 0.], [1., 1.], [0., 1.]], dtype=float)
147 The first example has the trivial solution `[0, 0]`
149 >>> b = np.array([0., 0., 0.], dtype=float)
150 >>> x, istop, itn, normr = lsmr(A, b)[:4]
151 >>> istop
152 0
153 >>> x
154 array([ 0., 0.])
156 The stopping code `istop=0` returned indicates that a vector of zeros was
157 found as a solution. The returned solution `x` indeed contains `[0., 0.]`.
158 The next example has a non-trivial solution:
160 >>> b = np.array([1., 0., -1.], dtype=float)
161 >>> x, istop, itn, normr = lsmr(A, b)[:4]
162 >>> istop
163 1
164 >>> x
165 array([ 1., -1.])
166 >>> itn
167 1
168 >>> normr
169 4.440892098500627e-16
171 As indicated by `istop=1`, `lsmr` found a solution obeying the tolerance
172 limits. The given solution `[1., -1.]` obviously solves the equation. The
173 remaining return values include information about the number of iterations
174 (`itn=1`) and the remaining difference of left and right side of the solved
175 equation.
176 The final example demonstrates the behavior in the case where there is no
177 solution for the equation:
179 >>> b = np.array([1., 0.01, -1.], dtype=float)
180 >>> x, istop, itn, normr = lsmr(A, b)[:4]
181 >>> istop
182 2
183 >>> x
184 array([ 1.00333333, -0.99666667])
185 >>> A.dot(x)-b
186 array([ 0.00333333, -0.00333333, 0.00333333])
187 >>> normr
188 0.005773502691896255
190 `istop` indicates that the system is inconsistent and thus `x` is rather an
191 approximate solution to the corresponding least-squares problem. `normr`
192 contains the minimal distance that was found.
193 """
195 A = aslinearoperator(A)
196 b = atleast_1d(b)
197 if b.ndim > 1:
198 b = b.squeeze()
200 msg = ('The exact solution is x = 0, or x = x0, if x0 was given ',
201 'Ax - b is small enough, given atol, btol ',
202 'The least-squares solution is good enough, given atol ',
203 'The estimate of cond(Abar) has exceeded conlim ',
204 'Ax - b is small enough for this machine ',
205 'The least-squares solution is good enough for this machine',
206 'Cond(Abar) seems to be too large for this machine ',
207 'The iteration limit has been reached ')
209 hdg1 = ' itn x(1) norm r norm Ar'
210 hdg2 = ' compatible LS norm A cond A'
211 pfreq = 20 # print frequency (for repeating the heading)
212 pcount = 0 # print counter
214 m, n = A.shape
216 # stores the num of singular values
217 minDim = min([m, n])
219 if maxiter is None:
220 maxiter = minDim
222 if x0 is None:
223 dtype = result_type(A, b, float)
224 else:
225 dtype = result_type(A, b, x0, float)
227 if show:
228 print(' ')
229 print('LSMR Least-squares solution of Ax = b\n')
230 print(f'The matrix A has {m} rows and {n} columns')
231 print('damp = %20.14e\n' % (damp))
232 print('atol = %8.2e conlim = %8.2e\n' % (atol, conlim))
233 print('btol = %8.2e maxiter = %8g\n' % (btol, maxiter))
235 u = b
236 normb = norm(b)
237 if x0 is None:
238 x = zeros(n, dtype)
239 beta = normb.copy()
240 else:
241 x = atleast_1d(x0)
242 u = u - A.matvec(x)
243 beta = norm(u)
245 if beta > 0:
246 u = (1 / beta) * u
247 v = A.rmatvec(u)
248 alpha = norm(v)
249 else:
250 v = zeros(n, dtype)
251 alpha = 0
253 if alpha > 0:
254 v = (1 / alpha) * v
256 # Initialize variables for 1st iteration.
258 itn = 0
259 zetabar = alpha * beta
260 alphabar = alpha
261 rho = 1
262 rhobar = 1
263 cbar = 1
264 sbar = 0
266 h = v.copy()
267 hbar = zeros(n, dtype)
269 # Initialize variables for estimation of ||r||.
271 betadd = beta
272 betad = 0
273 rhodold = 1
274 tautildeold = 0
275 thetatilde = 0
276 zeta = 0
277 d = 0
279 # Initialize variables for estimation of ||A|| and cond(A)
281 normA2 = alpha * alpha
282 maxrbar = 0
283 minrbar = 1e+100
284 normA = sqrt(normA2)
285 condA = 1
286 normx = 0
288 # Items for use in stopping rules, normb set earlier
289 istop = 0
290 ctol = 0
291 if conlim > 0:
292 ctol = 1 / conlim
293 normr = beta
295 # Reverse the order here from the original matlab code because
296 # there was an error on return when arnorm==0
297 normar = alpha * beta
298 if normar == 0:
299 if show:
300 print(msg[0])
301 return x, istop, itn, normr, normar, normA, condA, normx
303 if show:
304 print(' ')
305 print(hdg1, hdg2)
306 test1 = 1
307 test2 = alpha / beta
308 str1 = '%6g %12.5e' % (itn, x[0])
309 str2 = ' %10.3e %10.3e' % (normr, normar)
310 str3 = ' %8.1e %8.1e' % (test1, test2)
311 print(''.join([str1, str2, str3]))
313 # Main iteration loop.
314 while itn < maxiter:
315 itn = itn + 1
317 # Perform the next step of the bidiagonalization to obtain the
318 # next beta, u, alpha, v. These satisfy the relations
319 # beta*u = a*v - alpha*u,
320 # alpha*v = A'*u - beta*v.
322 u *= -alpha
323 u += A.matvec(v)
324 beta = norm(u)
326 if beta > 0:
327 u *= (1 / beta)
328 v *= -beta
329 v += A.rmatvec(u)
330 alpha = norm(v)
331 if alpha > 0:
332 v *= (1 / alpha)
334 # At this point, beta = beta_{k+1}, alpha = alpha_{k+1}.
336 # Construct rotation Qhat_{k,2k+1}.
338 chat, shat, alphahat = _sym_ortho(alphabar, damp)
340 # Use a plane rotation (Q_i) to turn B_i to R_i
342 rhoold = rho
343 c, s, rho = _sym_ortho(alphahat, beta)
344 thetanew = s*alpha
345 alphabar = c*alpha
347 # Use a plane rotation (Qbar_i) to turn R_i^T to R_i^bar
349 rhobarold = rhobar
350 zetaold = zeta
351 thetabar = sbar * rho
352 rhotemp = cbar * rho
353 cbar, sbar, rhobar = _sym_ortho(cbar * rho, thetanew)
354 zeta = cbar * zetabar
355 zetabar = - sbar * zetabar
357 # Update h, h_hat, x.
359 hbar *= - (thetabar * rho / (rhoold * rhobarold))
360 hbar += h
361 x += (zeta / (rho * rhobar)) * hbar
362 h *= - (thetanew / rho)
363 h += v
365 # Estimate of ||r||.
367 # Apply rotation Qhat_{k,2k+1}.
368 betaacute = chat * betadd
369 betacheck = -shat * betadd
371 # Apply rotation Q_{k,k+1}.
372 betahat = c * betaacute
373 betadd = -s * betaacute
375 # Apply rotation Qtilde_{k-1}.
376 # betad = betad_{k-1} here.
378 thetatildeold = thetatilde
379 ctildeold, stildeold, rhotildeold = _sym_ortho(rhodold, thetabar)
380 thetatilde = stildeold * rhobar
381 rhodold = ctildeold * rhobar
382 betad = - stildeold * betad + ctildeold * betahat
384 # betad = betad_k here.
385 # rhodold = rhod_k here.
387 tautildeold = (zetaold - thetatildeold * tautildeold) / rhotildeold
388 taud = (zeta - thetatilde * tautildeold) / rhodold
389 d = d + betacheck * betacheck
390 normr = sqrt(d + (betad - taud)**2 + betadd * betadd)
392 # Estimate ||A||.
393 normA2 = normA2 + beta * beta
394 normA = sqrt(normA2)
395 normA2 = normA2 + alpha * alpha
397 # Estimate cond(A).
398 maxrbar = max(maxrbar, rhobarold)
399 if itn > 1:
400 minrbar = min(minrbar, rhobarold)
401 condA = max(maxrbar, rhotemp) / min(minrbar, rhotemp)
403 # Test for convergence.
405 # Compute norms for convergence testing.
406 normar = abs(zetabar)
407 normx = norm(x)
409 # Now use these norms to estimate certain other quantities,
410 # some of which will be small near a solution.
412 test1 = normr / normb
413 if (normA * normr) != 0:
414 test2 = normar / (normA * normr)
415 else:
416 test2 = infty
417 test3 = 1 / condA
418 t1 = test1 / (1 + normA * normx / normb)
419 rtol = btol + atol * normA * normx / normb
421 # The following tests guard against extremely small values of
422 # atol, btol or ctol. (The user may have set any or all of
423 # the parameters atol, btol, conlim to 0.)
424 # The effect is equivalent to the normAl tests using
425 # atol = eps, btol = eps, conlim = 1/eps.
427 if itn >= maxiter:
428 istop = 7
429 if 1 + test3 <= 1:
430 istop = 6
431 if 1 + test2 <= 1:
432 istop = 5
433 if 1 + t1 <= 1:
434 istop = 4
436 # Allow for tolerances set by the user.
438 if test3 <= ctol:
439 istop = 3
440 if test2 <= atol:
441 istop = 2
442 if test1 <= rtol:
443 istop = 1
445 # See if it is time to print something.
447 if show:
448 if (n <= 40) or (itn <= 10) or (itn >= maxiter - 10) or \
449 (itn % 10 == 0) or (test3 <= 1.1 * ctol) or \
450 (test2 <= 1.1 * atol) or (test1 <= 1.1 * rtol) or \
451 (istop != 0):
453 if pcount >= pfreq:
454 pcount = 0
455 print(' ')
456 print(hdg1, hdg2)
457 pcount = pcount + 1
458 str1 = '%6g %12.5e' % (itn, x[0])
459 str2 = ' %10.3e %10.3e' % (normr, normar)
460 str3 = ' %8.1e %8.1e' % (test1, test2)
461 str4 = ' %8.1e %8.1e' % (normA, condA)
462 print(''.join([str1, str2, str3, str4]))
464 if istop > 0:
465 break
467 # Print the stopping condition.
469 if show:
470 print(' ')
471 print('LSMR finished')
472 print(msg[istop])
473 print('istop =%8g normr =%8.1e' % (istop, normr))
474 print(' normA =%8.1e normAr =%8.1e' % (normA, normar))
475 print('itn =%8g condA =%8.1e' % (itn, condA))
476 print(' normx =%8.1e' % (normx))
477 print(str1, str2)
478 print(str3, str4)
480 return x, istop, itn, normr, normar, normA, condA, normx