# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import ipaddress import random import re import socket import time import weakref from datetime import timedelta from threading import Event, Thread from typing import Any, Callable, Dict, Optional, Tuple, Union __all__ = ['parse_rendezvous_endpoint'] def _parse_rendezvous_config(config_str: str) -> Dict[str, str]: """Extract key-value pairs from a rendezvous configuration string. Args: config_str: A string in format =,...,=. """ config: Dict[str, str] = {} config_str = config_str.strip() if not config_str: return config key_values = config_str.split(",") for kv in key_values: key, *values = kv.split("=", 1) key = key.strip() if not key: raise ValueError( "The rendezvous configuration string must be in format " "=,...,=." ) value: Optional[str] if values: value = values[0].strip() else: value = None if not value: raise ValueError( f"The rendezvous configuration option '{key}' must have a value specified." ) config[key] = value return config def _try_parse_port(port_str: str) -> Optional[int]: """Try to extract the port number from ``port_str``.""" if port_str and re.match(r"^[0-9]{1,5}$", port_str): return int(port_str) return None def parse_rendezvous_endpoint(endpoint: Optional[str], default_port: int) -> Tuple[str, int]: """Extract the hostname and the port number from a rendezvous endpoint. Args: endpoint: A string in format [:]. default_port: The port number to use if the endpoint does not include one. Returns: A tuple of hostname and port number. """ if endpoint is not None: endpoint = endpoint.strip() if not endpoint: return ("localhost", default_port) # An endpoint that starts and ends with brackets represents an IPv6 address. if endpoint[0] == "[" and endpoint[-1] == "]": host, *rest = endpoint, *[] else: host, *rest = endpoint.rsplit(":", 1) # Sanitize the IPv6 address. if len(host) > 1 and host[0] == "[" and host[-1] == "]": host = host[1:-1] if len(rest) == 1: port = _try_parse_port(rest[0]) if port is None or port >= 2 ** 16: raise ValueError( f"The port number of the rendezvous endpoint '{endpoint}' must be an integer " "between 0 and 65536." ) else: port = default_port if not re.match(r"^[\w\.:-]+$", host): raise ValueError( f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of " "labels, an IPv4 address, or an IPv6 address." ) return host, port def _matches_machine_hostname(host: str) -> bool: """Indicate whether ``host`` matches the hostname of this machine. This function compares ``host`` to the hostname as well as to the IP addresses of this machine. Note that it may return a false negative if this machine has CNAME records beyond its FQDN or IP addresses assigned to secondary NICs. """ if host == "localhost": return True try: addr = ipaddress.ip_address(host) except ValueError: addr = None if addr and addr.is_loopback: return True try: host_addr_list = socket.getaddrinfo( host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME ) except (ValueError, socket.gaierror) as _: host_addr_list = [] host_ip_list = [ host_addr_info[4][0] for host_addr_info in host_addr_list ] this_host = socket.gethostname() if host == this_host: return True addr_list = socket.getaddrinfo( this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME ) for addr_info in addr_list: # If we have an FQDN in the addr_info, compare it to `host`. if addr_info[3] and addr_info[3] == host: return True # Otherwise if `host` represents an IP address, compare it to our IP # address. if addr and addr_info[4][0] == str(addr): return True # If the IP address matches one of the provided host's IP addresses if addr_info[4][0] in host_ip_list: return True return False def _delay(seconds: Union[float, Tuple[float, float]]) -> None: """Suspend the current thread for ``seconds``. Args: seconds: Either the delay, in seconds, or a tuple of a lower and an upper bound within which a random delay will be picked. """ if isinstance(seconds, tuple): seconds = random.uniform(*seconds) # Ignore delay requests that are less than 10 milliseconds. if seconds >= 0.01: time.sleep(seconds) class _PeriodicTimer: """Represent a timer that periodically runs a specified function. Args: interval: The interval, in seconds, between each run. function: The function to run. """ # The state of the timer is hold in a separate context object to avoid a # reference cycle between the timer and the background thread. class _Context: interval: float function: Callable[..., None] args: Tuple[Any, ...] kwargs: Dict[str, Any] stop_event: Event _name: Optional[str] _thread: Optional[Thread] _finalizer: Optional[weakref.finalize] # The context that is shared between the timer and the background thread. _ctx: _Context def __init__( self, interval: timedelta, function: Callable[..., None], *args: Any, **kwargs: Any, ) -> None: self._name = None self._ctx = self._Context() self._ctx.interval = interval.total_seconds() self._ctx.function = function # type: ignore[assignment] self._ctx.args = args or () self._ctx.kwargs = kwargs or {} self._ctx.stop_event = Event() self._thread = None self._finalizer = None @property def name(self) -> Optional[str]: """Get the name of the timer.""" return self._name def set_name(self, name: str) -> None: """Set the name of the timer. The specified name will be assigned to the background thread and serves for debugging and troubleshooting purposes. """ if self._thread: raise RuntimeError("The timer has already started.") self._name = name def start(self) -> None: """Start the timer.""" if self._thread: raise RuntimeError("The timer has already started.") self._thread = Thread( target=self._run, name=self._name or "PeriodicTimer", args=(self._ctx,), daemon=True ) # We avoid using a regular finalizer (a.k.a. __del__) for stopping the # timer as joining a daemon thread during the interpreter shutdown can # cause deadlocks. The weakref.finalize is a superior alternative that # provides a consistent behavior regardless of the GC implementation. self._finalizer = weakref.finalize( self, self._stop_thread, self._thread, self._ctx.stop_event ) # We do not attempt to stop our background thread during the interpreter # shutdown. At that point we do not even know whether it still exists. self._finalizer.atexit = False self._thread.start() def cancel(self) -> None: """Stop the timer at the next opportunity.""" if self._finalizer: self._finalizer() @staticmethod def _run(ctx) -> None: while not ctx.stop_event.wait(ctx.interval): ctx.function(*ctx.args, **ctx.kwargs) @staticmethod def _stop_thread(thread, stop_event): stop_event.set() thread.join()