from torch import nn from typing import List, Optional __all__ = ["partition_model"] def partition_model( module: nn.Sequential, balance: List[int], devices: Optional[List[int]] = None): """ Partions the model accross multiple GPU devices. Given an :class:`nn.Sequential ` module, partitions the model across multiple GPU devices according the provided ``balance`` and ``devices``. Args: module (:class:`nn.Sequential `): Sequential model representing the pipe. balance (List[int]): List indicating the number of layers in each partition. devices (List[int], optional): List indicating the device to use for each partition. Defaults to ``range(len(balance))`` """ device_idx = 0 pipe_idx = 0 balanced_pipe = [] for num_layers in balance: layers = [] for i in range(num_layers): layers.append(module[pipe_idx]) pipe_idx += 1 device = device_idx if devices is None else devices[device_idx] balanced_pipe.append(nn.Sequential(*layers).to(device)) device_idx += 1 return nn.Sequential(*balanced_pipe)