Source code for torch_secorder.analysis.eigensolvers

"""Eigensolvers for computing top-K eigenvalues and eigenvectors of Hessian/Fisher matrices.

This module provides iterative methods for computing eigenvalues and eigenvectors
of large matrices that can only be accessed through matrix-vector products.
"""

from typing import Callable, List, Optional, Tuple

import torch

from ..core.utils import flatten_params, unflatten_params


[docs] def power_iteration( matrix_vector_product: Callable[[torch.Tensor], torch.Tensor], dim: int, num_iterations: int = 100, num_vectors: int = 1, tol: float = 1e-6, device: Optional[torch.device] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute the top-K eigenvalues and eigenvectors using power iteration. Args: matrix_vector_product: Function that computes matrix-vector product dim: Dimension of the matrix num_iterations: Maximum number of iterations num_vectors: Number of top eigenvectors to compute tol: Convergence tolerance device: Device to use for computation Returns: Tuple of (eigenvalues, eigenvectors) """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize random vectors vectors = torch.randn(dim, num_vectors, device=device) vectors = vectors / torch.norm(vectors, dim=0, keepdim=True) eigenvalues = torch.zeros(num_vectors, device=device) prev_eigenvalues = torch.zeros(num_vectors, device=device) for _ in range(num_iterations): # Matrix-vector products vectors = matrix_vector_product(vectors) # Orthogonalize using Gram-Schmidt for i in range(num_vectors): for j in range(i): vectors[:, i] -= torch.dot(vectors[:, i], vectors[:, j]) * vectors[:, j] vectors[:, i] = vectors[:, i] / torch.norm(vectors[:, i]) # Compute eigenvalues eigenvalues = torch.diag(vectors.T @ matrix_vector_product(vectors)) # Check convergence if torch.all(torch.abs(eigenvalues - prev_eigenvalues) < tol): break prev_eigenvalues = eigenvalues.clone() return eigenvalues, vectors
[docs] def lanczos( matrix_vector_product: Callable[[torch.Tensor], torch.Tensor], dim: int, num_iterations: int = 100, num_vectors: int = 1, tol: float = 1e-6, device: Optional[torch.device] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute the top-K eigenvalues and eigenvectors using Lanczos algorithm. Args: matrix_vector_product: Function that computes matrix-vector product dim: Dimension of the matrix num_iterations: Maximum number of iterations num_vectors: Number of top eigenvectors to compute tol: Convergence tolerance device: Device to use for computation Returns: Tuple of (eigenvalues, eigenvectors) """ if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Initialize first vector v = torch.randn(dim, device=device) v = v / torch.norm(v) # Initialize Lanczos vectors and tridiagonal matrix V = torch.zeros(dim, num_iterations, device=device) T = torch.zeros(num_iterations, num_iterations, device=device) V[:, 0] = v # First iteration w = matrix_vector_product(v) alpha = torch.dot(w, v) T[0, 0] = alpha w = w - alpha * v beta = torch.norm(w) T[0, 1] = beta T[1, 0] = beta # Main Lanczos iterations for i in range(1, num_iterations): if beta < tol: break v = w / beta V[:, i] = v w = matrix_vector_product(v) w = w - beta * V[:, i - 1] alpha = torch.dot(w, v) T[i, i] = alpha w = w - alpha * v beta = torch.norm(w) if i < num_iterations - 1: T[i, i + 1] = beta T[i + 1, i] = beta # Compute eigenvalues and eigenvectors of tridiagonal matrix eigenvals, eigenvecs = torch.linalg.eigh(T[: i + 1, : i + 1]) eigenvals, idx = torch.sort(eigenvals, descending=True) eigenvecs = eigenvecs[:, idx] # Convert to original space eigenvectors = V[:, : i + 1] @ eigenvecs[:, :num_vectors] return eigenvals[:num_vectors], eigenvectors
[docs] def model_eigenvalues( model: torch.nn.Module, loss_fn: Callable[[torch.nn.Module, torch.Tensor, torch.Tensor], torch.Tensor], data: torch.Tensor, target: torch.Tensor, num_eigenvalues: int = 1, method: str = "lanczos", num_iterations: int = 100, tol: float = 1e-6, device: Optional[torch.device] = None, ) -> Tuple[torch.Tensor, List[List[torch.Tensor]]]: """Compute top-K eigenvalues and eigenvectors of the Hessian matrix for a model. Args: model: PyTorch model loss_fn: Loss function that takes (model, data, target) as arguments data: Input data target: Target data num_eigenvalues: Number of top eigenvalues to compute method: Method to use ('power_iteration' or 'lanczos') num_iterations: Maximum number of iterations tol: Convergence tolerance device: Device to use for computation Returns: Tuple of (eigenvalues, eigenvectors) """ if device is None: device = next(model.parameters()).device # Get flattened parameters params = list(model.parameters()) param_shapes = [p.shape for p in params] flat_params = flatten_params(params) def hvp(v: torch.Tensor) -> torch.Tensor: """Compute Hessian-vector product. Handles both single and multiple vectors.""" # Compute loss loss = loss_fn(model, data, target) grads = torch.autograd.grad(loss, params, create_graph=True) flat_grads = flatten_params(grads) if v.ndim == 1: # Single vector hvp = torch.autograd.grad( torch.dot(flat_grads, v), params, create_graph=False, allow_unused=True ) return flatten_params(hvp) elif v.ndim == 2: # Multiple vectors (batch mode) outs = [] for i in range(v.shape[1]): retain = i != v.shape[1] - 1 hvp = torch.autograd.grad( torch.dot(flat_grads, v[:, i]), params, create_graph=False, allow_unused=True, retain_graph=retain, ) outs.append(flatten_params(hvp).unsqueeze(1)) return torch.cat(outs, dim=1) else: raise ValueError("Input vector v must be 1D or 2D tensor.") # Choose eigensolver if method.lower() == "power_iteration": eigenvalues, eigenvectors = power_iteration( hvp, flat_params.numel(), num_iterations, num_eigenvalues, tol, device ) elif method.lower() == "lanczos": eigenvalues, eigenvectors = lanczos( hvp, flat_params.numel(), num_iterations, num_eigenvalues, tol, device ) else: raise ValueError(f"Unknown method: {method}") # Convert eigenvectors back to parameter space param_eigenvectors = [] for i in range(num_eigenvalues): param_eigenvectors.append(unflatten_params(eigenvectors[:, i], param_shapes)) return eigenvalues, param_eigenvectors
estimate_eigenvalues = model_eigenvalues lanczos_iteration = lanczos __all__ = [ "power_iteration", "lanczos", "model_eigenvalues", "estimate_eigenvalues", "lanczos_iteration", ]