Hessian Trace Module#
The Hessian Trace module provides functions to estimate the trace of the Hessian matrix.
Hessian Trace#
- torch_secorder.core.hessian_trace.hessian_trace(func: Callable[[], Tensor], params: List[Tensor], num_samples: int = 1, create_graph: bool = False, strict: bool = False) Tensor[source]#
Compute the trace of the Hessian matrix by summing the diagonal elements.
- Parameters:
func – A callable that returns a scalar tensor (the loss).
params – List of parameters with respect to which the Hessian is computed.
num_samples – Ignored (kept for API compatibility).
create_graph – If True, the computational graph will be constructed, allowing for higher-order derivatives.
strict – If True, an error will be raised if any parameter requires grad but has no gradient.
- Returns:
A tensor containing the trace of the Hessian.
Estimates the trace of the Hessian matrix by summing the diagonal elements, or using Hutchinson’s method based on the num_samples parameter.
Example:#
import torch
from torch_secorder.core.hessian_trace import hessian_trace
x = torch.tensor([1.0, 2.0], requires_grad=True)
def quadratic():
return x.pow(2).sum()
# Estimate the trace with 1000 samples
trace = hessian_trace(quadratic, [x], num_samples=1000)
print(trace) # Should be close to 4.0 (trace of 2I)
Model Hessian Trace#
- torch_secorder.core.hessian_trace.model_hessian_trace(model: Module, loss_fn: Callable[[Tensor, Tensor], Tensor], inputs: Tensor, targets: Tensor, num_samples: int = 1, create_graph: bool = False, strict: bool = False) Tensor[source]#
Compute the trace of the Hessian for a model’s loss function.
A convenience function to estimate the trace of the Hessian of a model’s loss with respect to its parameters.
- Parameters:
model – The PyTorch model.
loss_fn – The loss function, e.g.,
nn.MSELoss()orF.cross_entropy.inputs – Input tensor to the model.
targets – Target tensor for the loss function.
num_samples – Number of random vectors for Hutchinson’s method (if used). Ignored for diagonal-based trace.
create_graph – If True, the computational graph will be constructed, allowing for higher-order derivatives.
strict – If True, an error will be raised if any parameter requires grad but has no gradient.
- Returns:
A tensor containing the estimated trace of the Hessian.
A convenience function to estimate the trace of the Hessian of a model’s loss with respect to its parameters.
Example:#
import torch
import torch.nn as nn
from torch_secorder.core.hessian_trace import model_hessian_trace
# Create a simple neural network
model = nn.Linear(10, 1)
x = torch.randn(5, 10)
y = torch.randn(5, 1)
# Estimate the trace of the Hessian for the MSE loss
trace = model_hessian_trace(model, nn.functional.mse_loss, x, y)
print(f"Estimated Hessian Trace: {trace}")
Trace Estimation Methods#
The library provides two methods for estimating the trace of the Hessian matrix:
HVP-based Trace (from
torch_secorder.core.hvp): - Uses Hutchinson’s method with Hessian-vector products - More memory-efficient for large models - Better suited when computing the full diagonal is expensive - Uses random vectors to estimate the traceDiagonal-based Trace (from this module): - Computes the exact diagonal elements of the Hessian (via hessian_diagonal) - More accurate but more computationally expensive - Better suited for smaller models - Can use custom vectors for more control
The choice between these methods depends on your specific needs: - Use HVP-based trace for large models where memory efficiency is crucial - Use diagonal-based trace when accuracy is more important than computational cost - Both methods are related (they compute the same quantity) but use different approaches - The HVP version is more memory-efficient but may be less accurate - The diagonal version is more accurate but requires more computation
Notes#
The Hessian trace computation using this module by default sums the diagonal elements obtained from hessian_diagonal. For a Hutchinson-style trace estimation, use the hessian_trace function from the hvp module.
The trace estimation can be made more accurate by increasing the number of samples for Hutchinson’s method, but this comes at the cost of increased computation time.
The create_graph parameter allows for computing higher-order derivatives if needed, but this increases memory usage.
The strict parameter controls whether an error should be raised if any parameter requires gradients but has none.