388 lines
11 KiB
Python
388 lines
11 KiB
Python
import warnings
|
|
from numpy import inner, zeros, inf, finfo
|
|
from numpy.linalg import norm
|
|
from math import sqrt
|
|
|
|
from .utils import make_system
|
|
from scipy._lib.deprecation import _NoValue, _deprecate_positional_args
|
|
|
|
__all__ = ['minres']
|
|
|
|
|
|
@_deprecate_positional_args(version="1.14.0")
|
|
def minres(A, b, x0=None, *, shift=0.0, tol=_NoValue, maxiter=None,
|
|
M=None, callback=None, show=False, check=False, rtol=1e-5):
|
|
"""
|
|
Use MINimum RESidual iteration to solve Ax=b
|
|
|
|
MINRES minimizes norm(Ax - b) for a real symmetric matrix A. Unlike
|
|
the Conjugate Gradient method, A can be indefinite or singular.
|
|
|
|
If shift != 0 then the method solves (A - shift*I)x = b
|
|
|
|
Parameters
|
|
----------
|
|
A : {sparse matrix, ndarray, LinearOperator}
|
|
The real symmetric N-by-N matrix of the linear system
|
|
Alternatively, ``A`` can be a linear operator which can
|
|
produce ``Ax`` using, e.g.,
|
|
``scipy.sparse.linalg.LinearOperator``.
|
|
b : ndarray
|
|
Right hand side of the linear system. Has shape (N,) or (N,1).
|
|
|
|
Returns
|
|
-------
|
|
x : ndarray
|
|
The converged solution.
|
|
info : integer
|
|
Provides convergence information:
|
|
0 : successful exit
|
|
>0 : convergence to tolerance not achieved, number of iterations
|
|
<0 : illegal input or breakdown
|
|
|
|
Other Parameters
|
|
----------------
|
|
x0 : ndarray
|
|
Starting guess for the solution.
|
|
shift : float
|
|
Value to apply to the system ``(A - shift * I)x = b``. Default is 0.
|
|
rtol : float
|
|
Tolerance to achieve. The algorithm terminates when the relative
|
|
residual is below ``rtol``.
|
|
maxiter : integer
|
|
Maximum number of iterations. Iteration will stop after maxiter
|
|
steps even if the specified tolerance has not been achieved.
|
|
M : {sparse matrix, ndarray, LinearOperator}
|
|
Preconditioner for A. The preconditioner should approximate the
|
|
inverse of A. Effective preconditioning dramatically improves the
|
|
rate of convergence, which implies that fewer iterations are needed
|
|
to reach a given error tolerance.
|
|
callback : function
|
|
User-supplied function to call after each iteration. It is called
|
|
as callback(xk), where xk is the current solution vector.
|
|
show : bool
|
|
If ``True``, print out a summary and metrics related to the solution
|
|
during iterations. Default is ``False``.
|
|
check : bool
|
|
If ``True``, run additional input validation to check that `A` and
|
|
`M` (if specified) are symmetric. Default is ``False``.
|
|
tol : float, optional, deprecated
|
|
|
|
.. deprecated:: 1.12.0
|
|
`minres` keyword argument ``tol`` is deprecated in favor of ``rtol``
|
|
and will be removed in SciPy 1.14.0.
|
|
|
|
Examples
|
|
--------
|
|
>>> import numpy as np
|
|
>>> from scipy.sparse import csc_matrix
|
|
>>> from scipy.sparse.linalg import minres
|
|
>>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=float)
|
|
>>> A = A + A.T
|
|
>>> b = np.array([2, 4, -1], dtype=float)
|
|
>>> x, exitCode = minres(A, b)
|
|
>>> print(exitCode) # 0 indicates successful convergence
|
|
0
|
|
>>> np.allclose(A.dot(x), b)
|
|
True
|
|
|
|
References
|
|
----------
|
|
Solution of sparse indefinite systems of linear equations,
|
|
C. C. Paige and M. A. Saunders (1975),
|
|
SIAM J. Numer. Anal. 12(4), pp. 617-629.
|
|
https://web.stanford.edu/group/SOL/software/minres/
|
|
|
|
This file is a translation of the following MATLAB implementation:
|
|
https://web.stanford.edu/group/SOL/software/minres/minres-matlab.zip
|
|
|
|
"""
|
|
A, M, x, b, postprocess = make_system(A, M, x0, b)
|
|
|
|
if tol is not _NoValue:
|
|
msg = ("'scipy.sparse.linalg.minres' keyword argument `tol` is "
|
|
"deprecated in favor of `rtol` and will be removed in SciPy "
|
|
"v1.14. Until then, if set, it will override `rtol`.")
|
|
warnings.warn(msg, category=DeprecationWarning, stacklevel=4)
|
|
rtol = float(tol) if tol is not None else rtol
|
|
|
|
matvec = A.matvec
|
|
psolve = M.matvec
|
|
|
|
first = 'Enter minres. '
|
|
last = 'Exit minres. '
|
|
|
|
n = A.shape[0]
|
|
|
|
if maxiter is None:
|
|
maxiter = 5 * n
|
|
|
|
msg = [' beta2 = 0. If M = I, b and x are eigenvectors ', # -1
|
|
' beta1 = 0. The exact solution is x0 ', # 0
|
|
' A solution to Ax = b was found, given rtol ', # 1
|
|
' A least-squares solution was found, given rtol ', # 2
|
|
' Reasonable accuracy achieved, given eps ', # 3
|
|
' x has converged to an eigenvector ', # 4
|
|
' acond has exceeded 0.1/eps ', # 5
|
|
' The iteration limit was reached ', # 6
|
|
' A does not define a symmetric matrix ', # 7
|
|
' M does not define a symmetric matrix ', # 8
|
|
' M does not define a pos-def preconditioner '] # 9
|
|
|
|
if show:
|
|
print(first + 'Solution of symmetric Ax = b')
|
|
print(first + f'n = {n:3g} shift = {shift:23.14e}')
|
|
print(first + f'itnlim = {maxiter:3g} rtol = {rtol:11.2e}')
|
|
print()
|
|
|
|
istop = 0
|
|
itn = 0
|
|
Anorm = 0
|
|
Acond = 0
|
|
rnorm = 0
|
|
ynorm = 0
|
|
|
|
xtype = x.dtype
|
|
|
|
eps = finfo(xtype).eps
|
|
|
|
# Set up y and v for the first Lanczos vector v1.
|
|
# y = beta1 P' v1, where P = C**(-1).
|
|
# v is really P' v1.
|
|
|
|
if x0 is None:
|
|
r1 = b.copy()
|
|
else:
|
|
r1 = b - A@x
|
|
y = psolve(r1)
|
|
|
|
beta1 = inner(r1, y)
|
|
|
|
if beta1 < 0:
|
|
raise ValueError('indefinite preconditioner')
|
|
elif beta1 == 0:
|
|
return (postprocess(x), 0)
|
|
|
|
bnorm = norm(b)
|
|
if bnorm == 0:
|
|
x = b
|
|
return (postprocess(x), 0)
|
|
|
|
beta1 = sqrt(beta1)
|
|
|
|
if check:
|
|
# are these too strict?
|
|
|
|
# see if A is symmetric
|
|
w = matvec(y)
|
|
r2 = matvec(w)
|
|
s = inner(w,w)
|
|
t = inner(y,r2)
|
|
z = abs(s - t)
|
|
epsa = (s + eps) * eps**(1.0/3.0)
|
|
if z > epsa:
|
|
raise ValueError('non-symmetric matrix')
|
|
|
|
# see if M is symmetric
|
|
r2 = psolve(y)
|
|
s = inner(y,y)
|
|
t = inner(r1,r2)
|
|
z = abs(s - t)
|
|
epsa = (s + eps) * eps**(1.0/3.0)
|
|
if z > epsa:
|
|
raise ValueError('non-symmetric preconditioner')
|
|
|
|
# Initialize other quantities
|
|
oldb = 0
|
|
beta = beta1
|
|
dbar = 0
|
|
epsln = 0
|
|
qrnorm = beta1
|
|
phibar = beta1
|
|
rhs1 = beta1
|
|
rhs2 = 0
|
|
tnorm2 = 0
|
|
gmax = 0
|
|
gmin = finfo(xtype).max
|
|
cs = -1
|
|
sn = 0
|
|
w = zeros(n, dtype=xtype)
|
|
w2 = zeros(n, dtype=xtype)
|
|
r2 = r1
|
|
|
|
if show:
|
|
print()
|
|
print()
|
|
print(' Itn x(1) Compatible LS norm(A) cond(A) gbar/|A|')
|
|
|
|
while itn < maxiter:
|
|
itn += 1
|
|
|
|
s = 1.0/beta
|
|
v = s*y
|
|
|
|
y = matvec(v)
|
|
y = y - shift * v
|
|
|
|
if itn >= 2:
|
|
y = y - (beta/oldb)*r1
|
|
|
|
alfa = inner(v,y)
|
|
y = y - (alfa/beta)*r2
|
|
r1 = r2
|
|
r2 = y
|
|
y = psolve(r2)
|
|
oldb = beta
|
|
beta = inner(r2,y)
|
|
if beta < 0:
|
|
raise ValueError('non-symmetric matrix')
|
|
beta = sqrt(beta)
|
|
tnorm2 += alfa**2 + oldb**2 + beta**2
|
|
|
|
if itn == 1:
|
|
if beta/beta1 <= 10*eps:
|
|
istop = -1 # Terminate later
|
|
|
|
# Apply previous rotation Qk-1 to get
|
|
# [deltak epslnk+1] = [cs sn][dbark 0 ]
|
|
# [gbar k dbar k+1] [sn -cs][alfak betak+1].
|
|
|
|
oldeps = epsln
|
|
delta = cs * dbar + sn * alfa # delta1 = 0 deltak
|
|
gbar = sn * dbar - cs * alfa # gbar 1 = alfa1 gbar k
|
|
epsln = sn * beta # epsln2 = 0 epslnk+1
|
|
dbar = - cs * beta # dbar 2 = beta2 dbar k+1
|
|
root = norm([gbar, dbar])
|
|
Arnorm = phibar * root
|
|
|
|
# Compute the next plane rotation Qk
|
|
|
|
gamma = norm([gbar, beta]) # gammak
|
|
gamma = max(gamma, eps)
|
|
cs = gbar / gamma # ck
|
|
sn = beta / gamma # sk
|
|
phi = cs * phibar # phik
|
|
phibar = sn * phibar # phibark+1
|
|
|
|
# Update x.
|
|
|
|
denom = 1.0/gamma
|
|
w1 = w2
|
|
w2 = w
|
|
w = (v - oldeps*w1 - delta*w2) * denom
|
|
x = x + phi*w
|
|
|
|
# Go round again.
|
|
|
|
gmax = max(gmax, gamma)
|
|
gmin = min(gmin, gamma)
|
|
z = rhs1 / gamma
|
|
rhs1 = rhs2 - delta*z
|
|
rhs2 = - epsln*z
|
|
|
|
# Estimate various norms and test for convergence.
|
|
|
|
Anorm = sqrt(tnorm2)
|
|
ynorm = norm(x)
|
|
epsa = Anorm * eps
|
|
epsx = Anorm * ynorm * eps
|
|
epsr = Anorm * ynorm * rtol
|
|
diag = gbar
|
|
|
|
if diag == 0:
|
|
diag = epsa
|
|
|
|
qrnorm = phibar
|
|
rnorm = qrnorm
|
|
if ynorm == 0 or Anorm == 0:
|
|
test1 = inf
|
|
else:
|
|
test1 = rnorm / (Anorm*ynorm) # ||r|| / (||A|| ||x||)
|
|
if Anorm == 0:
|
|
test2 = inf
|
|
else:
|
|
test2 = root / Anorm # ||Ar|| / (||A|| ||r||)
|
|
|
|
# Estimate cond(A).
|
|
# In this version we look at the diagonals of R in the
|
|
# factorization of the lower Hessenberg matrix, Q @ H = R,
|
|
# where H is the tridiagonal matrix from Lanczos with one
|
|
# extra row, beta(k+1) e_k^T.
|
|
|
|
Acond = gmax/gmin
|
|
|
|
# See if any of the stopping criteria are satisfied.
|
|
# In rare cases, istop is already -1 from above (Abar = const*I).
|
|
|
|
if istop == 0:
|
|
t1 = 1 + test1 # These tests work if rtol < eps
|
|
t2 = 1 + test2
|
|
if t2 <= 1:
|
|
istop = 2
|
|
if t1 <= 1:
|
|
istop = 1
|
|
|
|
if itn >= maxiter:
|
|
istop = 6
|
|
if Acond >= 0.1/eps:
|
|
istop = 4
|
|
if epsx >= beta1:
|
|
istop = 3
|
|
# if rnorm <= epsx : istop = 2
|
|
# if rnorm <= epsr : istop = 1
|
|
if test2 <= rtol:
|
|
istop = 2
|
|
if test1 <= rtol:
|
|
istop = 1
|
|
|
|
# See if it is time to print something.
|
|
|
|
prnt = False
|
|
if n <= 40:
|
|
prnt = True
|
|
if itn <= 10:
|
|
prnt = True
|
|
if itn >= maxiter-10:
|
|
prnt = True
|
|
if itn % 10 == 0:
|
|
prnt = True
|
|
if qrnorm <= 10*epsx:
|
|
prnt = True
|
|
if qrnorm <= 10*epsr:
|
|
prnt = True
|
|
if Acond <= 1e-2/eps:
|
|
prnt = True
|
|
if istop != 0:
|
|
prnt = True
|
|
|
|
if show and prnt:
|
|
str1 = f'{itn:6g} {x[0]:12.5e} {test1:10.3e}'
|
|
str2 = f' {test2:10.3e}'
|
|
str3 = f' {Anorm:8.1e} {Acond:8.1e} {gbar/Anorm:8.1e}'
|
|
|
|
print(str1 + str2 + str3)
|
|
|
|
if itn % 10 == 0:
|
|
print()
|
|
|
|
if callback is not None:
|
|
callback(x)
|
|
|
|
if istop != 0:
|
|
break # TODO check this
|
|
|
|
if show:
|
|
print()
|
|
print(last + f' istop = {istop:3g} itn ={itn:5g}')
|
|
print(last + f' Anorm = {Anorm:12.4e} Acond = {Acond:12.4e}')
|
|
print(last + f' rnorm = {rnorm:12.4e} ynorm = {ynorm:12.4e}')
|
|
print(last + f' Arnorm = {Arnorm:12.4e}')
|
|
print(last + msg[istop+1])
|
|
|
|
if istop == 6:
|
|
info = maxiter
|
|
else:
|
|
info = 0
|
|
|
|
return (postprocess(x),info)
|