import warnings from numpy import inner, zeros, inf, finfo from numpy.linalg import norm from math import sqrt from .utils import make_system from scipy._lib.deprecation import _NoValue, _deprecate_positional_args __all__ = ['minres'] @_deprecate_positional_args(version="1.14.0") def minres(A, b, x0=None, *, shift=0.0, tol=_NoValue, maxiter=None, M=None, callback=None, show=False, check=False, rtol=1e-5): """ Use MINimum RESidual iteration to solve Ax=b MINRES minimizes norm(Ax - b) for a real symmetric matrix A. Unlike the Conjugate Gradient method, A can be indefinite or singular. If shift != 0 then the method solves (A - shift*I)x = b Parameters ---------- A : {sparse matrix, ndarray, LinearOperator} The real symmetric N-by-N matrix of the linear system Alternatively, ``A`` can be a linear operator which can produce ``Ax`` using, e.g., ``scipy.sparse.linalg.LinearOperator``. b : ndarray Right hand side of the linear system. Has shape (N,) or (N,1). Returns ------- x : ndarray The converged solution. info : integer Provides convergence information: 0 : successful exit >0 : convergence to tolerance not achieved, number of iterations <0 : illegal input or breakdown Other Parameters ---------------- x0 : ndarray Starting guess for the solution. shift : float Value to apply to the system ``(A - shift * I)x = b``. Default is 0. rtol : float Tolerance to achieve. The algorithm terminates when the relative residual is below ``rtol``. maxiter : integer Maximum number of iterations. Iteration will stop after maxiter steps even if the specified tolerance has not been achieved. M : {sparse matrix, ndarray, LinearOperator} Preconditioner for A. The preconditioner should approximate the inverse of A. Effective preconditioning dramatically improves the rate of convergence, which implies that fewer iterations are needed to reach a given error tolerance. callback : function User-supplied function to call after each iteration. It is called as callback(xk), where xk is the current solution vector. show : bool If ``True``, print out a summary and metrics related to the solution during iterations. Default is ``False``. check : bool If ``True``, run additional input validation to check that `A` and `M` (if specified) are symmetric. Default is ``False``. tol : float, optional, deprecated .. deprecated:: 1.12.0 `minres` keyword argument ``tol`` is deprecated in favor of ``rtol`` and will be removed in SciPy 1.14.0. Examples -------- >>> import numpy as np >>> from scipy.sparse import csc_matrix >>> from scipy.sparse.linalg import minres >>> A = csc_matrix([[3, 2, 0], [1, -1, 0], [0, 5, 1]], dtype=float) >>> A = A + A.T >>> b = np.array([2, 4, -1], dtype=float) >>> x, exitCode = minres(A, b) >>> print(exitCode) # 0 indicates successful convergence 0 >>> np.allclose(A.dot(x), b) True References ---------- Solution of sparse indefinite systems of linear equations, C. C. Paige and M. A. Saunders (1975), SIAM J. Numer. Anal. 12(4), pp. 617-629. https://web.stanford.edu/group/SOL/software/minres/ This file is a translation of the following MATLAB implementation: https://web.stanford.edu/group/SOL/software/minres/minres-matlab.zip """ A, M, x, b, postprocess = make_system(A, M, x0, b) if tol is not _NoValue: msg = ("'scipy.sparse.linalg.minres' keyword argument `tol` is " "deprecated in favor of `rtol` and will be removed in SciPy " "v1.14. Until then, if set, it will override `rtol`.") warnings.warn(msg, category=DeprecationWarning, stacklevel=4) rtol = float(tol) if tol is not None else rtol matvec = A.matvec psolve = M.matvec first = 'Enter minres. ' last = 'Exit minres. ' n = A.shape[0] if maxiter is None: maxiter = 5 * n msg = [' beta2 = 0. If M = I, b and x are eigenvectors ', # -1 ' beta1 = 0. The exact solution is x0 ', # 0 ' A solution to Ax = b was found, given rtol ', # 1 ' A least-squares solution was found, given rtol ', # 2 ' Reasonable accuracy achieved, given eps ', # 3 ' x has converged to an eigenvector ', # 4 ' acond has exceeded 0.1/eps ', # 5 ' The iteration limit was reached ', # 6 ' A does not define a symmetric matrix ', # 7 ' M does not define a symmetric matrix ', # 8 ' M does not define a pos-def preconditioner '] # 9 if show: print(first + 'Solution of symmetric Ax = b') print(first + f'n = {n:3g} shift = {shift:23.14e}') print(first + f'itnlim = {maxiter:3g} rtol = {rtol:11.2e}') print() istop = 0 itn = 0 Anorm = 0 Acond = 0 rnorm = 0 ynorm = 0 xtype = x.dtype eps = finfo(xtype).eps # Set up y and v for the first Lanczos vector v1. # y = beta1 P' v1, where P = C**(-1). # v is really P' v1. if x0 is None: r1 = b.copy() else: r1 = b - A@x y = psolve(r1) beta1 = inner(r1, y) if beta1 < 0: raise ValueError('indefinite preconditioner') elif beta1 == 0: return (postprocess(x), 0) bnorm = norm(b) if bnorm == 0: x = b return (postprocess(x), 0) beta1 = sqrt(beta1) if check: # are these too strict? # see if A is symmetric w = matvec(y) r2 = matvec(w) s = inner(w,w) t = inner(y,r2) z = abs(s - t) epsa = (s + eps) * eps**(1.0/3.0) if z > epsa: raise ValueError('non-symmetric matrix') # see if M is symmetric r2 = psolve(y) s = inner(y,y) t = inner(r1,r2) z = abs(s - t) epsa = (s + eps) * eps**(1.0/3.0) if z > epsa: raise ValueError('non-symmetric preconditioner') # Initialize other quantities oldb = 0 beta = beta1 dbar = 0 epsln = 0 qrnorm = beta1 phibar = beta1 rhs1 = beta1 rhs2 = 0 tnorm2 = 0 gmax = 0 gmin = finfo(xtype).max cs = -1 sn = 0 w = zeros(n, dtype=xtype) w2 = zeros(n, dtype=xtype) r2 = r1 if show: print() print() print(' Itn x(1) Compatible LS norm(A) cond(A) gbar/|A|') while itn < maxiter: itn += 1 s = 1.0/beta v = s*y y = matvec(v) y = y - shift * v if itn >= 2: y = y - (beta/oldb)*r1 alfa = inner(v,y) y = y - (alfa/beta)*r2 r1 = r2 r2 = y y = psolve(r2) oldb = beta beta = inner(r2,y) if beta < 0: raise ValueError('non-symmetric matrix') beta = sqrt(beta) tnorm2 += alfa**2 + oldb**2 + beta**2 if itn == 1: if beta/beta1 <= 10*eps: istop = -1 # Terminate later # Apply previous rotation Qk-1 to get # [deltak epslnk+1] = [cs sn][dbark 0 ] # [gbar k dbar k+1] [sn -cs][alfak betak+1]. oldeps = epsln delta = cs * dbar + sn * alfa # delta1 = 0 deltak gbar = sn * dbar - cs * alfa # gbar 1 = alfa1 gbar k epsln = sn * beta # epsln2 = 0 epslnk+1 dbar = - cs * beta # dbar 2 = beta2 dbar k+1 root = norm([gbar, dbar]) Arnorm = phibar * root # Compute the next plane rotation Qk gamma = norm([gbar, beta]) # gammak gamma = max(gamma, eps) cs = gbar / gamma # ck sn = beta / gamma # sk phi = cs * phibar # phik phibar = sn * phibar # phibark+1 # Update x. denom = 1.0/gamma w1 = w2 w2 = w w = (v - oldeps*w1 - delta*w2) * denom x = x + phi*w # Go round again. gmax = max(gmax, gamma) gmin = min(gmin, gamma) z = rhs1 / gamma rhs1 = rhs2 - delta*z rhs2 = - epsln*z # Estimate various norms and test for convergence. Anorm = sqrt(tnorm2) ynorm = norm(x) epsa = Anorm * eps epsx = Anorm * ynorm * eps epsr = Anorm * ynorm * rtol diag = gbar if diag == 0: diag = epsa qrnorm = phibar rnorm = qrnorm if ynorm == 0 or Anorm == 0: test1 = inf else: test1 = rnorm / (Anorm*ynorm) # ||r|| / (||A|| ||x||) if Anorm == 0: test2 = inf else: test2 = root / Anorm # ||Ar|| / (||A|| ||r||) # Estimate cond(A). # In this version we look at the diagonals of R in the # factorization of the lower Hessenberg matrix, Q @ H = R, # where H is the tridiagonal matrix from Lanczos with one # extra row, beta(k+1) e_k^T. Acond = gmax/gmin # See if any of the stopping criteria are satisfied. # In rare cases, istop is already -1 from above (Abar = const*I). if istop == 0: t1 = 1 + test1 # These tests work if rtol < eps t2 = 1 + test2 if t2 <= 1: istop = 2 if t1 <= 1: istop = 1 if itn >= maxiter: istop = 6 if Acond >= 0.1/eps: istop = 4 if epsx >= beta1: istop = 3 # if rnorm <= epsx : istop = 2 # if rnorm <= epsr : istop = 1 if test2 <= rtol: istop = 2 if test1 <= rtol: istop = 1 # See if it is time to print something. prnt = False if n <= 40: prnt = True if itn <= 10: prnt = True if itn >= maxiter-10: prnt = True if itn % 10 == 0: prnt = True if qrnorm <= 10*epsx: prnt = True if qrnorm <= 10*epsr: prnt = True if Acond <= 1e-2/eps: prnt = True if istop != 0: prnt = True if show and prnt: str1 = f'{itn:6g} {x[0]:12.5e} {test1:10.3e}' str2 = f' {test2:10.3e}' str3 = f' {Anorm:8.1e} {Acond:8.1e} {gbar/Anorm:8.1e}' print(str1 + str2 + str3) if itn % 10 == 0: print() if callback is not None: callback(x) if istop != 0: break # TODO check this if show: print() print(last + f' istop = {istop:3g} itn ={itn:5g}') print(last + f' Anorm = {Anorm:12.4e} Acond = {Acond:12.4e}') print(last + f' rnorm = {rnorm:12.4e} ynorm = {ynorm:12.4e}') print(last + f' Arnorm = {Arnorm:12.4e}') print(last + msg[istop+1]) if istop == 6: info = maxiter else: info = 0 return (postprocess(x),info)