38 lines
1.5 KiB
Python
38 lines
1.5 KiB
Python
from contextlib import contextmanager
|
|
from typing import cast
|
|
import logging
|
|
from . import api
|
|
from . import TensorPipeAgent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@contextmanager
|
|
def _group_membership_management(store, name, is_join):
|
|
token_key = "RpcGroupManagementToken"
|
|
join_or_leave = "join" if is_join else "leave"
|
|
my_token = f"Token_for_{name}_{join_or_leave}"
|
|
while True:
|
|
# Retrieve token from store to signal start of rank join/leave critical section
|
|
returned = store.compare_set(token_key, "", my_token).decode()
|
|
if returned == my_token:
|
|
# Yield to the function this context manager wraps
|
|
yield
|
|
# Finished, now exit and release token
|
|
# Update from store to signal end of rank join/leave critical section
|
|
store.set(token_key, "")
|
|
# Other will wait for this token to be set before they execute
|
|
store.set(my_token, "Done")
|
|
break
|
|
else:
|
|
# Store will wait for the token to be released
|
|
try:
|
|
store.wait([returned])
|
|
except RuntimeError:
|
|
logger.error("Group membership token %s timed out waiting for %s to be released.", my_token, returned)
|
|
raise
|
|
|
|
def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join):
|
|
agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
|
|
ret = agent._update_group_membership(worker_info, my_devices, reverse_device_map, is_join)
|
|
return ret
|