Hessian Diagonal Module#

The Hessian Diagonal module provides functions to compute the diagonal elements of the Hessian matrix.

Hessian Diagonal#

torch_secorder.core.hessian_diagonal.hessian_diagonal(func: Callable[[], Tensor], params: List[Tensor], v: List[Tensor] | None = None, create_graph: bool = False, strict: bool = False) List[Tensor][source]#

Compute the diagonal elements of the Hessian matrix.

This function computes the diagonal elements of the Hessian matrix by using double backward for each parameter element.

Parameters:
  • func – A callable that returns a scalar tensor (the loss).

  • params – List of parameters with respect to which the Hessian is computed.

  • v – Optional list of vectors to use for computing the diagonal. If None, computes the true diagonal. If provided, computes v_i * H_ii for each i.

  • create_graph – If True, the computational graph will be constructed, allowing for higher-order derivatives.

  • strict – If True, an error will be raised if any parameter requires grad but has no gradient.

Returns:

List of tensors containing the diagonal elements of the Hessian for each parameter.

Computes the diagonal elements of the Hessian matrix of a scalar output with respect to inputs.

Example:#

import torch
from torch_secorder.core.hessian_diagonal import hessian_diagonal

# Create a simple quadratic function: f(x) = x^T A x
A = torch.tensor([[2.0, 1.0], [1.0, 3.0]])
x = torch.tensor([1.0, 2.0], requires_grad=True)

def quadratic():
    return x @ A @ x

# Compute the diagonal of the Hessian
diag = hessian_diagonal(quadratic, [x])
print(diag[0])  # Should print tensor([4., 6.]) (diagonal of 2A)

Model Hessian Diagonal#

torch_secorder.core.hessian_diagonal.model_hessian_diagonal(model: Module, loss_fn: Callable[[Tensor, Tensor], Tensor], inputs: Tensor, targets: Tensor, create_graph: bool = False, strict: bool = False) List[Tensor][source]#

A convenience function to compute the diagonal of the Hessian of a model’s loss with respect to its parameters.

Example:#

import torch
import torch.nn as nn
from torch_secorder.core.hessian_diagonal import model_hessian_diagonal

# Create a simple neural network
model = nn.Linear(10, 1)
x = torch.randn(5, 10)
y = torch.randn(5, 1)

# Compute the diagonal of the Hessian for the MSE loss
diag = model_hessian_diagonal(model, nn.functional.mse_loss, x, y)

# diag[0] contains the diagonal for the weight matrix
# diag[1] contains the diagonal for the bias vector
print(f"Weight diagonal shape: {diag[0].shape}")
print(f"Bias diagonal shape: {diag[1].shape}")

Notes#

  1. The Hessian diagonal computation uses Hessian-vector products with unit vectors to compute the diagonal elements efficiently.

  2. The create_graph parameter allows for computing higher-order derivatives if needed, but this increases memory usage.

  3. The strict parameter controls whether an error should be raised if any parameter requires gradients but has none.