import abc import torch.nn as nn class Sharder(abc.ABC): """ This is an interface which allows user to create more advanced sharding strategies that are not easily be composed by the `ShardingSpec`. :class:`torch.distributed._shard.sharding_plan.ShardingPlan` could take an object of the `Sharder` and call `shard` to shard the module, then replace the original module with sharded module returned. """ @abc.abstractmethod def shard(self, module: nn.Module) -> nn.Module: """ Shard a module base on the implementation of this method, and return the sharded version of the module. Args: module (:class:`torch.nn.Module`): The module to apply sharding to. Returns: A :class:`torch.nn.Module` object that represents a module that's already been sharded. """ pass