Source code for torch_secorder.core.gvp

"""Jacobian-Vector Product (JVP) and Vector-Jacobian Product (VJP) utilities.

This module provides functions to compute JVPs and VJPs efficiently using PyTorch's autograd system.
Both functional and model-based APIs are provided.
"""

from typing import Callable, List, Union

import torch
import torch.nn as nn


[docs] def jvp( func: Callable[[], torch.Tensor], params: List[torch.Tensor], v: Union[torch.Tensor, List[torch.Tensor]], create_graph: bool = False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute the Jacobian-vector product (JVP): J v. Args: func: A callable that returns a tensor output (can be vector-valued). params: List of parameters with respect to which to compute the Jacobian. v: Vector to multiply with the Jacobian. Can be a single tensor or a list of tensors matching the structure of params. create_graph: If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Returns: The JVP (same shape as the output of func). """ if isinstance(v, torch.Tensor): v = [v] output = func() flat_output = output.reshape(-1) jvp_result = torch.zeros_like(flat_output) for i in range(flat_output.shape[0]): grad = torch.autograd.grad( flat_output[i], params, retain_graph=True, create_graph=create_graph, allow_unused=True, ) grad = [ (g if g is not None else torch.zeros_like(p)) for g, p in zip(grad, params) ] jvp_result[i] = sum([(g * v_).sum() for g, v_ in zip(grad, v)]) return jvp_result.reshape(output.shape)
[docs] def vjp( func: Callable[[], torch.Tensor], params: List[torch.Tensor], v: torch.Tensor, create_graph: bool = False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute the vector-Jacobian product (VJP): v^T J. Args: func: A callable that returns a tensor output (can be vector-valued). params: List of parameters with respect to which to compute the Jacobian. v: Vector to multiply with the Jacobian (should match the output shape of func). create_graph: If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Returns: The VJP (list of tensors matching the structure of params). """ output = func() v = v.reshape_as(output) grads = torch.autograd.grad( output, params, grad_outputs=v, create_graph=create_graph, allow_unused=True ) grads = [ (g if g is not None else torch.zeros_like(p)) for g, p in zip(grads, params) ] return grads[0] if len(grads) == 1 else grads
[docs] def model_jvp( model: nn.Module, x: torch.Tensor, v: Union[torch.Tensor, List[torch.Tensor]], create_graph: bool = False, ) -> torch.Tensor: """Compute the JVP for a model's output with respect to its parameters. Args: model: The PyTorch model. x: Input tensor. v: Vector to multiply with the Jacobian (should match the structure of model.parameters()). create_graph: If True, graph of the derivative will be constructed. Returns: The JVP (same shape as the model output). """ if not isinstance(model, nn.Module): raise TypeError("model must be a torch.nn.Module") params = list(model.parameters()) def forward(): return model(x) return jvp(forward, params, v, create_graph=create_graph)
[docs] def model_vjp( model: nn.Module, x: torch.Tensor, v: torch.Tensor, create_graph: bool = False, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Compute the VJP for a model's output with respect to its parameters. Args: model: The PyTorch model. x: Input tensor. v: Vector to multiply with the Jacobian (should match the output shape of model(x)). create_graph: If True, graph of the derivative will be constructed. Returns: The VJP (list of tensors matching the structure of model.parameters()). """ if not isinstance(model, nn.Module): raise TypeError("model must be a torch.nn.Module") params = list(model.parameters()) def forward(): return model(x) return vjp(forward, params, v, create_graph=create_graph)
[docs] def batch_jvp( func: Callable[[], torch.Tensor], params: List[torch.Tensor], vs: Union[torch.Tensor, List[torch.Tensor]], create_graph: bool = False, ) -> torch.Tensor: """Compute a batch of Jacobian-vector products (JVPs). Args: func: A callable that returns a tensor output (can be vector-valued). params: List of parameters with respect to which to compute the Jacobian. vs: Batch of vectors to multiply with the Jacobian. Should be a tensor of shape (batch, ...) or a list of such tensors. create_graph: If True, graph of the derivative will be constructed. Returns: Tensor of shape (batch, ...) with the JVPs for each vector in the batch. """ if isinstance(vs, torch.Tensor): vs = [vs] batch_size = vs[0].shape[0] results = [] for i in range(batch_size): v_i = [v[i] for v in vs] results.append(jvp(func, params, v_i, create_graph=create_graph)) return torch.stack(results)
[docs] def batch_vjp( func: Callable[[], torch.Tensor], params: List[torch.Tensor], vs: torch.Tensor, create_graph: bool = False, ) -> List[torch.Tensor]: """Compute a batch of vector-Jacobian products (VJPs). Args: func: A callable that returns a tensor output (can be vector-valued). params: List of parameters with respect to which to compute the Jacobian. vs: Batch of vectors to multiply with the Jacobian (should match the output shape of func, with batch dimension first). create_graph: If True, graph of the derivative will be constructed. Returns: List of tensors, each of shape (batch, ...) matching the structure of params. """ batch_size = vs.shape[0] results: list[list[torch.Tensor]] = [[] for _ in params] for i in range(batch_size): v_i = vs[i] vjp_i = vjp(func, params, v_i, create_graph=create_graph) if isinstance(vjp_i, torch.Tensor): vjp_i = [vjp_i] for j, vj in enumerate(vjp_i): results[j].append(vj) return [torch.stack(r) for r in results]
[docs] def full_jacobian( func: Callable[[], torch.Tensor], params: List[torch.Tensor], create_graph: bool = False, ) -> List[torch.Tensor]: """Compute the full Jacobian matrix of func with respect to params. Args: func: A callable that returns a tensor output (can be vector-valued). params: List of parameters with respect to which to compute the Jacobian. create_graph: If True, graph of the derivative will be constructed. Returns: List of Jacobian tensors, one for each parameter, with shape (output_dim, param_shape). """ output = func() flat_output = output.reshape(-1) jacobians = [] for p in params: jac_rows = [] for i in range(flat_output.shape[0]): grad = torch.autograd.grad( flat_output[i], p, retain_graph=True, create_graph=create_graph, allow_unused=True, )[0] if grad is None: grad = torch.zeros_like(p) jac_row = grad.reshape(-1) jac_rows.append(jac_row) jac = torch.stack(jac_rows, dim=0) jacobians.append(jac) return jacobians