Loss Landscape Visualization#
This module provides tools for visualizing the loss landscape of neural networks, allowing researchers and practitioners to gain insights into optimization dynamics, generalization properties, and the geometry of the loss surface.
Compute 1D Loss Surface#
- torch_secorder.analysis.landscape.compute_loss_surface_1d(model: Module, loss_fn: Callable[[Tensor, Tensor], Tensor], inputs: Tensor, targets: Tensor, direction: List[Tensor], alpha_range: Tuple[float, float] = (-1.0, 1.0), num_points: int = 50) Tuple[Tensor, Tensor][source]#
Computes loss values along a 1D slice in the parameter space.
- Parameters:
model – The PyTorch model.
loss_fn – The loss function.
inputs – Input tensor to the model.
targets – Target tensor for the loss function.
direction – A list of tensors representing the direction vector in parameter space. Must have the same structure and shape as model.parameters().
alpha_range – A tuple (min_alpha, max_alpha) defining the range for the 1D slice.
num_points – Number of points to sample along the 1D slice.
- Returns:
A tuple (alphas, losses) where alphas are the scaled distances along the direction and losses are the corresponding loss values.
Compute 2D Loss Surface#
- torch_secorder.analysis.landscape.compute_loss_surface_2d(model: Module, loss_fn: Callable[[Tensor, Tensor], Tensor], inputs: Tensor, targets: Tensor, direction1: List[Tensor], direction2: List[Tensor], alpha_range: Tuple[float, float] = (-1.0, 1.0), beta_range: Tuple[float, float] = (-1.0, 1.0), num_points: int = 25) Tuple[Tensor, Tensor, Tensor][source]#
Computes loss values over a 2D plane in the parameter space.
- Parameters:
model – The PyTorch model.
loss_fn – The loss function.
inputs – Input tensor to the model.
targets – Target tensor for the loss function.
direction1 – First direction vector in parameter space.
direction2 – Second direction vector in parameter space.
alpha_range – A tuple (min_alpha, max_alpha) defining the range for the first direction.
beta_range – A tuple (min_beta, max_beta) defining the range for the second direction.
num_points – Number of points to sample along each dimension of the 2D plane.
- Returns:
A tuple (alphas, betas, losses_surface) where alphas and betas are the grid coordinates and losses_surface is a 2D tensor of corresponding loss values.
Create Random Direction#
- torch_secorder.analysis.landscape.create_random_direction(model: Module) List[Tensor][source]#
Creates a random normalized direction vector in parameter space.
The direction vector has the same structure and shape as the model’s parameters.
- Parameters:
model – The PyTorch model.
- Returns:
A list of tensors representing the random normalized direction vector.
Example#
1"""Example demonstrating loss landscape visualization.
2
3This script shows how to compute and visualize 1D slices and 2D contours
4of the loss landscape for a simple model, using random directions.
5
6The example demonstrates:
71. Computing 1D loss surface along a random direction
82. Computing 2D loss surface using two random directions
93. Visualizing both surfaces using matplotlib
10
11Requirements:
12 - torch
13 - matplotlib
14 - numpy
15"""
16
17import matplotlib.pyplot as plt
18import torch
19import torch.nn as nn
20import torch.nn.functional as F
21from matplotlib import cm
22
23from torch_secorder.analysis.landscape import (
24 compute_loss_surface_1d,
25 compute_loss_surface_2d,
26 create_random_direction,
27)
28
29
30# 1. Define a Simple Model
31class SimpleModel(nn.Module):
32 def __init__(self):
33 super().__init__()
34 self.linear = nn.Linear(2, 1)
35
36 def forward(self, x):
37 return self.linear(x)
38
39
40# 2. Generate Synthetic Data
41torch.manual_seed(42)
42X_train = torch.randn(20, 2)
43y_train = X_train @ torch.tensor([[2.0], [3.0]]) + 1.0 + torch.randn(20, 1) * 0.5
44
45
46# 3. Instantiate Model and Loss Function
47model = SimpleModel()
48loss_fn = F.mse_loss
49
50
51print("--- Loss Landscape Visualization Example ---")
52
53# 4. Compute 1D Loss Surface
54print("\nComputing 1D loss surface...")
55direction_1d = create_random_direction(model)
56alphas_1d, losses_1d = compute_loss_surface_1d(
57 model,
58 loss_fn,
59 X_train,
60 y_train,
61 direction_1d,
62 alpha_range=(-2.0, 2.0),
63 num_points=50,
64)
65
66print("1D Loss Surface (first 5 points):\nAlphas: ", alphas_1d[:5].tolist())
67print("Losses: ", losses_1d[:5].tolist())
68
69# Plot 1D Loss Surface
70plt.figure(figsize=(8, 6))
71plt.plot(alphas_1d.numpy(), losses_1d.numpy(), "b-", linewidth=2)
72plt.xlabel("Alpha (Direction Scale)")
73plt.ylabel("Loss")
74plt.title("1D Loss Surface")
75plt.grid(True)
76plt.savefig("1d_loss_surface.png")
77plt.close()
78
79
80# 5. Compute 2D Loss Surface
81print("\nComputing 2D loss surface...")
82direction1_2d = create_random_direction(model)
83direction2_2d = create_random_direction(model)
84
85# Ensure directions are not collinear (optional, but good for meaningful 2D surface)
86# A simple way to get somewhat orthogonal directions: re-randomize if dot product is too high
87# This is a heuristic, proper orthogonalization methods might be preferred for robustness
88if (
89 torch.dot(
90 torch.cat([d.flatten() for d in direction1_2d]),
91 torch.cat([d.flatten() for d in direction2_2d]),
92 ).abs()
93 > 0.5
94):
95 print("Adjusting second random direction for better orthogonality...")
96 direction2_2d = create_random_direction(model)
97
98alphas_2d, betas_2d, losses_2d = compute_loss_surface_2d(
99 model, loss_fn, X_train, y_train, direction1_2d, direction2_2d, num_points=25
100)
101
102print("2D Loss Surface (top-left 3x3 values):\n", losses_2d[:3, :3].tolist())
103
104# Plot 2D Loss Surface
105fig = plt.figure(figsize=(12, 10))
106ax = fig.add_subplot(111, projection="3d")
107A, B = torch.meshgrid(alphas_2d, betas_2d, indexing="ij")
108surf = ax.plot_surface(
109 A.numpy(),
110 B.numpy(),
111 losses_2d.numpy(),
112 cmap=cm.viridis,
113 edgecolor="none",
114 alpha=0.8,
115)
116ax.set_xlabel("Alpha (Direction 1)")
117ax.set_ylabel("Beta (Direction 2)")
118ax.set_zlabel("Loss")
119ax.set_title("2D Loss Surface")
120fig.colorbar(surf, shrink=0.5, aspect=5)
121plt.savefig("2d_loss_surface.png")
122plt.close()
123
124print(
125 "\nLoss landscape visualization complete. Check '1d_loss_surface.png' and '2d_loss_surface.png' for the plots."
126)
Notes#
Parameter Interpolation: The functions compute_loss_surface_1d and compute_loss_surface_2d temporarily modify the model’s parameters to explore the loss surface. The original parameters are restored after computation.
Random Directions: create_random_direction generates a random direction vector. For more advanced analysis, users might want to use specific directions (e.g., principal components of the Hessian, or directions defined by optimization trajectories).
Visualization: This module provides the computation of loss values. External libraries like matplotlib are required for actual plotting and visualization, as demonstrated in the example.