Approximations#

torch_secorder.approximations.empirical_fisher_diagonal(model: Module, loss_fn: Callable[[Tensor, Tensor], Tensor], inputs: Tensor, targets: Tensor) List[Tensor][source]#

Compute the diagonal elements of the Empirical Fisher Information Matrix.

The Empirical Fisher is approximated by the squared gradients of the loss with respect to the model parameters. This function computes the diagonal elements of this approximation for each parameter.

Parameters:
  • model – The PyTorch model.

  • loss_fn – The loss function, e.g., nn.CrossEntropyLoss() or nn.MSELoss().

  • inputs – Input tensor to the model.

  • targets – Target tensor for the loss function.

Returns:

A list of tensors, each containing the diagonal elements of the EFIM for the corresponding parameter.

torch_secorder.approximations.empirical_fisher_trace(model: Module, loss_fn: Callable[[Tensor, Tensor], Tensor], inputs: Tensor, targets: Tensor, num_samples: int = 1) Tensor[source]#

Estimate the trace of the Empirical Fisher Information Matrix using Hutchinson’s method.

This function estimates the trace of the EFIM by leveraging Hutchinson’s method, which involves computing Jacobian-vector products (or gradient products).

Parameters:
  • model – The PyTorch model.

  • loss_fn – The loss function, e.g., nn.CrossEntropyLoss() or nn.MSELoss().

  • inputs – Input tensor to the model.

  • targets – Target tensor for the loss function.

  • num_samples – Number of random vectors to use for Hutchinson’s estimation. Higher values lead to more accurate estimates but increase computation.

Returns:

A scalar tensor representing the estimated trace of the EFIM.

torch_secorder.approximations.gauss_newton_matrix_approximation(model: Module, loss_fn: Callable[[Tensor, Tensor], Tensor], inputs: Tensor, targets: Tensor, create_graph: bool = False) Iterable[Tensor][source]#

Computes the diagonal approximation of the Gauss-Newton Matrix (GNM).

The Gauss-Newton Matrix (GNM) is an approximation of the Hessian, commonly used in non-linear least squares problems. For a loss function defined as:

\[L = \frac{1}{2} \|f(x; \theta) - y\|^2,\]

where \(f(x; \theta)\) is the model output, the Gauss-Newton Matrix is given by:

\[G = J^\top J,\]

where \(J\) is the Jacobian of \(f(x; \theta)\) with respect to the parameters \(\theta\).

This function returns the diagonal elements of the GNM for each parameter. It achieves this by computing the sum of squared Jacobian rows corresponding to each parameter.

Parameters:
Returns:

A list of tensors, where each tensor contains the diagonal elements of the GNM for the corresponding model parameter.

Return type:

Iterable[torch.Tensor]

Raises:

ValueError – If loss_fn is not MSE-based, as GNM is defined specifically for least-squares problems.

torch_secorder.approximations.generalized_fisher_diagonal(model: Module, outputs: Tensor, targets: Tensor, loss_type: str = 'nll', create_graph: bool = False) List[Tensor][source]#

Compute the diagonal elements of the Generalized Fisher Information Matrix.

The Generalized Fisher Information Matrix (GFIM) is defined as the expectation of the outer product of the gradients of the log-likelihood with respect to the parameters. This function computes the diagonal elements of this approximation.

Parameters:
  • model – The PyTorch model.

  • outputs – The raw outputs (e.g., logits) from the model.

  • targets – The target tensor (e.g., class labels or regression targets).

  • loss_type – Specifies the type of likelihood. Currently supports ‘nll’ (Negative Log Likelihood).

  • create_graph – If True, the computational graph will be constructed, allowing for higher-order derivatives.

Returns:

A list of tensors, each containing the diagonal elements of the GFIM for the corresponding parameter.

Raises:

NotImplementedError – If an unsupported loss_type or output shape is provided.

torch_secorder.approximations.generalized_fisher_trace(model: Module, outputs: Tensor, targets: Tensor, loss_type: str = 'nll', num_samples: int = 1, create_graph: bool = False) Tensor[source]#

Estimate the trace of the Generalized Fisher Information Matrix using Hutchinson’s method.

The Generalized Fisher Information Matrix (GFIM) is defined as the expectation of the outer product of the gradients of the log-likelihood with respect to the parameters. This function estimates its trace using the sum of squared gradients of the negative log-likelihood, which is a common practical approximation for classification tasks.

Parameters:
  • model – The PyTorch model.

  • outputs – The raw outputs (e.g., logits) from the model.

  • targets – The target tensor (e.g., class labels or regression targets).

  • loss_type – Specifies the type of likelihood. Currently supports ‘nll’ (Negative Log Likelihood).

  • num_samples – Number of random vectors for Hutchinson’s estimation. Higher values lead to more accurate estimates but increase computation. (Note: For ‘nll’ with current implementation, this parameter is effectively ignored as the trace is computed directly via sum of squared gradients, which is exact for EFIM.)

  • create_graph – If True, the computational graph will be constructed, allowing for higher-order derivatives.

Returns:

A scalar tensor representing the estimated trace of the GFIM.

Raises: