280 lines
8.2 KiB
Python
280 lines
8.2 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import ipaddress
|
|
import random
|
|
import re
|
|
import socket
|
|
import time
|
|
import weakref
|
|
from datetime import timedelta
|
|
from threading import Event, Thread
|
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
|
|
__all__ = ['parse_rendezvous_endpoint']
|
|
|
|
def _parse_rendezvous_config(config_str: str) -> Dict[str, str]:
|
|
"""Extract key-value pairs from a rendezvous configuration string.
|
|
|
|
Args:
|
|
config_str:
|
|
A string in format <key1>=<value1>,...,<keyN>=<valueN>.
|
|
"""
|
|
config: Dict[str, str] = {}
|
|
|
|
config_str = config_str.strip()
|
|
if not config_str:
|
|
return config
|
|
|
|
key_values = config_str.split(",")
|
|
for kv in key_values:
|
|
key, *values = kv.split("=", 1)
|
|
|
|
key = key.strip()
|
|
if not key:
|
|
raise ValueError(
|
|
"The rendezvous configuration string must be in format "
|
|
"<key1>=<value1>,...,<keyN>=<valueN>."
|
|
)
|
|
|
|
value: Optional[str]
|
|
if values:
|
|
value = values[0].strip()
|
|
else:
|
|
value = None
|
|
if not value:
|
|
raise ValueError(
|
|
f"The rendezvous configuration option '{key}' must have a value specified."
|
|
)
|
|
|
|
config[key] = value
|
|
return config
|
|
|
|
|
|
def _try_parse_port(port_str: str) -> Optional[int]:
|
|
"""Try to extract the port number from ``port_str``."""
|
|
if port_str and re.match(r"^[0-9]{1,5}$", port_str):
|
|
return int(port_str)
|
|
return None
|
|
|
|
|
|
def parse_rendezvous_endpoint(endpoint: Optional[str], default_port: int) -> Tuple[str, int]:
|
|
"""Extract the hostname and the port number from a rendezvous endpoint.
|
|
|
|
Args:
|
|
endpoint:
|
|
A string in format <hostname>[:<port>].
|
|
default_port:
|
|
The port number to use if the endpoint does not include one.
|
|
|
|
Returns:
|
|
A tuple of hostname and port number.
|
|
"""
|
|
if endpoint is not None:
|
|
endpoint = endpoint.strip()
|
|
|
|
if not endpoint:
|
|
return ("localhost", default_port)
|
|
|
|
# An endpoint that starts and ends with brackets represents an IPv6 address.
|
|
if endpoint[0] == "[" and endpoint[-1] == "]":
|
|
host, *rest = endpoint, *[]
|
|
else:
|
|
host, *rest = endpoint.rsplit(":", 1)
|
|
|
|
# Sanitize the IPv6 address.
|
|
if len(host) > 1 and host[0] == "[" and host[-1] == "]":
|
|
host = host[1:-1]
|
|
|
|
if len(rest) == 1:
|
|
port = _try_parse_port(rest[0])
|
|
if port is None or port >= 2 ** 16:
|
|
raise ValueError(
|
|
f"The port number of the rendezvous endpoint '{endpoint}' must be an integer "
|
|
"between 0 and 65536."
|
|
)
|
|
else:
|
|
port = default_port
|
|
|
|
if not re.match(r"^[\w\.:-]+$", host):
|
|
raise ValueError(
|
|
f"The hostname of the rendezvous endpoint '{endpoint}' must be a dot-separated list of "
|
|
"labels, an IPv4 address, or an IPv6 address."
|
|
)
|
|
|
|
return host, port
|
|
|
|
|
|
def _matches_machine_hostname(host: str) -> bool:
|
|
"""Indicate whether ``host`` matches the hostname of this machine.
|
|
|
|
This function compares ``host`` to the hostname as well as to the IP
|
|
addresses of this machine. Note that it may return a false negative if this
|
|
machine has CNAME records beyond its FQDN or IP addresses assigned to
|
|
secondary NICs.
|
|
"""
|
|
if host == "localhost":
|
|
return True
|
|
|
|
try:
|
|
addr = ipaddress.ip_address(host)
|
|
except ValueError:
|
|
addr = None
|
|
|
|
if addr and addr.is_loopback:
|
|
return True
|
|
|
|
try:
|
|
host_addr_list = socket.getaddrinfo(
|
|
host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
|
|
)
|
|
except (ValueError, socket.gaierror) as _:
|
|
host_addr_list = []
|
|
|
|
host_ip_list = [
|
|
host_addr_info[4][0]
|
|
for host_addr_info in host_addr_list
|
|
]
|
|
|
|
this_host = socket.gethostname()
|
|
if host == this_host:
|
|
return True
|
|
|
|
addr_list = socket.getaddrinfo(
|
|
this_host, None, proto=socket.IPPROTO_TCP, flags=socket.AI_CANONNAME
|
|
)
|
|
for addr_info in addr_list:
|
|
# If we have an FQDN in the addr_info, compare it to `host`.
|
|
if addr_info[3] and addr_info[3] == host:
|
|
return True
|
|
|
|
# Otherwise if `host` represents an IP address, compare it to our IP
|
|
# address.
|
|
if addr and addr_info[4][0] == str(addr):
|
|
return True
|
|
|
|
# If the IP address matches one of the provided host's IP addresses
|
|
if addr_info[4][0] in host_ip_list:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def _delay(seconds: Union[float, Tuple[float, float]]) -> None:
|
|
"""Suspend the current thread for ``seconds``.
|
|
|
|
Args:
|
|
seconds:
|
|
Either the delay, in seconds, or a tuple of a lower and an upper
|
|
bound within which a random delay will be picked.
|
|
"""
|
|
if isinstance(seconds, tuple):
|
|
seconds = random.uniform(*seconds)
|
|
# Ignore delay requests that are less than 10 milliseconds.
|
|
if seconds >= 0.01:
|
|
time.sleep(seconds)
|
|
|
|
|
|
class _PeriodicTimer:
|
|
"""Represent a timer that periodically runs a specified function.
|
|
|
|
Args:
|
|
interval:
|
|
The interval, in seconds, between each run.
|
|
function:
|
|
The function to run.
|
|
"""
|
|
|
|
# The state of the timer is hold in a separate context object to avoid a
|
|
# reference cycle between the timer and the background thread.
|
|
class _Context:
|
|
interval: float
|
|
function: Callable[..., None]
|
|
args: Tuple[Any, ...]
|
|
kwargs: Dict[str, Any]
|
|
stop_event: Event
|
|
|
|
_name: Optional[str]
|
|
_thread: Optional[Thread]
|
|
_finalizer: Optional[weakref.finalize]
|
|
|
|
# The context that is shared between the timer and the background thread.
|
|
_ctx: _Context
|
|
|
|
def __init__(
|
|
self,
|
|
interval: timedelta,
|
|
function: Callable[..., None],
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self._name = None
|
|
|
|
self._ctx = self._Context()
|
|
self._ctx.interval = interval.total_seconds()
|
|
self._ctx.function = function # type: ignore[assignment]
|
|
self._ctx.args = args or ()
|
|
self._ctx.kwargs = kwargs or {}
|
|
self._ctx.stop_event = Event()
|
|
|
|
self._thread = None
|
|
self._finalizer = None
|
|
|
|
@property
|
|
def name(self) -> Optional[str]:
|
|
"""Get the name of the timer."""
|
|
return self._name
|
|
|
|
def set_name(self, name: str) -> None:
|
|
"""Set the name of the timer.
|
|
|
|
The specified name will be assigned to the background thread and serves
|
|
for debugging and troubleshooting purposes.
|
|
"""
|
|
if self._thread:
|
|
raise RuntimeError("The timer has already started.")
|
|
|
|
self._name = name
|
|
|
|
def start(self) -> None:
|
|
"""Start the timer."""
|
|
if self._thread:
|
|
raise RuntimeError("The timer has already started.")
|
|
|
|
self._thread = Thread(
|
|
target=self._run, name=self._name or "PeriodicTimer", args=(self._ctx,), daemon=True
|
|
)
|
|
|
|
# We avoid using a regular finalizer (a.k.a. __del__) for stopping the
|
|
# timer as joining a daemon thread during the interpreter shutdown can
|
|
# cause deadlocks. The weakref.finalize is a superior alternative that
|
|
# provides a consistent behavior regardless of the GC implementation.
|
|
self._finalizer = weakref.finalize(
|
|
self, self._stop_thread, self._thread, self._ctx.stop_event
|
|
)
|
|
|
|
# We do not attempt to stop our background thread during the interpreter
|
|
# shutdown. At that point we do not even know whether it still exists.
|
|
self._finalizer.atexit = False
|
|
|
|
self._thread.start()
|
|
|
|
def cancel(self) -> None:
|
|
"""Stop the timer at the next opportunity."""
|
|
if self._finalizer:
|
|
self._finalizer()
|
|
|
|
@staticmethod
|
|
def _run(ctx) -> None:
|
|
while not ctx.stop_event.wait(ctx.interval):
|
|
ctx.function(*ctx.args, **ctx.kwargs)
|
|
|
|
@staticmethod
|
|
def _stop_thread(thread, stop_event):
|
|
stop_event.set()
|
|
|
|
thread.join()
|