Hessian-Vector Product (HVP) Module#
The HVP module provides efficient implementations of Hessian-vector products and related utilities.
Exact HVP#
- torch_secorder.core.hvp.exact_hvp(func: Callable[[], Tensor], params: List[Tensor], v: Tensor | List[Tensor], create_graph: bool = False) Tensor | List[Tensor][source]#
Compute the exact Hessian-vector product Hv using double backpropagation.
- Parameters:
func – A callable that returns a scalar tensor (loss).
params – List of parameters with respect to which to compute the Hessian.
v – Vector to multiply with the Hessian. Can be a single tensor or a list of tensors matching the structure of params.
create_graph – If True, the graph used to compute the grad will be constructed, allowing to compute higher order derivative products.
- Returns:
The Hessian-vector product Hv. Returns a single tensor if v is a single tensor, otherwise returns a list of tensors matching the structure of params.
The exact HVP computation uses double backpropagation to compute the Hessian-vector product. This method is accurate but can be memory-intensive for large models.
Example:#
import torch
from torch_secorder.core.hvp import exact_hvp
def loss_func():
return model(x).sum()
v = [torch.randn_like(p) for p in model.parameters()]
hvp_result = exact_hvp(loss_func, list(model.parameters()), v)
Approximate HVP#
- torch_secorder.core.hvp.approximate_hvp(func: Callable[[], Tensor], params: List[Tensor], v: Tensor | List[Tensor], num_samples: int = 1, damping: float = 0.0) Tensor | List[Tensor][source]#
Compute an approximate Hessian-vector product using finite differences.
This method uses a finite difference approximation of the Hessian-vector product, which can be more memory efficient than the exact computation.
- Parameters:
func – A callable that returns a scalar tensor (loss).
params – List of parameters with respect to which to compute the Hessian.
v – Vector to multiply with the Hessian. Can be a single tensor or a list of tensors matching the structure of params.
num_samples – Number of samples to use for the approximation.
damping – Damping term to add to the diagonal of the Hessian (lambda * I).
- Returns:
The approximate Hessian-vector product Hv. Returns a single tensor if v is a single tensor, otherwise returns a list of tensors matching the structure of params.
The approximate HVP uses finite differences to estimate the Hessian-vector product. This method is more memory-efficient but less accurate than the exact computation.
Example:#
import torch
from torch_secorder.core.hvp import approximate_hvp
def loss_func():
return model(x).sum()
v = [torch.randn_like(p) for p in model.parameters()]
hvp_result = approximate_hvp(
loss_func,
list(model.parameters()),
v,
num_samples=10,
damping=0.1
)
Model HVP#
- torch_secorder.core.hvp.model_hvp(model: Module, loss_fn: Callable[[Tensor, Tensor], Tensor], x: Tensor, y: Tensor, v: Tensor | List[Tensor], create_graph: bool = False) Tensor | List[Tensor][source]#
Compute the Hessian-vector product for a model’s loss function.
This is a convenience wrapper around exact_hvp that handles the model’s forward pass and loss computation.
- Parameters:
model – The PyTorch model.
loss_fn – The loss function that takes model output and target as arguments.
x – Input tensor.
y – Target tensor.
v – Vector to multiply with the Hessian.
create_graph – If True, the graph used to compute the grad will be constructed.
- Returns:
The Hessian-vector product Hv.
A convenience wrapper for computing HVP with respect to a model’s loss function.
Example:#
import torch
import torch.nn as nn
from torch_secorder.core.hvp import model_hvp
model = nn.Linear(10, 1)
x = torch.randn(1, 10)
y = torch.randn(1, 1)
v = [torch.randn_like(p) for p in model.parameters()]
hvp_result = model_hvp(
model,
nn.MSELoss(),
x,
y,
v
)
Gauss-Newton Product#
- torch_secorder.core.hvp.gauss_newton_product(model: Module, loss_fn: Callable[[Tensor, Tensor], Tensor], x: Tensor, y: Tensor, v: Tensor | List[Tensor], create_graph: bool = False) Tensor | List[Tensor][source]#
Compute the Gauss-Newton matrix-vector product for a model’s loss.
- Parameters:
model – The PyTorch model.
loss_fn – The loss function (should be MSE or cross-entropy for GN to be valid).
x – Input tensor.
y – Target tensor.
v – Vector to multiply with the Gauss-Newton matrix.
create_graph – If True, graph of the derivative will be constructed.
- Returns:
The Gauss-Newton matrix-vector product (same structure as v).
Computes the Gauss-Newton matrix-vector product, which is a positive semi-definite approximation of the Hessian matrix.
Example:#
import torch
import torch.nn as nn
from torch_secorder.core.hvp import gauss_newton_product
model = nn.Linear(10, 1)
x = torch.randn(1, 10)
y = torch.randn(1, 1)
v = [torch.randn_like(p) for p in model.parameters()]
gn_result = gauss_newton_product(
model,
nn.MSELoss(),
x,
y,
v
)
Trace Estimation Methods#
The library provides two methods for estimating the trace of the Hessian matrix:
HVP-based Trace (from this module): - Uses Hutchinson’s method with Hessian-vector products - More memory-efficient for large models - Better suited when computing the full diagonal is expensive - Uses random vectors to estimate the trace
Diagonal-based Trace (from
torch_secorder.core.hessian_trace): - Computes the exact diagonal elements of the Hessian - More accurate but more computationally expensive - Better suited for smaller models - Can use custom vectors for more control
The choice between these methods depends on your specific needs: - Use HVP-based trace for large models where memory efficiency is crucial - Use diagonal-based trace when accuracy is more important than computational cost - Both methods are related (they compute the same quantity) but use different approaches - The HVP version is more memory-efficient but may be less accurate - The diagonal version is more accurate but requires more computation
Hessian Trace#
- torch_secorder.core.hvp.hessian_trace(func: Callable[[], Tensor], params: List[Tensor], num_samples: int = 10, create_graph: bool = False, sparse: bool = False) float[source]#
Estimate the trace of the Hessian using Hutchinson’s method (random projections).
- Parameters:
func – A callable that returns a scalar tensor (loss).
params – List of parameters with respect to which to compute the Hessian.
num_samples – Number of random projections to use.
create_graph – If True, graph of the derivative will be constructed.
sparse – If True, use sparse random vectors for projections.
- Returns:
Estimated trace of the Hessian.
Estimates the trace of the Hessian matrix using Hutchinson’s method with random projections.
Example:#
import torch
from torch_secorder.core.hvp import hessian_trace
def loss_func():
return model(x).sum()
trace = hessian_trace(
loss_func,
list(model.parameters()),
num_samples=10,
sparse=True
)