Source code for torch_secorder.core.utils
"""Utility functions for handling model parameters."""
from typing import Dict, Iterable, List, Tuple, Union
import torch
from torch.nn import Module
[docs]
def flatten_params(params: Iterable[torch.Tensor]) -> torch.Tensor:
"""Flattens a list of parameter tensors into a single concatenated tensor.
Args:
params: An iterable of PyTorch parameter tensors.
Returns:
A single 1D tensor containing all flattened parameters.
"""
return torch.cat([p.view(-1) for p in params])
[docs]
def unflatten_params(
flat_params: torch.Tensor, param_shapes: List[torch.Size]
) -> List[torch.Tensor]:
"""Unflattens a single tensor of parameters back into a list of tensors with original shapes.
Args:
flat_params: A 1D tensor containing all flattened parameters.
param_shapes: A list of `torch.Size` objects, representing the original shapes of the parameters.
Returns:
A list of PyTorch tensors with their original shapes.
"""
unflattened_params = []
offset = 0
for shape in param_shapes:
num_elements = shape.numel()
param = flat_params[offset : offset + num_elements].view(shape)
unflattened_params.append(param)
offset += num_elements
return unflattened_params
[docs]
def get_param_shapes(params: Iterable[torch.Tensor]) -> List[torch.Size]:
"""Retrieves the shapes of an iterable of parameter tensors.
Args:
params: An iterable of PyTorch parameter tensors.
Returns:
A list of `torch.Size` objects, each representing the shape of a parameter.
"""
return [p.shape for p in params]
[docs]
def get_params_by_module_type(
model: Module, module_type: Union[type, List[type], Tuple[type, ...]]
) -> Dict[str, List[torch.Tensor]]:
"""Extracts parameters belonging to specific module types.
Args:
model: The PyTorch model to inspect.
module_type: The type(s) of `torch.nn.Module` to filter parameters by.
Can be a single type (e.g., `torch.nn.Linear`) or a list/tuple of types.
Returns:
Dictionary mapping module names to lists of their parameter tensors.
"""
if not isinstance(model, Module):
raise TypeError("model must be a torch.nn.Module")
if isinstance(module_type, list):
module_type = tuple(module_type)
elif not isinstance(module_type, tuple):
module_type = (module_type,)
params_by_module = {}
for name, module in model.named_modules():
if isinstance(module, module_type):
params_by_module[name] = list(module.parameters()) # type: ignore[attr-defined]
return params_by_module
[docs]
def get_params_by_name_pattern(model: Module, pattern: str) -> List[torch.Tensor]:
"""Extracts parameters whose names match a given pattern.
This is useful for selecting parameters based on their hierarchical names
within the model (e.g., "layer1.0.weight").
Args:
model: The PyTorch model to inspect.
pattern: A regex pattern string to match against parameter names.
Returns:
A list of parameter tensors whose names match the pattern.
"""
if not isinstance(model, Module):
raise TypeError("model must be a torch.nn.Module")
import re
params = []
for name, param in model.named_parameters():
if re.search(pattern, name):
params.append(param)
return params