Fisher Diagonal Module#

The Fisher Diagonal module provides functions to compute the diagonal elements of Fisher Information Matrices.

Empirical Fisher Diagonal#

torch_secorder.approximations.fisher_diagonal.empirical_fisher_diagonal(model: Module, loss_fn: Callable[[Tensor, Tensor], Tensor], inputs: Tensor, targets: Tensor) List[Tensor][source]#

Compute the diagonal elements of the Empirical Fisher Information Matrix.

The Empirical Fisher is approximated by the squared gradients of the loss with respect to the model parameters. This function computes the diagonal elements of this approximation for each parameter.

Parameters:
  • model – The PyTorch model.

  • loss_fn – The loss function, e.g., nn.CrossEntropyLoss() or nn.MSELoss().

  • inputs – Input tensor to the model.

  • targets – Target tensor for the loss function.

Returns:

A list of tensors, each containing the diagonal elements of the EFIM for the corresponding parameter.

Example:#

import torch
import torch.nn as nn
from torch_secorder.approximations.fisher_diagonal import empirical_fisher_diagonal

# Define a simple model and loss function
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        return self.linear(x)

model = SimpleModel()
loss_fn = nn.CrossEntropyLoss()
inputs = torch.randn(5, 10)  # Batch size 5, input dim 10
targets = torch.randint(0, 2, (5,)) # Batch size 5, 2 classes

# Compute the Empirical Fisher Diagonal
efim_diagonal = empirical_fisher_diagonal(model, loss_fn, inputs, targets)

print("Empirical Fisher Diagonal:", efim_diagonal)

Generalized Fisher Diagonal#

torch_secorder.approximations.fisher_diagonal.generalized_fisher_diagonal(model: Module, outputs: Tensor, targets: Tensor, loss_type: str = 'nll', create_graph: bool = False) List[Tensor][source]#

Compute the diagonal elements of the Generalized Fisher Information Matrix.

The Generalized Fisher Information Matrix (GFIM) is defined as the expectation of the outer product of the gradients of the log-likelihood with respect to the parameters. This function computes the diagonal elements of this approximation.

Parameters:
  • model – The PyTorch model.

  • outputs – The raw outputs (e.g., logits) from the model.

  • targets – The target tensor (e.g., class labels or regression targets).

  • loss_type – Specifies the type of likelihood. Currently supports ‘nll’ (Negative Log Likelihood).

  • create_graph – If True, the computational graph will be constructed, allowing for higher-order derivatives.

Returns:

A list of tensors, each containing the diagonal elements of the GFIM for the corresponding parameter.

Raises:

NotImplementedError – If an unsupported loss_type or output shape is provided.

Example:#

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_secorder.approximations.fisher_diagonal import generalized_fisher_diagonal

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        return self.linear(x)

model = SimpleModel()
inputs = torch.randn(5, 10)
outputs = model(inputs)  # Logits
targets = torch.randint(0, 2, (5,)) # Class labels

# Compute the Generalized Fisher Diagonal (NLL loss type)
gfim_diagonal = generalized_fisher_diagonal(model, outputs, targets, loss_type="nll")

print("Generalized Fisher Diagonal:", gfim_diagonal)

Notes#

  1. The Fisher diagonal computation can be used to approximate the Hessian’s diagonal for second-order optimization.

  2. The create_graph parameter for generalized_fisher_diagonal allows for computing higher-order derivatives if needed, but this increases memory usage.