38 lines
1.5 KiB
Python
38 lines
1.5 KiB
Python
|
from contextlib import contextmanager
|
||
|
from typing import cast
|
||
|
import logging
|
||
|
from . import api
|
||
|
from . import TensorPipeAgent
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
@contextmanager
|
||
|
def _group_membership_management(store, name, is_join):
|
||
|
token_key = "RpcGroupManagementToken"
|
||
|
join_or_leave = "join" if is_join else "leave"
|
||
|
my_token = f"Token_for_{name}_{join_or_leave}"
|
||
|
while True:
|
||
|
# Retrieve token from store to signal start of rank join/leave critical section
|
||
|
returned = store.compare_set(token_key, "", my_token).decode()
|
||
|
if returned == my_token:
|
||
|
# Yield to the function this context manager wraps
|
||
|
yield
|
||
|
# Finished, now exit and release token
|
||
|
# Update from store to signal end of rank join/leave critical section
|
||
|
store.set(token_key, "")
|
||
|
# Other will wait for this token to be set before they execute
|
||
|
store.set(my_token, "Done")
|
||
|
break
|
||
|
else:
|
||
|
# Store will wait for the token to be released
|
||
|
try:
|
||
|
store.wait([returned])
|
||
|
except RuntimeError:
|
||
|
logger.error("Group membership token %s timed out waiting for %s to be released.", my_token, returned)
|
||
|
raise
|
||
|
|
||
|
def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join):
|
||
|
agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
|
||
|
ret = agent._update_group_membership(worker_info, my_devices, reverse_device_map, is_join)
|
||
|
return ret
|