Source code for mgplvm.fast_utils.linear_cg

#!/usr/bin/env python3
"""
This file taken from GPytorch https://github.com/cornellius-gp/gpytorch/blob/c74b8ed8d590fcb1d79808d9ee7cde168588ef99/gpytorch/utils/linear_cg.py

With this License:
MIT License

Copyright (c) 2017 Jake Gardner

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

"""

import warnings

import torch


def _default_preconditioner(x):
    return x.clone()


@torch.jit.script
def _jit_linear_cg_updates(result, alpha, residual_inner_prod, eps, beta,
                           residual, precond_residual, mul_storage, is_zero,
                           curr_conjugate_vec):
    # # Update result
    # # result_{k} = result_{k-1} + alpha_{k} p_vec_{k-1}
    result = torch.addcmul(result, alpha, curr_conjugate_vec, out=result)

    # beta_{k} = (precon_residual{k}^T r_vec_{k}) / (precon_residual{k-1}^T r_vec_{k-1})
    beta.resize_as_(residual_inner_prod).copy_(residual_inner_prod)
    torch.mul(residual, precond_residual, out=mul_storage)
    torch.sum(mul_storage, -2, keepdim=True, out=residual_inner_prod)

    # Do a safe division here
    torch.lt(beta, eps, out=is_zero)
    beta.masked_fill_(is_zero, 1)
    torch.div(residual_inner_prod, beta, out=beta)
    beta.masked_fill_(is_zero, 0)

    # Update curr_conjugate_vec
    # curr_conjugate_vec_{k} = precon_residual{k} + beta_{k} curr_conjugate_vec_{k-1}
    curr_conjugate_vec.mul_(beta).add_(precond_residual)


@torch.jit.script
def _jit_linear_cg_updates_no_precond(
    mvms,
    result,
    has_converged,
    alpha,
    residual_inner_prod,
    eps,
    beta,
    residual,
    precond_residual,
    mul_storage,
    is_zero,
    curr_conjugate_vec,
):
    torch.mul(curr_conjugate_vec, mvms, out=mul_storage)
    torch.sum(mul_storage, dim=-2, keepdim=True, out=alpha)

    # Do a safe division here
    torch.lt(alpha, eps, out=is_zero)
    alpha.masked_fill_(is_zero, 1)
    torch.div(residual_inner_prod, alpha, out=alpha)
    alpha.masked_fill_(is_zero, 0)

    # We'll cancel out any updates by setting alpha=0 for any vector that has already converged
    alpha.masked_fill_(has_converged, 0)

    # Update residual
    # residual_{k} = residual_{k-1} - alpha_{k} mat p_vec_{k-1}
    torch.addcmul(residual, -alpha, mvms, out=residual)

    # Update precond_residual
    # precon_residual{k} = M^-1 residual_{k}
    precond_residual = residual.clone()

    _jit_linear_cg_updates(
        result,
        alpha,
        residual_inner_prod,
        eps,
        beta,
        residual,
        precond_residual,
        mul_storage,
        is_zero,
        curr_conjugate_vec,
    )


[docs]def linear_cg( matmul_closure, rhs, n_tridiag=0, tolerance=.01, eps=1e-10, stop_updating_after=1e-10, max_iter=1000, max_tridiag_iter=20, initial_guess=None, preconditioner=None, verbose=False, ): """ Implements the linear conjugate gradients method for (approximately) solving systems of the form lhs result = rhs for positive definite and symmetric matrices. Args: - matmul_closure - a function which performs a left matrix multiplication with lhs_mat - rhs - the right-hand side of the equation - n_tridiag - returns a tridiagonalization of the first n_tridiag columns of rhs - tolerance - stop the solve when the max residual is less than this - eps - noise to add to prevent division by zero - stop_updating_after - will stop updating a vector after this residual norm is reached - max_iter - the maximum number of CG iterations - max_tridiag_iter - the maximum size of the tridiagonalization matrix - initial_guess - an initial guess at the solution `result` - precondition_closure - a functions which left-preconditions a supplied vector Returns: result - a solution to the system (if n_tridiag is 0) result, tridiags - a solution to the system, and corresponding tridiagonal matrices (if n_tridiag > 0) """ # Unsqueeze, if necesasry is_vector = rhs.ndimension() == 1 if is_vector: rhs = rhs.unsqueeze(-1) # Some default arguments if initial_guess is None: initial_guess = torch.zeros_like(rhs) if preconditioner is None: preconditioner = _default_preconditioner precond = False else: precond = True # MGPLVM-GPYTORCH doesn't have other preconditioners. See gpytorch for example uses # If we are running m CG iterations, we obviously can't get more than m Lanczos coefficients if max_tridiag_iter > max_iter: raise RuntimeError( "Getting a tridiagonalization larger than the number of CG iterations run is not possible!" ) # Check matmul_closure object if torch.is_tensor(matmul_closure): matmul_closure = matmul_closure.matmul elif not callable(matmul_closure): raise RuntimeError( "matmul_closure must be a tensor, or a callable object!") # Get some constants num_rows = rhs.size(-2) n_iter = max_iter n_tridiag_iter = min(max_tridiag_iter, num_rows) eps = torch.tensor(eps, dtype=rhs.dtype, device=rhs.device) # Get the norm of the rhs - used for convergence checks # Here we're going to make almost-zero norms actually be 1 (so we don't get divide-by-zero issues) # But we'll store which norms were actually close to zero rhs_norm = rhs.norm(2, dim=-2, keepdim=True) rhs_is_zero = rhs_norm.lt(eps) rhs_norm = rhs_norm.masked_fill_(rhs_is_zero, 1) # Let's normalize. We'll un-normalize afterwards rhs = rhs.div(rhs_norm) # residual: residual_{0} = b_vec - lhs x_{0} residual = rhs - matmul_closure(initial_guess) batch_shape = residual.shape[:-2] # result <- x_{0} result = initial_guess.expand_as(residual).contiguous() if verbose: print( f"Running CG on a {rhs.shape} RHS for {n_iter} iterations (tol={tolerance}). Output: {result.shape}." ) # Check for NaNs if not torch.equal(residual, residual): raise RuntimeError( "NaNs encountered when trying to perform matrix-vector multiplication" ) # Sometime we're lucky and the preconditioner solves the system right away # Check for convergence residual_norm = residual.norm(2, dim=-2, keepdim=True) has_converged = torch.lt(residual_norm, stop_updating_after) if has_converged.all() and not n_tridiag: n_iter = 0 # Skip the iteration! # Otherwise, let's define precond_residual and curr_conjugate_vec else: # precon_residual{0} = M^-1 residual_{0} precond_residual = preconditioner(residual) curr_conjugate_vec = precond_residual residual_inner_prod = precond_residual.mul(residual).sum(-2, keepdim=True) # Define storage matrices mul_storage = torch.empty_like(residual) alpha = torch.empty(*batch_shape, 1, rhs.size(-1), dtype=residual.dtype, device=residual.device) beta = torch.empty_like(alpha) is_zero = torch.empty(*batch_shape, 1, rhs.size(-1), dtype=torch.bool, device=residual.device) # Define tridiagonal matrices, if applicable if n_tridiag: t_mat = torch.zeros(n_tridiag_iter, n_tridiag_iter, *batch_shape, n_tridiag, dtype=alpha.dtype, device=alpha.device) alpha_tridiag_is_zero = torch.empty(*batch_shape, n_tridiag, dtype=torch.bool, device=t_mat.device) alpha_reciprocal = torch.empty(*batch_shape, n_tridiag, dtype=t_mat.dtype, device=t_mat.device) prev_alpha_reciprocal = torch.empty_like(alpha_reciprocal) prev_beta = torch.empty_like(alpha_reciprocal) update_tridiag = True last_tridiag_iter = 0 # It's conceivable we reach the tolerance on the last iteration, so can't just check iteration number. tolerance_reached = False # Start the iteration for k in range(n_iter): # Get next alpha # alpha_{k} = (residual_{k-1}^T precon_residual{k-1}) / (p_vec_{k-1}^T mat p_vec_{k-1}) mvms = matmul_closure(curr_conjugate_vec) if precond: torch.mul(curr_conjugate_vec, mvms, out=mul_storage) torch.sum(mul_storage, -2, keepdim=True, out=alpha) # Do a safe division here torch.lt(alpha, eps, out=is_zero) alpha.masked_fill_(is_zero, 1) torch.div(residual_inner_prod, alpha, out=alpha) alpha.masked_fill_(is_zero, 0) # We'll cancel out any updates by setting alpha=0 for any vector that has already converged alpha.masked_fill_(has_converged, 0) # Update residual # residual_{k} = residual_{k-1} - alpha_{k} mat p_vec_{k-1} residual = torch.addcmul(residual, alpha, mvms, value=-1, out=residual) # Update precond_residual # precon_residual{k} = M^-1 residual_{k} precond_residual = preconditioner(residual) _jit_linear_cg_updates( result, alpha, residual_inner_prod, eps, beta, residual, precond_residual, mul_storage, is_zero, curr_conjugate_vec, ) else: _jit_linear_cg_updates_no_precond( mvms, result, has_converged, alpha, residual_inner_prod, eps, beta, residual, precond_residual, mul_storage, is_zero, curr_conjugate_vec, ) torch.norm(residual, 2, dim=-2, keepdim=True, out=residual_norm) residual_norm.masked_fill_(rhs_is_zero, 0) torch.lt(residual_norm, stop_updating_after, out=has_converged) if k >= 10 and bool(residual_norm.mean() < tolerance) and not ( n_tridiag and k < n_tridiag_iter): tolerance_reached = True break # Update tridiagonal matrices, if applicable if n_tridiag and k < n_tridiag_iter and update_tridiag: alpha_tridiag = alpha.squeeze_(-2).narrow(-1, 0, n_tridiag) beta_tridiag = beta.squeeze_(-2).narrow(-1, 0, n_tridiag) torch.eq(alpha_tridiag, 0, out=alpha_tridiag_is_zero) alpha_tridiag.masked_fill_(alpha_tridiag_is_zero, 1) torch.reciprocal(alpha_tridiag, out=alpha_reciprocal) alpha_tridiag.masked_fill_(alpha_tridiag_is_zero, 0) if k == 0: t_mat[k, k].copy_(alpha_reciprocal) else: torch.addcmul(alpha_reciprocal, prev_beta, prev_alpha_reciprocal, out=t_mat[k, k]) torch.mul(prev_beta.sqrt_(), prev_alpha_reciprocal, out=t_mat[k, k - 1]) t_mat[k - 1, k].copy_(t_mat[k, k - 1]) if t_mat[k - 1, k].max() < 1e-6: update_tridiag = False last_tridiag_iter = k prev_alpha_reciprocal.copy_(alpha_reciprocal) prev_beta.copy_(beta_tridiag) # Un-normalize result = result.mul(rhs_norm) if not tolerance_reached and n_iter > 0: warnings.warn( "CG terminated in {} iterations with average residual norm {}" " which is larger than the tolerance of {} specified by" " tolerance." " If performance is affected, consider raising the maximum number of CG iterations by running code by" " increasing the max_iter value.".format(k + 1, residual_norm.mean(), tolerance),) if is_vector: result = result.squeeze(-1) if n_tridiag: t_mat = t_mat[:last_tridiag_iter + 1, :last_tridiag_iter + 1] return result, t_mat.permute(-1, *range(2, 2 + len(batch_shape)), 0, 1).contiguous() else: return result