ai-content-maker/.venv/Lib/site-packages/scipy/sparse/linalg/_isolve/minres.py

388 lines
11 KiB
Python

import warnings
from numpy import inner, zeros, inf, finfo
from numpy.linalg import norm
from math import sqrt
from .utils import make_system
from scipy._lib.deprecation import _NoValue, _deprecate_positional_args
__all__ = ['minres']
@_deprecate_positional_args(version="1.14.0")
def minres(A, b, x0=None, *, shift=0.0, tol=_NoValue, maxiter=None,
M=None, callback=None, show=False, check=False, rtol=1e-5):
"""
Use MINimum RESidual iteration to solve Ax=b
MINRES minimizes norm(Ax - b) for a real symmetric matrix A. Unlike
the Conjugate Gradient method, A can be indefinite or singular.
If shift != 0 then the method solves (A - shift*I)x = b
Parameters
----------
A : {sparse matrix, ndarray, LinearOperator}
The real symmetric N-by-N matrix of the linear system
Alternatively, ``A`` can be a linear operator which can
produce ``Ax`` using, e.g.,
``scipy.sparse.linalg.LinearOperator``.
b : ndarray
Right hand side of the linear system. Has shape (N,) or (N,1).
Returns
-------
x : ndarray
The converged solution.
info : integer
Provides convergence information:
0 : successful exit
>0 : convergence to tolerance not achieved, number of iterations
<0 : illegal input or breakdown
Other Parameters
----------------
x0 : ndarray
Starting guess for the solution.
shift : float
Value to apply to the system ``(A - shift * I)x = b``. Default is 0.
rtol : float
Tolerance to achieve. The algorithm terminates when the relative
residual is below ``rtol``.
maxiter : integer
Maximum number of iterations. Iteration will stop after maxiter
steps even if the specified tolerance has not been achieved.
M : {sparse matrix, ndarray, LinearOperator}
Preconditioner for A. The preconditioner should approximate the
inverse of A. Effective preconditioning dramatically improves the
rate of convergence, which implies that fewer iterations are needed
to reach a given error tolerance.
callback : function
User-supplied function to call after each iteration. It is called
as callback(xk), where xk is the current solution vector.
show : bool
If ``True``, print out a summary and metrics related to the solution
during iterations. Default is ``False``.
check : bool
If ``True``, run additional input validation to check that `A` and
`M` (if specified) are symmetric. Default is ``False``.
tol : float, optional, deprecated
.. deprecated:: 1.12.0
`minres` keyword argument ``tol`` is deprecated in favor of ``rtol``
and will be removed in SciPy 1.14.0.
Examples
--------
>>> import numpy as np
>>> from scipy.sparse import csc_matrix
>>> from scipy.sparse.linalg import minres
>>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=float)
>>> A = A + A.T
>>> b = np.array([2, 4, -1], dtype=float)
>>> x, exitCode = minres(A, b)
>>> print(exitCode) # 0 indicates successful convergence
0
>>> np.allclose(A.dot(x), b)
True
References
----------
Solution of sparse indefinite systems of linear equations,
C. C. Paige and M. A. Saunders (1975),
SIAM J. Numer. Anal. 12(4), pp. 617-629.
https://web.stanford.edu/group/SOL/software/minres/
This file is a translation of the following MATLAB implementation:
https://web.stanford.edu/group/SOL/software/minres/minres-matlab.zip
"""
A, M, x, b, postprocess = make_system(A, M, x0, b)
if tol is not _NoValue:
msg = ("'scipy.sparse.linalg.minres' keyword argument `tol` is "
"deprecated in favor of `rtol` and will be removed in SciPy "
"v1.14. Until then, if set, it will override `rtol`.")
warnings.warn(msg, category=DeprecationWarning, stacklevel=4)
rtol = float(tol) if tol is not None else rtol
matvec = A.matvec
psolve = M.matvec
first = 'Enter minres. '
last = 'Exit minres. '
n = A.shape[0]
if maxiter is None:
maxiter = 5 * n
msg = [' beta2 = 0. If M = I, b and x are eigenvectors ', # -1
' beta1 = 0. The exact solution is x0 ', # 0
' A solution to Ax = b was found, given rtol ', # 1
' A least-squares solution was found, given rtol ', # 2
' Reasonable accuracy achieved, given eps ', # 3
' x has converged to an eigenvector ', # 4
' acond has exceeded 0.1/eps ', # 5
' The iteration limit was reached ', # 6
' A does not define a symmetric matrix ', # 7
' M does not define a symmetric matrix ', # 8
' M does not define a pos-def preconditioner '] # 9
if show:
print(first + 'Solution of symmetric Ax = b')
print(first + f'n = {n:3g} shift = {shift:23.14e}')
print(first + f'itnlim = {maxiter:3g} rtol = {rtol:11.2e}')
print()
istop = 0
itn = 0
Anorm = 0
Acond = 0
rnorm = 0
ynorm = 0
xtype = x.dtype
eps = finfo(xtype).eps
# Set up y and v for the first Lanczos vector v1.
# y = beta1 P' v1, where P = C**(-1).
# v is really P' v1.
if x0 is None:
r1 = b.copy()
else:
r1 = b - A@x
y = psolve(r1)
beta1 = inner(r1, y)
if beta1 < 0:
raise ValueError('indefinite preconditioner')
elif beta1 == 0:
return (postprocess(x), 0)
bnorm = norm(b)
if bnorm == 0:
x = b
return (postprocess(x), 0)
beta1 = sqrt(beta1)
if check:
# are these too strict?
# see if A is symmetric
w = matvec(y)
r2 = matvec(w)
s = inner(w,w)
t = inner(y,r2)
z = abs(s - t)
epsa = (s + eps) * eps**(1.0/3.0)
if z > epsa:
raise ValueError('non-symmetric matrix')
# see if M is symmetric
r2 = psolve(y)
s = inner(y,y)
t = inner(r1,r2)
z = abs(s - t)
epsa = (s + eps) * eps**(1.0/3.0)
if z > epsa:
raise ValueError('non-symmetric preconditioner')
# Initialize other quantities
oldb = 0
beta = beta1
dbar = 0
epsln = 0
qrnorm = beta1
phibar = beta1
rhs1 = beta1
rhs2 = 0
tnorm2 = 0
gmax = 0
gmin = finfo(xtype).max
cs = -1
sn = 0
w = zeros(n, dtype=xtype)
w2 = zeros(n, dtype=xtype)
r2 = r1
if show:
print()
print()
print(' Itn x(1) Compatible LS norm(A) cond(A) gbar/|A|')
while itn < maxiter:
itn += 1
s = 1.0/beta
v = s*y
y = matvec(v)
y = y - shift * v
if itn >= 2:
y = y - (beta/oldb)*r1
alfa = inner(v,y)
y = y - (alfa/beta)*r2
r1 = r2
r2 = y
y = psolve(r2)
oldb = beta
beta = inner(r2,y)
if beta < 0:
raise ValueError('non-symmetric matrix')
beta = sqrt(beta)
tnorm2 += alfa**2 + oldb**2 + beta**2
if itn == 1:
if beta/beta1 <= 10*eps:
istop = -1 # Terminate later
# Apply previous rotation Qk-1 to get
# [deltak epslnk+1] = [cs sn][dbark 0 ]
# [gbar k dbar k+1] [sn -cs][alfak betak+1].
oldeps = epsln
delta = cs * dbar + sn * alfa # delta1 = 0 deltak
gbar = sn * dbar - cs * alfa # gbar 1 = alfa1 gbar k
epsln = sn * beta # epsln2 = 0 epslnk+1
dbar = - cs * beta # dbar 2 = beta2 dbar k+1
root = norm([gbar, dbar])
Arnorm = phibar * root
# Compute the next plane rotation Qk
gamma = norm([gbar, beta]) # gammak
gamma = max(gamma, eps)
cs = gbar / gamma # ck
sn = beta / gamma # sk
phi = cs * phibar # phik
phibar = sn * phibar # phibark+1
# Update x.
denom = 1.0/gamma
w1 = w2
w2 = w
w = (v - oldeps*w1 - delta*w2) * denom
x = x + phi*w
# Go round again.
gmax = max(gmax, gamma)
gmin = min(gmin, gamma)
z = rhs1 / gamma
rhs1 = rhs2 - delta*z
rhs2 = - epsln*z
# Estimate various norms and test for convergence.
Anorm = sqrt(tnorm2)
ynorm = norm(x)
epsa = Anorm * eps
epsx = Anorm * ynorm * eps
epsr = Anorm * ynorm * rtol
diag = gbar
if diag == 0:
diag = epsa
qrnorm = phibar
rnorm = qrnorm
if ynorm == 0 or Anorm == 0:
test1 = inf
else:
test1 = rnorm / (Anorm*ynorm) # ||r|| / (||A|| ||x||)
if Anorm == 0:
test2 = inf
else:
test2 = root / Anorm # ||Ar|| / (||A|| ||r||)
# Estimate cond(A).
# In this version we look at the diagonals of R in the
# factorization of the lower Hessenberg matrix, Q @ H = R,
# where H is the tridiagonal matrix from Lanczos with one
# extra row, beta(k+1) e_k^T.
Acond = gmax/gmin
# See if any of the stopping criteria are satisfied.
# In rare cases, istop is already -1 from above (Abar = const*I).
if istop == 0:
t1 = 1 + test1 # These tests work if rtol < eps
t2 = 1 + test2
if t2 <= 1:
istop = 2
if t1 <= 1:
istop = 1
if itn >= maxiter:
istop = 6
if Acond >= 0.1/eps:
istop = 4
if epsx >= beta1:
istop = 3
# if rnorm <= epsx : istop = 2
# if rnorm <= epsr : istop = 1
if test2 <= rtol:
istop = 2
if test1 <= rtol:
istop = 1
# See if it is time to print something.
prnt = False
if n <= 40:
prnt = True
if itn <= 10:
prnt = True
if itn >= maxiter-10:
prnt = True
if itn % 10 == 0:
prnt = True
if qrnorm <= 10*epsx:
prnt = True
if qrnorm <= 10*epsr:
prnt = True
if Acond <= 1e-2/eps:
prnt = True
if istop != 0:
prnt = True
if show and prnt:
str1 = f'{itn:6g} {x[0]:12.5e} {test1:10.3e}'
str2 = f' {test2:10.3e}'
str3 = f' {Anorm:8.1e} {Acond:8.1e} {gbar/Anorm:8.1e}'
print(str1 + str2 + str3)
if itn % 10 == 0:
print()
if callback is not None:
callback(x)
if istop != 0:
break # TODO check this
if show:
print()
print(last + f' istop = {istop:3g} itn ={itn:5g}')
print(last + f' Anorm = {Anorm:12.4e} Acond = {Acond:12.4e}')
print(last + f' rnorm = {rnorm:12.4e} ynorm = {ynorm:12.4e}')
print(last + f' Arnorm = {Arnorm:12.4e}')
print(last + msg[istop+1])
if istop == 6:
info = maxiter
else:
info = 0
return (postprocess(x),info)