145 lines
4.5 KiB
Python
145 lines
4.5 KiB
Python
|
#!/usr/bin/env python3
|
||
|
|
||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
|
# All rights reserved.
|
||
|
#
|
||
|
# This source code is licensed under the BSD-style license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
import datetime
|
||
|
import socket
|
||
|
from contextlib import closing
|
||
|
|
||
|
import torch.distributed as dist
|
||
|
from torch.distributed.elastic.utils.logging import get_logger
|
||
|
|
||
|
|
||
|
log = get_logger(__name__)
|
||
|
|
||
|
_ADDRESS_IN_USE = "Address already in use"
|
||
|
_SOCKET_TIMEOUT = "Socket Timeout"
|
||
|
|
||
|
_MEMBER_CHECKIN = "_tcp_store/num_members"
|
||
|
_LAST_MEMBER_CHECKIN = "_tcp_store/last_member"
|
||
|
|
||
|
|
||
|
def create_c10d_store(
|
||
|
is_server: bool,
|
||
|
server_addr: str,
|
||
|
server_port: int = -1,
|
||
|
world_size: int = 1,
|
||
|
timeout: float = (60 * 10), # 10 min
|
||
|
wait_for_workers: bool = True,
|
||
|
retries=3,
|
||
|
):
|
||
|
if server_port == -1 and world_size > 1:
|
||
|
raise ValueError(
|
||
|
f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}"
|
||
|
)
|
||
|
|
||
|
if server_port != -1:
|
||
|
log.info("sever_port: %s, specified, ignoring retries", server_port)
|
||
|
|
||
|
# only retry when server_port is NOT static
|
||
|
attempt = retries if server_port == -1 else 1
|
||
|
while True:
|
||
|
if server_port != -1:
|
||
|
port = server_port
|
||
|
else:
|
||
|
port = get_free_port()
|
||
|
|
||
|
log.info(
|
||
|
"Creating c10d store on %s:%s\n"
|
||
|
" world_size : %s\n"
|
||
|
" is_server : %s\n"
|
||
|
" timeout(sec): %s\n",
|
||
|
server_addr, port, world_size, is_server, timeout
|
||
|
)
|
||
|
|
||
|
try:
|
||
|
store = dist.TCPStore(
|
||
|
host_name=server_addr,
|
||
|
port=port,
|
||
|
world_size=world_size,
|
||
|
is_master=is_server,
|
||
|
timeout=datetime.timedelta(seconds=timeout),
|
||
|
wait_for_workers=wait_for_workers,
|
||
|
)
|
||
|
# skips full rank check when we don't have to wait for all workers
|
||
|
if wait_for_workers:
|
||
|
_check_full_rank(store, world_size)
|
||
|
log.info("Successfully created c10d store")
|
||
|
return store
|
||
|
except RuntimeError as e:
|
||
|
# this is brittle, but the underlying exception type is not properly pybinded
|
||
|
# so we parse the error msg for now, interestingly this is how torch itself
|
||
|
# detects timeouts and port conflicts in their own unittests
|
||
|
# see - caffe2/torch/testing/_internal/common_utils.py
|
||
|
# TODO properly map the exceptions in pybind (c10d/init.cpp)
|
||
|
if str(e) == _ADDRESS_IN_USE: # this will only happen on the server
|
||
|
if attempt < retries:
|
||
|
log.warning(
|
||
|
"port: %s already in use, attempt: [%s/%s]", port, attempt, retries
|
||
|
)
|
||
|
attempt += 1
|
||
|
else:
|
||
|
raise RuntimeError(
|
||
|
f"on {server_addr}, port: {port} already in use"
|
||
|
) from e
|
||
|
else:
|
||
|
raise
|
||
|
|
||
|
|
||
|
def _check_full_rank(store, world_size):
|
||
|
idx = store.add(_MEMBER_CHECKIN, 1)
|
||
|
if idx == world_size:
|
||
|
store.set(_LAST_MEMBER_CHECKIN, "<val_ignored>")
|
||
|
|
||
|
try:
|
||
|
store.get(_LAST_MEMBER_CHECKIN)
|
||
|
except RuntimeError as e:
|
||
|
if str(e) == _SOCKET_TIMEOUT:
|
||
|
raise TimeoutError(
|
||
|
f"timed out waiting for all {world_size} members to join"
|
||
|
) from e
|
||
|
else:
|
||
|
raise
|
||
|
|
||
|
|
||
|
def get_free_port():
|
||
|
sock = get_socket_with_port()
|
||
|
with closing(sock):
|
||
|
return sock.getsockname()[1]
|
||
|
|
||
|
|
||
|
def get_socket_with_port() -> socket.socket:
|
||
|
"""
|
||
|
Returns a free port on localhost that is "reserved" by binding a temporary
|
||
|
socket on it. Close the socket before passing the port to the entity
|
||
|
that requires it. Usage example
|
||
|
|
||
|
::
|
||
|
|
||
|
sock = _get_socket_with_port()
|
||
|
with closing(sock):
|
||
|
port = sock.getsockname()[1]
|
||
|
sock.close()
|
||
|
# there is still a race-condition that some other process
|
||
|
# may grab this port before func() runs
|
||
|
func(port)
|
||
|
"""
|
||
|
|
||
|
addrs = socket.getaddrinfo(
|
||
|
host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
|
||
|
)
|
||
|
for addr in addrs:
|
||
|
family, type, proto, _, _ = addr
|
||
|
s = socket.socket(family, type, proto)
|
||
|
try:
|
||
|
s.bind(("localhost", 0))
|
||
|
s.listen(0)
|
||
|
return s
|
||
|
except OSError as e:
|
||
|
s.close()
|
||
|
log.info("Socket creation attempt failed.", exc_info=e)
|
||
|
raise RuntimeError("Failed to create a socket")
|