#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import datetime import socket from contextlib import closing import torch.distributed as dist from torch.distributed.elastic.utils.logging import get_logger log = get_logger(__name__) _ADDRESS_IN_USE = "Address already in use" _SOCKET_TIMEOUT = "Socket Timeout" _MEMBER_CHECKIN = "_tcp_store/num_members" _LAST_MEMBER_CHECKIN = "_tcp_store/last_member" def create_c10d_store( is_server: bool, server_addr: str, server_port: int = -1, world_size: int = 1, timeout: float = (60 * 10), # 10 min wait_for_workers: bool = True, retries=3, ): if server_port == -1 and world_size > 1: raise ValueError( f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}" ) if server_port != -1: log.info("sever_port: %s, specified, ignoring retries", server_port) # only retry when server_port is NOT static attempt = retries if server_port == -1 else 1 while True: if server_port != -1: port = server_port else: port = get_free_port() log.info( "Creating c10d store on %s:%s\n" " world_size : %s\n" " is_server : %s\n" " timeout(sec): %s\n", server_addr, port, world_size, is_server, timeout ) try: store = dist.TCPStore( host_name=server_addr, port=port, world_size=world_size, is_master=is_server, timeout=datetime.timedelta(seconds=timeout), wait_for_workers=wait_for_workers, ) # skips full rank check when we don't have to wait for all workers if wait_for_workers: _check_full_rank(store, world_size) log.info("Successfully created c10d store") return store except RuntimeError as e: # this is brittle, but the underlying exception type is not properly pybinded # so we parse the error msg for now, interestingly this is how torch itself # detects timeouts and port conflicts in their own unittests # see - caffe2/torch/testing/_internal/common_utils.py # TODO properly map the exceptions in pybind (c10d/init.cpp) if str(e) == _ADDRESS_IN_USE: # this will only happen on the server if attempt < retries: log.warning( "port: %s already in use, attempt: [%s/%s]", port, attempt, retries ) attempt += 1 else: raise RuntimeError( f"on {server_addr}, port: {port} already in use" ) from e else: raise def _check_full_rank(store, world_size): idx = store.add(_MEMBER_CHECKIN, 1) if idx == world_size: store.set(_LAST_MEMBER_CHECKIN, "") try: store.get(_LAST_MEMBER_CHECKIN) except RuntimeError as e: if str(e) == _SOCKET_TIMEOUT: raise TimeoutError( f"timed out waiting for all {world_size} members to join" ) from e else: raise def get_free_port(): sock = get_socket_with_port() with closing(sock): return sock.getsockname()[1] def get_socket_with_port() -> socket.socket: """ Returns a free port on localhost that is "reserved" by binding a temporary socket on it. Close the socket before passing the port to the entity that requires it. Usage example :: sock = _get_socket_with_port() with closing(sock): port = sock.getsockname()[1] sock.close() # there is still a race-condition that some other process # may grab this port before func() runs func(port) """ addrs = socket.getaddrinfo( host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM ) for addr in addrs: family, type, proto, _, _ = addr s = socket.socket(family, type, proto) try: s.bind(("localhost", 0)) s.listen(0) return s except OSError as e: s.close() log.info("Socket creation attempt failed.", exc_info=e) raise RuntimeError("Failed to create a socket")