Fisher Trace Module#
The Fisher Trace module provides functions to estimate the trace of Fisher Information Matrices.
Empirical Fisher Trace#
- torch_secorder.approximations.fisher_trace.empirical_fisher_trace(model: Module, loss_fn: Callable[[Tensor, Tensor], Tensor], inputs: Tensor, targets: Tensor, num_samples: int = 1) Tensor[source]#
Estimate the trace of the Empirical Fisher Information Matrix using Hutchinson’s method.
This function estimates the trace of the EFIM by leveraging Hutchinson’s method, which involves computing Jacobian-vector products (or gradient products).
- Parameters:
model – The PyTorch model.
loss_fn – The loss function, e.g.,
nn.CrossEntropyLoss()ornn.MSELoss().inputs – Input tensor to the model.
targets – Target tensor for the loss function.
num_samples – Number of random vectors to use for Hutchinson’s estimation. Higher values lead to more accurate estimates but increase computation.
- Returns:
A scalar tensor representing the estimated trace of the EFIM.
Example:#
import torch
import torch.nn as nn
from torch_secorder.approximations.fisher_trace import empirical_fisher_trace
# Define a simple model and loss function
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 1)
def forward(self, x):
return self.linear(x)
model = SimpleModel()
loss_fn = nn.MSELoss()
inputs = torch.randn(5, 10)
targets = torch.randn(5, 1)
# Estimate the Empirical Fisher Trace
efim_trace = empirical_fisher_trace(model, loss_fn, inputs, targets)
print("Empirical Fisher Trace:", efim_trace)
Generalized Fisher Trace#
- torch_secorder.approximations.fisher_trace.generalized_fisher_trace(model: Module, outputs: Tensor, targets: Tensor, loss_type: str = 'nll', num_samples: int = 1, create_graph: bool = False) Tensor[source]#
Estimate the trace of the Generalized Fisher Information Matrix using Hutchinson’s method.
The Generalized Fisher Information Matrix (GFIM) is defined as the expectation of the outer product of the gradients of the log-likelihood with respect to the parameters. This function estimates its trace using the sum of squared gradients of the negative log-likelihood, which is a common practical approximation for classification tasks.
- Parameters:
model – The PyTorch model.
outputs – The raw outputs (e.g., logits) from the model.
targets – The target tensor (e.g., class labels or regression targets).
loss_type – Specifies the type of likelihood. Currently supports ‘nll’ (Negative Log Likelihood).
num_samples – Number of random vectors for Hutchinson’s estimation. Higher values lead to more accurate estimates but increase computation. (Note: For ‘nll’ with current implementation, this parameter is effectively ignored as the trace is computed directly via sum of squared gradients, which is exact for EFIM.)
create_graph – If True, the computational graph will be constructed, allowing for higher-order derivatives.
- Returns:
A scalar tensor representing the estimated trace of the GFIM.
- Raises:
NotImplementedError – If an unsupported loss_type or output shape is provided.
ValueError – If num_samples is less than 1.
Example:#
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_secorder.approximations.fisher_trace import generalized_fisher_trace
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 2)
def forward(self, x):
return self.linear(x)
model = SimpleModel()
inputs = torch.randn(5, 10)
outputs = model(inputs) # Logits
targets = torch.randint(0, 2, (5,)) # Class labels
# Estimate the Generalized Fisher Trace (NLL loss type)
gfim_trace = generalized_fisher_trace(model, outputs, targets, loss_type="nll")
print("Generalized Fisher Trace:", gfim_trace)
Notes#
The Fisher trace computation can be used to approximate the Hessian’s trace for second-order optimization.
The create_graph parameter for generalized_fisher_trace allows for computing higher-order derivatives if needed, but this increases memory usage.
The num_samples parameter for empirical_fisher_trace is effectively ignored in the current implementation when the trace is computed directly by summing squared gradients, which is exact for EFIM.