Coverage for /home/martinb/.local/share/virtualenvs/camcops/lib/python3.6/site-packages/scipy/linalg/_matfuncs_sqrtm.py : 12%

Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""
2Matrix square root for general matrices and for upper triangular matrices.
4This module exists to avoid cyclic imports.
6"""
7__all__ = ['sqrtm']
9import numpy as np
11from scipy._lib._util import _asarray_validated
14# Local imports
15from .misc import norm
16from .lapack import ztrsyl, dtrsyl
17from .decomp_schur import schur, rsf2csf
20class SqrtmError(np.linalg.LinAlgError):
21 pass
24def _sqrtm_triu(T, blocksize=64):
25 """
26 Matrix square root of an upper triangular matrix.
28 This is a helper function for `sqrtm` and `logm`.
30 Parameters
31 ----------
32 T : (N, N) array_like upper triangular
33 Matrix whose square root to evaluate
34 blocksize : int, optional
35 If the blocksize is not degenerate with respect to the
36 size of the input array, then use a blocked algorithm. (Default: 64)
38 Returns
39 -------
40 sqrtm : (N, N) ndarray
41 Value of the sqrt function at `T`
43 References
44 ----------
45 .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013)
46 "Blocked Schur Algorithms for Computing the Matrix Square Root,
47 Lecture Notes in Computer Science, 7782. pp. 171-182.
49 """
50 T_diag = np.diag(T)
51 keep_it_real = np.isrealobj(T) and np.min(T_diag) >= 0
52 if not keep_it_real:
53 T_diag = T_diag.astype(complex)
54 R = np.diag(np.sqrt(T_diag))
56 # Compute the number of blocks to use; use at least one block.
57 n, n = T.shape
58 nblocks = max(n // blocksize, 1)
60 # Compute the smaller of the two sizes of blocks that
61 # we will actually use, and compute the number of large blocks.
62 bsmall, nlarge = divmod(n, nblocks)
63 blarge = bsmall + 1
64 nsmall = nblocks - nlarge
65 if nsmall * bsmall + nlarge * blarge != n:
66 raise Exception('internal inconsistency')
68 # Define the index range covered by each block.
69 start_stop_pairs = []
70 start = 0
71 for count, size in ((nsmall, bsmall), (nlarge, blarge)):
72 for i in range(count):
73 start_stop_pairs.append((start, start + size))
74 start += size
76 # Within-block interactions.
77 for start, stop in start_stop_pairs:
78 for j in range(start, stop):
79 for i in range(j-1, start-1, -1):
80 s = 0
81 if j - i > 1:
82 s = R[i, i+1:j].dot(R[i+1:j, j])
83 denom = R[i, i] + R[j, j]
84 num = T[i, j] - s
85 if denom != 0:
86 R[i, j] = (T[i, j] - s) / denom
87 elif denom == 0 and num == 0:
88 R[i, j] = 0
89 else:
90 raise SqrtmError('failed to find the matrix square root')
92 # Between-block interactions.
93 for j in range(nblocks):
94 jstart, jstop = start_stop_pairs[j]
95 for i in range(j-1, -1, -1):
96 istart, istop = start_stop_pairs[i]
97 S = T[istart:istop, jstart:jstop]
98 if j - i > 1:
99 S = S - R[istart:istop, istop:jstart].dot(R[istop:jstart,
100 jstart:jstop])
102 # Invoke LAPACK.
103 # For more details, see the solve_sylvester implemention
104 # and the fortran dtrsyl and ztrsyl docs.
105 Rii = R[istart:istop, istart:istop]
106 Rjj = R[jstart:jstop, jstart:jstop]
107 if keep_it_real:
108 x, scale, info = dtrsyl(Rii, Rjj, S)
109 else:
110 x, scale, info = ztrsyl(Rii, Rjj, S)
111 R[istart:istop, jstart:jstop] = x * scale
113 # Return the matrix square root.
114 return R
117def sqrtm(A, disp=True, blocksize=64):
118 """
119 Matrix square root.
121 Parameters
122 ----------
123 A : (N, N) array_like
124 Matrix whose square root to evaluate
125 disp : bool, optional
126 Print warning if error in the result is estimated large
127 instead of returning estimated error. (Default: True)
128 blocksize : integer, optional
129 If the blocksize is not degenerate with respect to the
130 size of the input array, then use a blocked algorithm. (Default: 64)
132 Returns
133 -------
134 sqrtm : (N, N) ndarray
135 Value of the sqrt function at `A`
137 errest : float
138 (if disp == False)
140 Frobenius norm of the estimated error, ||err||_F / ||A||_F
142 References
143 ----------
144 .. [1] Edvin Deadman, Nicholas J. Higham, Rui Ralha (2013)
145 "Blocked Schur Algorithms for Computing the Matrix Square Root,
146 Lecture Notes in Computer Science, 7782. pp. 171-182.
148 Examples
149 --------
150 >>> from scipy.linalg import sqrtm
151 >>> a = np.array([[1.0, 3.0], [1.0, 4.0]])
152 >>> r = sqrtm(a)
153 >>> r
154 array([[ 0.75592895, 1.13389342],
155 [ 0.37796447, 1.88982237]])
156 >>> r.dot(r)
157 array([[ 1., 3.],
158 [ 1., 4.]])
160 """
161 A = _asarray_validated(A, check_finite=True, as_inexact=True)
162 if len(A.shape) != 2:
163 raise ValueError("Non-matrix input to matrix function.")
164 if blocksize < 1:
165 raise ValueError("The blocksize should be at least 1.")
166 keep_it_real = np.isrealobj(A)
167 if keep_it_real:
168 T, Z = schur(A)
169 if not np.array_equal(T, np.triu(T)):
170 T, Z = rsf2csf(T, Z)
171 else:
172 T, Z = schur(A, output='complex')
173 failflag = False
174 try:
175 R = _sqrtm_triu(T, blocksize=blocksize)
176 ZH = np.conjugate(Z).T
177 X = Z.dot(R).dot(ZH)
178 except SqrtmError:
179 failflag = True
180 X = np.empty_like(A)
181 X.fill(np.nan)
183 if disp:
184 if failflag:
185 print("Failed to find a square root.")
186 return X
187 else:
188 try:
189 arg2 = norm(X.dot(X) - A, 'fro')**2 / norm(A, 'fro')
190 except ValueError:
191 # NaNs in matrix
192 arg2 = np.inf
194 return X, arg2