Gauss-Newton Module#
The Gauss-Newton module provides functions for computing approximations related to the Gauss-Newton Matrix.
Gauss-Newton Matrix Approximation#
- torch_secorder.approximations.gauss_newton.gauss_newton_matrix_approximation(model: Module, loss_fn: Callable[[Tensor, Tensor], Tensor], inputs: Tensor, targets: Tensor, create_graph: bool = False) Iterable[Tensor][source]#
Computes the diagonal approximation of the Gauss-Newton Matrix (GNM).
The Gauss-Newton Matrix (GNM) is an approximation of the Hessian, commonly used in non-linear least squares problems. For a loss function defined as:
\[L = \frac{1}{2} \|f(x; \theta) - y\|^2,\]where \(f(x; \theta)\) is the model output, the Gauss-Newton Matrix is given by:
\[G = J^\top J,\]where \(J\) is the Jacobian of \(f(x; \theta)\) with respect to the parameters \(\theta\).
This function returns the diagonal elements of the GNM for each parameter. It achieves this by computing the sum of squared Jacobian rows corresponding to each parameter.
- Parameters:
model (torch.nn.Module) – The PyTorch model.
loss_fn (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]) – The loss function, typically
torch.nn.MSELossortorch.nn.functional.mse_loss().inputs (torch.Tensor) – Input tensor to the model.
targets (torch.Tensor) – Target tensor for the loss function.
create_graph (bool, optional) – If True, constructs the computation graph for higher-order derivatives. Default is False.
- Returns:
A list of tensors, where each tensor contains the diagonal elements of the GNM for the corresponding model parameter.
- Return type:
Iterable[torch.Tensor]
- Raises:
ValueError – If
loss_fnis not MSE-based, as GNM is defined specifically for least-squares problems.
Example:#
import torch
import torch.nn as nn
from torch_secorder.approximations.gauss_newton import gauss_newton_matrix_approximation
# Define a simple linear model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
model = SimpleModel()
loss_fn = nn.MSELoss()
inputs = torch.randn(5, 10)
targets = torch.randn(5, 1)
# Compute the diagonal of the Gauss-Newton Matrix approximation
gnm_diagonal = gauss_newton_matrix_approximation(model, loss_fn, inputs, targets)
print("Gauss-Newton Matrix Diagonal (for weights):", gnm_diagonal[0].shape)
print("Gauss-Newton Matrix Diagonal (for bias):", gnm_diagonal[1].shape)
Notes#
The Gauss-Newton Matrix is a positive semi-definite approximation of the Hessian, commonly used in non-linear least squares optimization.
This function provides the diagonal elements of the GNM, which can be used for diagonal approximations in optimizers or for diagnostic purposes.
The create_graph parameter allows for computing higher-order derivatives, which is useful in meta-learning or other advanced scenarios.