Source code for torch_secorder.core.hessian_trace

"""Module for estimating the trace of the Hessian matrix.

This module provides functions to estimate the trace of the Hessian matrix
using various methods, including Hutchinson's method (HVP-based) and
summing the diagonal elements (diagonal-based).
"""

from typing import Callable, List

from torch import Tensor
from torch.nn import Module

# Import hessian_diagonal from its new location to use it for trace estimation
from .hessian_diagonal import hessian_diagonal


[docs] def hessian_trace( func: Callable[[], Tensor], params: List[Tensor], num_samples: int = 1, create_graph: bool = False, strict: bool = False, ) -> Tensor: """Compute the trace of the Hessian matrix by summing the diagonal elements. Args: func: A callable that returns a scalar tensor (the loss). params: List of parameters with respect to which the Hessian is computed. num_samples: Ignored (kept for API compatibility). create_graph: If True, the computational graph will be constructed, allowing for higher-order derivatives. strict: If True, an error will be raised if any parameter requires grad but has no gradient. Returns: A tensor containing the trace of the Hessian. """ diag = hessian_diagonal(func, params, create_graph=create_graph, strict=strict) return sum([d.sum() for d in diag])
[docs] def model_hessian_trace( model: Module, loss_fn: Callable[[Tensor, Tensor], Tensor], inputs: Tensor, targets: Tensor, num_samples: int = 1, create_graph: bool = False, strict: bool = False, ) -> Tensor: """Compute the trace of the Hessian for a model's loss function. A convenience function to estimate the trace of the Hessian of a model's loss with respect to its parameters. Args: model: The PyTorch model. loss_fn: The loss function, e.g., ``nn.MSELoss()`` or ``F.cross_entropy``. inputs: Input tensor to the model. targets: Target tensor for the loss function. num_samples: Number of random vectors for Hutchinson's method (if used). Ignored for diagonal-based trace. create_graph: If True, the computational graph will be constructed, allowing for higher-order derivatives. strict: If True, an error will be raised if any parameter requires grad but has no gradient. Returns: A tensor containing the estimated trace of the Hessian. """ def loss_func(): outputs = model(inputs) return loss_fn(outputs, targets) return hessian_trace( loss_func, list(model.parameters()), create_graph=create_graph, strict=strict, )