"""Sherman-Morrison-Woodbury formula implementations."""
from __future__ import annotations
import numpy as np
from scipy.linalg import lapack
from .linalg import (
cholesky_decomposition,
cholesky_solve,
matrix_inverse_symm,
matrix_mult,
matrix_slogdet_symm,
)
[docs]
def smw_inverse(N_inv, V_N_inv, V_Ninv_VT, Lambda_diag, threshold=1e-30):
"""
Sherman-Morrison-Woodbury inverse: (N + V^T Lambda V)^{-1}.
Parameters
----------
N_inv : numpy.ndarray
Precomputed inverse of N, shape (n, n).
V_N_inv : numpy.ndarray
Precomputed V @ N^{-1}, shape (k, n).
V_Ninv_VT : numpy.ndarray
Precomputed V @ N^{-1} @ V^T, shape (k, k).
Lambda_diag : numpy.ndarray
Diagonal of Lambda, shape (k,).
threshold : float, optional
Minimum value for Lambda diagonal elements. Default is 1e-30.
Returns
-------
numpy.ndarray
The inverse (N + V^T Lambda V)^{-1}, shape (n, n).
"""
lambda_inv_diag = np.where(
Lambda_diag > threshold, 1.0 / Lambda_diag, 1.0 / threshold
)
K = np.asfortranarray(V_Ninv_VT.copy())
n = K.shape[0]
K.flat[:: n + 1] += lambda_inv_diag
kernel_inv = matrix_inverse_symm(K, overwrite=True)
return N_inv - V_N_inv.T @ kernel_inv @ V_N_inv
[docs]
def smw_logdet(log_det_N, V_Ninv_VT, Lambda_diag, threshold=1e-30):
"""
Sherman-Morrison-Woodbury log determinant: log|N + V^T Lambda V|.
Parameters
----------
log_det_N : float
Precomputed log|N|.
V_Ninv_VT : numpy.ndarray
Precomputed V @ N^{-1} @ V^T, shape (k, k).
Lambda_diag : numpy.ndarray
Diagonal of Lambda, shape (k,).
threshold : float, optional
Minimum value for Lambda diagonal elements. Default is 1e-30.
Returns
-------
float
The log determinant log|N + V^T Lambda V|.
"""
log_det_Lambda = np.sum(np.log(np.maximum(Lambda_diag, threshold)))
lambda_inv_diag = np.where(
Lambda_diag > threshold, 1.0 / Lambda_diag, 1.0 / threshold
)
K = np.asfortranarray(V_Ninv_VT.copy())
n = K.shape[0]
K.flat[:: n + 1] += lambda_inv_diag
_, log_det_K = matrix_slogdet_symm(K)
return log_det_N + log_det_Lambda + log_det_K
[docs]
def smw_kernel(V_Ninv_VT, Lambda_diag, threshold=1e-30):
"""
Build the SMW kernel matrix K = Lambda^{-1} + V N^{-1} V^T.
Parameters
----------
V_Ninv_VT : numpy.ndarray
Precomputed V @ N^{-1} @ V^T, shape (k, k).
Lambda_diag : numpy.ndarray
Diagonal of Lambda, shape (k,).
threshold : float, optional
Minimum value for Lambda diagonal elements. Default is 1e-30.
Returns
-------
numpy.ndarray
The kernel matrix K, shape (k, k), in Fortran order.
"""
lambda_inv_diag = np.where(
Lambda_diag > threshold, 1.0 / Lambda_diag, 1.0 / threshold
)
K = np.asfortranarray(V_Ninv_VT.copy())
n = K.shape[0]
K.flat[:: n + 1] += lambda_inv_diag
return K