84 lines
3.1 KiB
Python
84 lines
3.1 KiB
Python
|
import torch
|
||
|
from typing import Iterable, Optional
|
||
|
|
||
|
|
||
|
def parameters_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor:
|
||
|
r"""Flatten an iterable of parameters into a single vector.
|
||
|
|
||
|
Args:
|
||
|
parameters (Iterable[Tensor]): an iterable of Tensors that are the
|
||
|
parameters of a model.
|
||
|
|
||
|
Returns:
|
||
|
The parameters represented by a single vector
|
||
|
"""
|
||
|
# Flag for the device where the parameter is located
|
||
|
param_device = None
|
||
|
|
||
|
vec = []
|
||
|
for param in parameters:
|
||
|
# Ensure the parameters are located in the same device
|
||
|
param_device = _check_param_device(param, param_device)
|
||
|
|
||
|
vec.append(param.view(-1))
|
||
|
return torch.cat(vec)
|
||
|
|
||
|
|
||
|
def vector_to_parameters(vec: torch.Tensor, parameters: Iterable[torch.Tensor]) -> None:
|
||
|
r"""Copy slices of a vector into an iterable of parameters.
|
||
|
|
||
|
Args:
|
||
|
vec (Tensor): a single vector representing the parameters of a model.
|
||
|
parameters (Iterable[Tensor]): an iterable of Tensors that are the
|
||
|
parameters of a model.
|
||
|
"""
|
||
|
# Ensure vec of type Tensor
|
||
|
if not isinstance(vec, torch.Tensor):
|
||
|
raise TypeError(f'expected torch.Tensor, but got: {torch.typename(vec)}')
|
||
|
# Flag for the device where the parameter is located
|
||
|
param_device = None
|
||
|
|
||
|
# Pointer for slicing the vector for each parameter
|
||
|
pointer = 0
|
||
|
for param in parameters:
|
||
|
# Ensure the parameters are located in the same device
|
||
|
param_device = _check_param_device(param, param_device)
|
||
|
|
||
|
# The length of the parameter
|
||
|
num_param = param.numel()
|
||
|
# Slice the vector, reshape it, and replace the old data of the parameter
|
||
|
param.data = vec[pointer:pointer + num_param].view_as(param).data
|
||
|
|
||
|
# Increment the pointer
|
||
|
pointer += num_param
|
||
|
|
||
|
|
||
|
def _check_param_device(param: torch.Tensor, old_param_device: Optional[int]) -> int:
|
||
|
r"""Check if the parameters are located on the same device.
|
||
|
|
||
|
Currently, the conversion between model parameters and single vector form is not supported
|
||
|
for multiple allocations, e.g. parameters in different GPUs/PrivateUse1s, or mixture of CPU/GPU/PrivateUse1.
|
||
|
|
||
|
Args:
|
||
|
param ([Tensor]): a Tensor of a parameter of a model
|
||
|
old_param_device (int): the device where the first parameter of a
|
||
|
model is allocated.
|
||
|
|
||
|
Returns:
|
||
|
old_param_device (int): report device for the first time
|
||
|
"""
|
||
|
# Meet the first parameter
|
||
|
support_device_types = ["cuda", torch._C._get_privateuse1_backend_name()]
|
||
|
if old_param_device is None:
|
||
|
old_param_device = param.get_device() if param.device.type in support_device_types else -1
|
||
|
else:
|
||
|
warn = False
|
||
|
if param.device.type in support_device_types: # Check if in same GPU/PrivateUse1
|
||
|
warn = (param.get_device() != old_param_device)
|
||
|
else: # Check if in CPU
|
||
|
warn = (old_param_device != -1)
|
||
|
if warn:
|
||
|
raise TypeError('Found two parameters on different devices, '
|
||
|
'this is currently not supported.')
|
||
|
return old_param_device
|