Source code for torch_secorder.core.hessian_diagonal
"""Module for computing Hessian diagonal efficiently.
This module provides functions to compute the diagonal elements of the Hessian matrix
which are useful for various second-order analysis tasks.
"""
from typing import Callable, List, Optional
import torch
from torch import Tensor
from torch.nn import Module
[docs]
def hessian_diagonal(
func: Callable[[], Tensor],
params: List[Tensor],
v: Optional[List[Tensor]] = None,
create_graph: bool = False,
strict: bool = False,
) -> List[Tensor]:
"""Compute the diagonal elements of the Hessian matrix.
This function computes the diagonal elements of the Hessian matrix by using
double backward for each parameter element.
Args:
func: A callable that returns a scalar tensor (the loss).
params: List of parameters with respect to which the Hessian is computed.
v: Optional list of vectors to use for computing the diagonal. If None,
computes the true diagonal. If provided, computes v_i * H_ii for each i.
create_graph: If True, the computational graph will be constructed,
allowing for higher-order derivatives.
strict: If True, an error will be raised if any parameter requires grad
but has no gradient.
Returns:
List of tensors containing the diagonal elements of the Hessian for each parameter.
"""
# Filter out parameters that don't require grad if not in strict mode
if not strict:
grad_params = [p for p in params if p.requires_grad]
if not grad_params:
# If no parameters require grad, return zeros. Ensure they are traceable if create_graph is True.
if create_graph:
return [(p.sum() * 0.0).expand_as(p) for p in params]
else:
return [torch.zeros_like(p) for p in params]
else:
grad_params = params
if not all(p.requires_grad for p in params):
raise RuntimeError(
"One of the differentiated Tensors does not require grad"
)
diag_results = []
# Compute first-order gradients. create_graph=True ensures these gradients themselves have grad_fn.
first_grads = torch.autograd.grad(
func(), grad_params, create_graph=True, allow_unused=True
)
# Create a mapping from parameter object to its first gradient for efficient lookup
param_to_grad_map = {p: g for p, g in zip(grad_params, first_grads)}
for p_orig_idx, p in enumerate(params):
if p.requires_grad:
g = param_to_grad_map.get(p)
if g is None:
# This parameter `p` did not receive a gradient (e.g., due to allow_unused=True).
# Its Hessian diagonal elements will be zero. Ensure they are traceable if create_graph is True.
if create_graph:
diag_results.append((p.sum() * 0.0).expand_as(p))
else:
diag_results.append(torch.zeros_like(p))
continue
# Flatten parameter and gradient for element-wise processing
g_flat = g.flatten()
per_param_diagonal_elements = []
for i in range(
g_flat.numel()
): # Iterate over flattened elements of the gradient
# Only proceed if g_flat[i] is part of the graph and can be differentiated further
if g_flat[i].requires_grad and g_flat[i].grad_fn is not None:
# Compute the gradient of the i-th element of g_flat with respect to the entire parameter p.
grad2 = torch.autograd.grad(
g_flat[i],
p,
retain_graph=True,
create_graph=create_graph,
allow_unused=True,
)[0]
if grad2 is not None:
diag_elem = grad2.flatten()[i]
else:
# g_flat[i] required grad and had grad_fn, but still no dependency on p.
# This should still yield a traceable zero.
if create_graph:
diag_elem = (p.sum() * 0.0).expand_as(g_flat[i])
else:
diag_elem = torch.zeros_like(g_flat[i])
else:
# If g_flat[i] does not require grad, or has no grad_fn (i.e., it's a leaf not connected for second derivatives),
# its second derivative w.r.t. p is zero.
if create_graph:
diag_elem = (p.sum() * 0.0).expand_as(g_flat[i])
else:
diag_elem = torch.zeros_like(g_flat[i])
if v is not None:
v_param_elem = v[p_orig_idx].flatten()[i]
per_param_diagonal_elements.append(diag_elem * v_param_elem)
else:
per_param_diagonal_elements.append(diag_elem)
if per_param_diagonal_elements:
# Stack elements and reshape to original parameter shape
# Ensure the stacked tensor also retains grad_fn if create_graph is True
stacked_elements = torch.stack(per_param_diagonal_elements)
diag_results.append(stacked_elements.reshape(p.shape))
else:
if create_graph:
diag_results.append((p.sum() * 0.0).expand_as(p))
else:
diag_results.append(torch.zeros_like(p))
else:
diag_results.append(torch.zeros_like(p))
return diag_results
[docs]
def model_hessian_diagonal(
model: Module,
loss_fn: Callable[[Tensor, Tensor], Tensor],
inputs: Tensor,
targets: Tensor,
create_graph: bool = False,
strict: bool = False,
) -> List[Tensor]:
def loss_func():
outputs = model(inputs)
return loss_fn(outputs, targets)
return hessian_diagonal(
loss_func,
list(model.parameters()),
create_graph=create_graph,
strict=strict,
)