ai-content-maker/.venv/Lib/site-packages/torch/distributed/elastic/rendezvous/utils.py

280 lines
8.2 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import ipaddress
import random
import re
import socket
import time
import weakref
from datetime import timedelta
from threading import Event, Thread
from typing import Any, Callable, Dict, Optional, Tuple, Union
__all__ = ['parse_rendezvous_endpoint']
def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
"""Extract key-value pairs from a rendezvous configuration string.
Args:
config_str:
A string in format <key1>=<value1>,...,<keyN>=<valueN>.
"""
config: Dict[str, str] = {}
config_str = config_str.strip()
if not config_str:
return config
key_values = config_str.split(",")
for kv in key_values:
key, *values = kv.split("=", 1)
key = key.strip()
if not key:
raise ValueError(
"The rendezvous configuration string must be in format "
"<key1>=<value1>,...,<keyN>=<valueN>."
)
value: Optional[str]
if values:
value = values[0].strip()
else:
value = None
if not value:
raise ValueError(
f"The rendezvous configuration option '{key}' must have a value specified."
)
config[key] = value
return config
def _try_parse_port(port_str: str) -> Optional[int]:
"""Try to extract the port number from ``port_str``."""
if port_str and re.match(r"^[0-9]{1,5}$", port_str):
return int(port_str)
return None
def parse_rendezvous_endpoint(endpoint: Optional[str], default_port: int) -> Tuple[str, int]:
"""Extract the hostname and the port number from a rendezvous endpoint.
Args:
endpoint:
A string in format <hostname>[:<port>].
default_port:
The port number to use if the endpoint does not include one.
Returns:
A tuple of hostname and port number.
"""
if endpoint is not None:
endpoint = endpoint.strip()
if not endpoint:
return ("localhost", default_port)
# An endpoint that starts and ends with brackets represents an IPv6 address.
if endpoint[0] == "[" and endpoint[-1] == "]":
host, *rest = endpoint, *[]
else:
host, *rest = endpoint.rsplit(":", 1)
# Sanitize the IPv6 address.
if len(host) > 1 and host[0] == "[" and host[-1] == "]":
host = host[1:-1]
if len(rest) == 1:
port = _try_parse_port(rest[0])
if port is None or port >= 2 ** 16:
raise ValueError(
f"The port number of the rendezvous endpoint '{endpoint}' must be an integer "
"between 0 and 65536."
)
else:
port = default_port
if not re.match(r"^[\w\.:-]+$", host):
raise ValueError(
f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of "
"labels, an IPv4 address, or an IPv6 address."
)
return host, port
def _matches_machine_hostname(host: str) -> bool:
"""Indicate whether ``host`` matches the hostname of this machine.
This function compares ``host`` to the hostname as well as to the IP
addresses of this machine. Note that it may return a false negative if this
machine has CNAME records beyond its FQDN or IP addresses assigned to
secondary NICs.
"""
if host == "localhost":
return True
try:
addr = ipaddress.ip_address(host)
except ValueError:
addr = None
if addr and addr.is_loopback:
return True
try:
host_addr_list = socket.getaddrinfo(
host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
)
except (ValueError, socket.gaierror) as _:
host_addr_list = []
host_ip_list = [
host_addr_info[4][0]
for host_addr_info in host_addr_list
]
this_host = socket.gethostname()
if host == this_host:
return True
addr_list = socket.getaddrinfo(
this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
)
for addr_info in addr_list:
# If we have an FQDN in the addr_info, compare it to `host`.
if addr_info[3] and addr_info[3] == host:
return True
# Otherwise if `host` represents an IP address, compare it to our IP
# address.
if addr and addr_info[4][0] == str(addr):
return True
# If the IP address matches one of the provided host's IP addresses
if addr_info[4][0] in host_ip_list:
return True
return False
def _delay(seconds: Union[float, Tuple[float, float]]) -> None:
"""Suspend the current thread for ``seconds``.
Args:
seconds:
Either the delay, in seconds, or a tuple of a lower and an upper
bound within which a random delay will be picked.
"""
if isinstance(seconds, tuple):
seconds = random.uniform(*seconds)
# Ignore delay requests that are less than 10 milliseconds.
if seconds >= 0.01:
time.sleep(seconds)
class _PeriodicTimer:
"""Represent a timer that periodically runs a specified function.
Args:
interval:
The interval, in seconds, between each run.
function:
The function to run.
"""
# The state of the timer is hold in a separate context object to avoid a
# reference cycle between the timer and the background thread.
class _Context:
interval: float
function: Callable[..., None]
args: Tuple[Any, ...]
kwargs: Dict[str, Any]
stop_event: Event
_name: Optional[str]
_thread: Optional[Thread]
_finalizer: Optional[weakref.finalize]
# The context that is shared between the timer and the background thread.
_ctx: _Context
def __init__(
self,
interval: timedelta,
function: Callable[..., None],
*args: Any,
**kwargs: Any,
) -> None:
self._name = None
self._ctx = self._Context()
self._ctx.interval = interval.total_seconds()
self._ctx.function = function # type: ignore[assignment]
self._ctx.args = args or ()
self._ctx.kwargs = kwargs or {}
self._ctx.stop_event = Event()
self._thread = None
self._finalizer = None
@property
def name(self) -> Optional[str]:
"""Get the name of the timer."""
return self._name
def set_name(self, name: str) -> None:
"""Set the name of the timer.
The specified name will be assigned to the background thread and serves
for debugging and troubleshooting purposes.
"""
if self._thread:
raise RuntimeError("The timer has already started.")
self._name = name
def start(self) -> None:
"""Start the timer."""
if self._thread:
raise RuntimeError("The timer has already started.")
self._thread = Thread(
target=self._run, name=self._name or "PeriodicTimer", args=(self._ctx,), daemon=True
)
# We avoid using a regular finalizer (a.k.a. __del__) for stopping the
# timer as joining a daemon thread during the interpreter shutdown can
# cause deadlocks. The weakref.finalize is a superior alternative that
# provides a consistent behavior regardless of the GC implementation.
self._finalizer = weakref.finalize(
self, self._stop_thread, self._thread, self._ctx.stop_event
)
# We do not attempt to stop our background thread during the interpreter
# shutdown. At that point we do not even know whether it still exists.
self._finalizer.atexit = False
self._thread.start()
def cancel(self) -> None:
"""Stop the timer at the next opportunity."""
if self._finalizer:
self._finalizer()
@staticmethod
def _run(ctx) -> None:
while not ctx.stop_event.wait(ctx.interval):
ctx.function(*ctx.args, **ctx.kwargs)
@staticmethod
def _stop_thread(thread, stop_event):
stop_event.set()
thread.join()