Source code for torch_secorder.core.hvp

"""Hessian-Vector Product (HVP) computation utilities.

This module provides functions to compute Hessian-vector products efficiently
using PyTorch's autograd system. Both exact and approximate methods are provided.
"""

from typing import Callable, List, Union

import torch
import torch.nn as nn


[docs] def exact_hvp( func: Callable[[], torch.Tensor], params: List[torch.Tensor], v: Union[torch.Tensor, List[torch.Tensor]], create_graph: bool = False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute the exact Hessian-vector product Hv using double backpropagation. Args: func: A callable that returns a scalar tensor (loss). params: List of parameters with respect to which to compute the Hessian. v: Vector to multiply with the Hessian. Can be a single tensor or a list of tensors matching the structure of params. create_graph: If True, the graph used to compute the grad will be constructed, allowing to compute higher order derivative products. Returns: The Hessian-vector product Hv. Returns a single tensor if v is a single tensor, otherwise returns a list of tensors matching the structure of params. """ if isinstance(v, torch.Tensor): v = [v] # First compute the gradient grad = torch.autograd.grad(func(), params, create_graph=True) # Then compute the Hessian-vector product # Ensure the output is a scalar by summing all elements grad_dot_v = sum((g * v_).sum() for g, v_ in zip(grad, v)) hvp = torch.autograd.grad( grad_dot_v, params, create_graph=create_graph, ) return hvp[0] if len(hvp) == 1 else hvp
[docs] def approximate_hvp( func: Callable[[], torch.Tensor], params: List[torch.Tensor], v: Union[torch.Tensor, List[torch.Tensor]], num_samples: int = 1, damping: float = 0.0, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute an approximate Hessian-vector product using finite differences. This method uses a finite difference approximation of the Hessian-vector product, which can be more memory efficient than the exact computation. Args: func: A callable that returns a scalar tensor (loss). params: List of parameters with respect to which to compute the Hessian. v: Vector to multiply with the Hessian. Can be a single tensor or a list of tensors matching the structure of params. num_samples: Number of samples to use for the approximation. damping: Damping term to add to the diagonal of the Hessian (lambda * I). Returns: The approximate Hessian-vector product Hv. Returns a single tensor if v is a single tensor, otherwise returns a list of tensors matching the structure of params. """ if isinstance(v, torch.Tensor): v = [v] eps = 1e-6 # Small constant for finite differences # Compute the gradient at the current point grad = torch.autograd.grad(func(), params, create_graph=False) # Initialize the result hvp = [torch.zeros_like(p) for p in params] for _ in range(num_samples): # Perturb parameters safely with torch.no_grad(): for p, vec in zip(params, v): p.add_(eps * vec) # Compute gradient at perturbed point grad_perturbed = torch.autograd.grad(func(), params, create_graph=False) with torch.no_grad(): for p, vec in zip(params, v): p.sub_(eps * vec) # Update HVP approximation for h, g, g_p in zip(hvp, grad, grad_perturbed): h.add_((g_p - g) / eps) # Average over samples for h in hvp: h.div_(num_samples) # Add damping if specified if damping > 0: for h, vec in zip(hvp, v): h.add_(damping * vec) return hvp[0] if len(hvp) == 1 else hvp
[docs] def model_hvp( model: nn.Module, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], x: torch.Tensor, y: torch.Tensor, v: Union[torch.Tensor, List[torch.Tensor]], create_graph: bool = False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute the Hessian-vector product for a model's loss function. This is a convenience wrapper around exact_hvp that handles the model's forward pass and loss computation. Args: model: The PyTorch model. loss_fn: The loss function that takes model output and target as arguments. x: Input tensor. y: Target tensor. v: Vector to multiply with the Hessian. create_graph: If True, the graph used to compute the grad will be constructed. Returns: The Hessian-vector product Hv. """ params = list(model.parameters()) def loss_func(): output = model(x) return loss_fn(output, y) return exact_hvp(loss_func, params, v, create_graph=create_graph)
[docs] def gauss_newton_product( model: nn.Module, loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], x: torch.Tensor, y: torch.Tensor, v: Union[torch.Tensor, List[torch.Tensor]], create_graph: bool = False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute the Gauss-Newton matrix-vector product for a model's loss. Args: model: The PyTorch model. loss_fn: The loss function (should be MSE or cross-entropy for GN to be valid). x: Input tensor. y: Target tensor. v: Vector to multiply with the Gauss-Newton matrix. create_graph: If True, graph of the derivative will be constructed. Returns: The Gauss-Newton matrix-vector product (same structure as v). """ params = list(model.parameters()) def output(): return model(x) out = output() out_flat = out.reshape(-1) # Compute Jv (J: Jacobian of model output wrt params) if isinstance(v, torch.Tensor): v = [v] jvp_vec = torch.zeros_like(out_flat) for i in range(out_flat.shape[0]): grad = torch.autograd.grad( out_flat[i], params, retain_graph=True, create_graph=True, allow_unused=True ) grad = [ (g if g is not None else torch.zeros_like(p)) for g, p in zip(grad, params) ] jvp_vec[i] = sum([(g * v_).sum() for g, v_ in zip(grad, v)]) # Compute Hessian of loss wrt model output (usually diagonal for MSE, cross-entropy) out = out.detach().requires_grad_(True) loss = loss_fn(out, y) grad_out = torch.autograd.grad(loss, out, create_graph=True)[0] # Reshape jvp_vec to match out's shape jvp_vec_reshaped = jvp_vec.reshape_as(out) gn_prod = torch.autograd.grad( grad_out, out, grad_outputs=jvp_vec_reshaped, create_graph=create_graph )[0] # Backpropagate to parameters gn_param = torch.autograd.grad( out, params, grad_outputs=gn_prod, create_graph=create_graph, allow_unused=True ) gn_param = [ (g if g is not None else torch.zeros_like(p)) for g, p in zip(gn_param, params) ] return gn_param[0] if len(gn_param) == 1 else gn_param
[docs] def hessian_trace( func: Callable[[], torch.Tensor], params: List[torch.Tensor], num_samples: int = 10, create_graph: bool = False, sparse: bool = False, ) -> float: """Estimate the trace of the Hessian using Hutchinson's method (random projections). Args: func: A callable that returns a scalar tensor (loss). params: List of parameters with respect to which to compute the Hessian. num_samples: Number of random projections to use. create_graph: If True, graph of the derivative will be constructed. sparse: If True, use sparse random vectors for projections. Returns: Estimated trace of the Hessian. """ trace = 0.0 for _ in range(num_samples): vs = [] for p in params: if sparse: # Sparse Rademacher vector v = torch.randint_like(p, low=0, high=2, dtype=torch.float32) * 2 - 1 mask = torch.rand_like(p) > 0.9 # 10% nonzero v = v * mask else: v = torch.randint_like(p, low=0, high=2, dtype=torch.float32) * 2 - 1 vs.append(v) hv = exact_hvp(func, params, vs, create_graph=create_graph) if isinstance(hv, torch.Tensor): hv = [hv] trace += sum([(h * v).sum().item() for h, v in zip(hv, vs)]) return trace / num_samples