ai-content-maker/.venv/Lib/site-packages/trainer/distribute.py

74 lines
2.2 KiB
Python
Raw Normal View History

2024-05-03 04:18:51 +03:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import pathlib
import subprocess
import time
from trainer import TrainerArgs, logger
def distribute():
"""
Call 👟Trainer training script in DDP mode.
"""
parser = TrainerArgs().init_argparse(arg_prefix="")
parser.add_argument("--script", type=str, help="Target training script to distibute.")
parser.add_argument(
"--gpus",
type=str,
help='GPU IDs to be used for distributed training in the format ```"0,1"```. Used if ```CUDA_VISIBLE_DEVICES``` is not set.',
)
args, unargs = parser.parse_known_args()
gpus = get_gpus(args)
group_id = time.strftime("%Y_%m_%d-%H%M%S")
# set arguments for train.py
folder_path = pathlib.Path(__file__).parent.absolute()
if os.path.exists(os.path.join(folder_path, args.script)):
command = [os.path.join(folder_path, args.script)]
else:
command = [args.script]
# Pass all the TrainerArgs fields
command.append(f"--continue_path={args.continue_path}")
command.append(f"--restore_path={args.restore_path}")
command.append(f"--group_id=group_{group_id}")
command.append("--use_ddp=true")
command += unargs
command.append("")
# run processes
processes = []
for rank, local_gpu_id in enumerate(gpus):
my_env = os.environ.copy()
my_env["PYTHON_EGG_CACHE"] = f"/tmp/tmp{local_gpu_id}"
my_env["RANK"] = f"{rank}"
my_env["CUDA_VISIBLE_DEVICES"] = f"{','.join(gpus)}"
command[-1] = f"--rank={rank}"
# prevent stdout for processes with rank != 0
stdout = None
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with
processes.append(p)
logger.info(command)
for p in processes:
p.wait()
def get_gpus(args):
# set active gpus from CUDA_VISIBLE_DEVICES or --gpus
if "CUDA_VISIBLE_DEVICES" in os.environ and os.environ["CUDA_VISIBLE_DEVICES"] != "":
gpus = os.environ["CUDA_VISIBLE_DEVICES"]
else:
gpus = args.gpus
gpus = list(map(str.strip, gpus.split(",")))
return gpus
if __name__ == "__main__":
distribute()