54 lines
2.1 KiB
Python
54 lines
2.1 KiB
Python
|
from .modules import * # noqa: F403
|
||
|
from .parameter import (
|
||
|
Parameter as Parameter,
|
||
|
UninitializedParameter as UninitializedParameter,
|
||
|
UninitializedBuffer as UninitializedBuffer,
|
||
|
)
|
||
|
from .parallel import DataParallel as DataParallel
|
||
|
from . import init
|
||
|
from . import functional
|
||
|
from . import utils
|
||
|
from . import attention
|
||
|
|
||
|
|
||
|
def factory_kwargs(kwargs):
|
||
|
r"""Return a canonicalized dict of factory kwargs.
|
||
|
|
||
|
Given kwargs, returns a canonicalized dict of factory kwargs that can be directly passed
|
||
|
to factory functions like torch.empty, or errors if unrecognized kwargs are present.
|
||
|
|
||
|
This function makes it simple to write code like this::
|
||
|
|
||
|
class MyModule(nn.Module):
|
||
|
def __init__(self, **kwargs):
|
||
|
factory_kwargs = torch.nn.factory_kwargs(kwargs)
|
||
|
self.weight = Parameter(torch.empty(10, **factory_kwargs))
|
||
|
|
||
|
Why should you use this function instead of just passing `kwargs` along directly?
|
||
|
|
||
|
1. This function does error validation, so if there are unexpected kwargs we will
|
||
|
immediately report an error, instead of deferring it to the factory call
|
||
|
2. This function supports a special `factory_kwargs` argument, which can be used to
|
||
|
explicitly specify a kwarg to be used for factory functions, in the event one of the
|
||
|
factory kwargs conflicts with an already existing argument in the signature (e.g.
|
||
|
in the signature ``def f(dtype, **kwargs)``, you can specify ``dtype`` for factory
|
||
|
functions, as distinct from the dtype argument, by saying
|
||
|
``f(dtype1, factory_kwargs={"dtype": dtype2})``)
|
||
|
"""
|
||
|
if kwargs is None:
|
||
|
return {}
|
||
|
simple_keys = {"device", "dtype", "memory_format"}
|
||
|
expected_keys = simple_keys | {"factory_kwargs"}
|
||
|
if not kwargs.keys() <= expected_keys:
|
||
|
raise TypeError(f"unexpected kwargs {kwargs.keys() - expected_keys}")
|
||
|
|
||
|
# guarantee no input kwargs is untouched
|
||
|
r = dict(kwargs.get("factory_kwargs", {}))
|
||
|
for k in simple_keys:
|
||
|
if k in kwargs:
|
||
|
if k in r:
|
||
|
raise TypeError(f"{k} specified twice, in **kwargs and in factory_kwargs")
|
||
|
r[k] = kwargs[k]
|
||
|
|
||
|
return r
|