Per-Layer Curvature Module#
The per-layer curvature module provides utilities for computing block-diagonal approximations of the Hessian and Fisher Information matrices, where each block corresponds to a layer’s parameters. This is useful for understanding the curvature of individual layers and for implementing layer-wise optimization strategies.
Per-Layer Hessian Diagonal#
- torch_secorder.core.per_layer_curvature.per_layer_hessian_diagonal(model: Module, loss_fn: Module, inputs: Tensor, targets: Tensor, layer_types: List[type] | None = None, create_graph: bool = False) Dict[str, Tensor][source]#
Compute the diagonal of the Hessian matrix for each layer separately.
This function computes the diagonal elements of the Hessian matrix for each layer in the model independently, treating other layers’ parameters as fixed. This provides a block-diagonal approximation of the full Hessian.
- Parameters:
model – The neural network model.
loss_fn – The loss function used for training.
inputs – Input tensor for the model.
targets – Target tensor for the loss function.
layer_types – List of layer types to include. If None, includes all layers.
create_graph – Whether to create the computational graph for higher-order derivatives.
- Returns:
Dictionary mapping layer names to their Hessian diagonal tensors.
Example:#
import torch
import torch.nn as nn
from torch_secorder.core.per_layer_curvature import per_layer_hessian_diagonal
# Define a simple model with multiple layers
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 20)
self.linear2 = nn.Linear(20, 5)
self.linear3 = nn.Linear(5, 1)
def forward(self, x):
return self.linear3(self.linear2(self.linear1(x)))
model = SimpleModel()
loss_fn = nn.MSELoss()
inputs = torch.randn(5, 10)
targets = torch.randn(5, 1)
# Compute per-layer Hessian diagonals
layer_hessians = per_layer_hessian_diagonal(model, loss_fn, inputs, targets)
# Print shapes of Hessian diagonals for each layer
for layer_name, hessian in layer_hessians.items():
print(f"{layer_name} Hessian diagonal shape: {hessian.shape}")
Per-Layer Fisher Diagonal#
- torch_secorder.core.per_layer_curvature.per_layer_fisher_diagonal(model: Module, loss_fn: Module, inputs: Tensor, targets: Tensor, layer_types: List[type] | None = None, create_graph: bool = False) Dict[str, Tensor][source]#
Compute the diagonal of the Fisher Information matrix for each layer separately.
This function computes the diagonal elements of the Fisher Information matrix for each layer in the model independently, treating other layers’ parameters as fixed. This provides a block-diagonal approximation of the full Fisher matrix.
- Parameters:
model – The neural network model.
loss_fn – The loss function used for training.
inputs – Input tensor for the model.
targets – Target tensor for the loss function.
layer_types – List of layer types to include. If None, includes all layers.
create_graph – Whether to create the computational graph for higher-order derivatives.
- Returns:
Dictionary mapping layer names to their Fisher diagonal tensors.
Example:#
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_secorder.core.per_layer_curvature import per_layer_fisher_diagonal
# Define a simple model with multiple layers
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 5)
self.linear2 = nn.Linear(5, 3)
def forward(self, x):
x = F.relu(self.linear1(x))
return self.linear2(x)
# Generate synthetic data
torch.manual_seed(42)
X = torch.randn(8, 10)
y = torch.randint(0, 3, (8,)) # 3 classes
# Instantiate the model
model = SimpleModel()
# Compute per-layer Fisher diagonal
# This extracts block-diagonal approximations of the Fisher Information Matrix
# These blocks are the basis for K-FAC approximations
layer_fisher_diagonals = per_layer_fisher_diagonal(model, F.cross_entropy, X, y)
print("Per-Layer Fisher Diagonal (K-FAC basis):")
for layer_name, diag in layer_fisher_diagonals.items():
print(f"Layer: {layer_name}, Shape: {diag.shape}, First few values: {diag.flatten()[:5].tolist()}")
print("\nK-FAC would further decompose these blocks into Kronecker products of smaller matrices.")
Layer Curvature Statistics#
- torch_secorder.core.per_layer_curvature.get_layer_curvature_stats(layer_curvatures: Dict[str, Tensor]) Dict[str, Dict[str, float]][source]#
Compute basic statistics for each layer’s curvature information.
- Parameters:
layer_curvatures – Dictionary mapping layer names to their curvature tensors (Hessian or Fisher diagonals).
- Returns:
Dictionary mapping layer names to their curvature statistics (mean, std, max, min).
Example:#
import torch
from torch_secorder.core.per_layer_curvature import get_layer_curvature_stats
# Using the layer_hessians from the previous example
stats = get_layer_curvature_stats(layer_hessians)
# Print statistics for each layer
for layer_name, layer_stats in stats.items():
print(f"\n{layer_name} statistics:")
print(f" Mean: {layer_stats['mean']:.4f}")
print(f" Std: {layer_stats['std']:.4f}")
print(f" Max: {layer_stats['max']:.4f}")
print(f" Min: {layer_stats['min']:.4f}")
Notes#
The per-layer curvature computations provide a block-diagonal approximation of the full Hessian/Fisher matrix, where each block corresponds to a layer’s parameters.
This approximation is useful for: - Understanding the curvature of individual layers - Implementing layer-wise optimization strategies - Diagnosing training issues at the layer level - Reducing computational complexity compared to full matrix computations
The layer_types parameter allows you to specify which types of layers to include in the computation. By default, it includes nn.Linear and nn.Conv2d layers.
The create_graph parameter allows for computing higher-order derivatives, which is useful in meta-learning or other advanced scenarios.