Source code for torch_secorder.approximations.gauss_newton

"""Functions for computing Gauss-Newton matrix approximations."""

from typing import Callable, Iterable

import torch
from torch import Tensor
from torch.nn import Module


[docs] def gauss_newton_matrix_approximation( model: Module, loss_fn: Callable[[Tensor, Tensor], Tensor], inputs: Tensor, targets: Tensor, create_graph: bool = False, ) -> Iterable[Tensor]: r""" Computes the diagonal approximation of the Gauss-Newton Matrix (GNM). The Gauss-Newton Matrix (GNM) is an approximation of the Hessian, commonly used in non-linear least squares problems. For a loss function defined as: .. math:: L = \frac{1}{2} \|f(x; \theta) - y\|^2, where :math:`f(x; \theta)` is the model output, the Gauss-Newton Matrix is given by: .. math:: G = J^\top J, where :math:`J` is the Jacobian of :math:`f(x; \theta)` with respect to the parameters :math:`\theta`. This function returns the diagonal elements of the GNM for each parameter. It achieves this by computing the sum of squared Jacobian rows corresponding to each parameter. Parameters ---------- model : torch.nn.Module The PyTorch model. loss_fn : Callable[[torch.Tensor, torch.Tensor], torch.Tensor] The loss function, typically :class:`torch.nn.MSELoss` or :func:`torch.nn.functional.mse_loss`. inputs : torch.Tensor Input tensor to the model. targets : torch.Tensor Target tensor for the loss function. create_graph : bool, optional If True, constructs the computation graph for higher-order derivatives. Default is False. Returns ------- Iterable[torch.Tensor] A list of tensors, where each tensor contains the diagonal elements of the GNM for the corresponding model parameter. Raises ------ ValueError If ``loss_fn`` is not MSE-based, as GNM is defined specifically for least-squares problems. """ # Verify that the loss function is MSE-based # This is a heuristic check; a more robust check might involve inspecting the loss_fn's type # or relying on user to pass appropriate loss. if not ( isinstance(loss_fn, torch.nn.MSELoss) or ( callable(loss_fn) and ("mse_loss" in str(loss_fn) or "MSELoss" in str(loss_fn.__class__)) ) ): raise ValueError( "Gauss-Newton Matrix approximation is typically used with MSE-based loss functions " "for least-squares problems. Please use `nn.MSELoss` or `F.mse_loss`." ) # The residual r(x;theta) = f(x;theta) - y # The loss L = 1/2 * ||r(x;theta)||^2 # The Jacobian of the loss w.r.t parameters is J_L = r(x;theta)^T * J_r # where J_r is the Jacobian of the residuals w.r.t parameters. # The GNM is J_r^T J_r # Compute the model outputs outputs = model(inputs) # Compute residuals: r = outputs - targets residuals = outputs - targets # Ensure residuals require grad if create_graph is True for higher-order derivatives if create_graph and not residuals.requires_grad: # This ensures the gradient computation for residuals can be differentiated again residuals.requires_grad_(True) param_gen = [p for p in model.parameters() if p.requires_grad] gnm_diagonal = [] # Iterate over each output dimension to construct the Jacobian rows implicitly # This is effectively computing J^T J by summing outer products of gradients (for each output component) # For the diagonal, we only need the sum of squared gradients for each parameter. # J_r^T J_r = sum over output dimensions of (grad(r_i) * grad(r_i)^T) # The diagonal of J_r^T J_r is sum over output dimensions of (grad(r_i) .^ 2) # Loop through each element of the residual and compute its gradient w.r.t. parameters # Then sum the squares of these gradients to get the diagonal of GNM. # This is effectively computing diag(J_r^T J_r) for i in range(residuals.numel()): # Select a single element of the residual # Ensure the element is part of the graph for higher-order derivatives if create_graph is True residual_element = residuals.view(-1)[i] # Compute gradients of this single residual element w.r.t. all parameters that require grad # Use retain_graph=True because we'll call .grad multiple times in the loop grads_for_element = torch.autograd.grad( residual_element, param_gen, # Use the filtered list of parameters retain_graph=True, create_graph=create_graph, allow_unused=True, ) if i == 0: # Initialize gnm_diagonal with zeros of the correct shape based on first gradients # Ensure these initial zeros also have requires_grad=True if create_graph is True gnm_diagonal = [ torch.zeros_like(g, requires_grad=create_graph) for g in grads_for_element if g is not None ] for j, g_elem in enumerate(grads_for_element): if g_elem is not None: # Sum the squares of the gradients for each parameter # Ensure that g_elem.pow(2) maintains the graph if create_graph is True # and that the accumulation also respects it. if create_graph: # If create_graph is True, we need to ensure g_elem has grad_fn # and that its square retains grad_fn. if g_elem.grad_fn is None and g_elem.requires_grad: # This case implies g_elem is an input that directly requires grad # or a leaf tensor. If not, it should have a grad_fn. pass # Handled by earlier create_graph=True for residual_element gnm_diagonal[j] = gnm_diagonal[j] + g_elem.pow(2) else: # If create_graph is False, simple accumulation is fine gnm_diagonal[j] = gnm_diagonal[j] + g_elem.pow(2) # Clear the graph to avoid memory issues if create_graph was True and not needed further if not create_graph: model.zero_grad() return gnm_diagonal