mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[Tools] use torchrun instead of torch.distributed.launch (#6304)
This commit is contained in:
@@ -21,7 +21,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
|
||||
master_port=1234,
|
||||
)
|
||||
expected = (
|
||||
"python3.7 -m torch.distributed.launch "
|
||||
"python3.7 -m torch.distributed.run "
|
||||
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
|
||||
"--master_port=1234 path/to/some/trainer.py arg1 arg2"
|
||||
)
|
||||
@@ -41,7 +41,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
|
||||
master_port=1234,
|
||||
)
|
||||
expected = (
|
||||
"cd path/to && python3.7 -m torch.distributed.launch "
|
||||
"cd path/to && python3.7 -m torch.distributed.run "
|
||||
"--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
|
||||
"--master_port=1234 path/to/some/trainer.py arg1 arg2"
|
||||
)
|
||||
@@ -68,7 +68,7 @@ class TestWrapUdfInTorchDistLauncher(unittest.TestCase):
|
||||
master_port=1234,
|
||||
)
|
||||
expected = (
|
||||
"{python_bin} -m torch.distributed.launch ".format(
|
||||
"{python_bin} -m torch.distributed.run ".format(
|
||||
python_bin=py_bin
|
||||
)
|
||||
+ "--nproc_per_node=2 --nnodes=2 --node_rank=1 --master_addr=127.0.0.1 "
|
||||
@@ -221,7 +221,7 @@ def test_submit_jobs():
|
||||
assert "DGL_ROLE=client" in cmd
|
||||
assert "DGL_GROUP_ID=0" in cmd
|
||||
assert (
|
||||
f"python3 -m torch.distributed.launch --nproc_per_node={args.num_trainers} --nnodes={num_machines}"
|
||||
f"python3 -m torch.distributed.run --nproc_per_node={args.num_trainers} --nnodes={num_machines}"
|
||||
in cmd
|
||||
)
|
||||
assert "--master_addr=127.0.0" in cmd
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Optional
|
||||
|
||||
def cleanup_proc(get_all_remote_pids, conn):
|
||||
"""This process tries to clean up the remote training tasks."""
|
||||
print("cleanupu process runs")
|
||||
print("cleanup process runs")
|
||||
# This process should not handle SIGINT.
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
@@ -228,7 +228,7 @@ def construct_torch_dist_launcher_cmd(
|
||||
cmd_str.
|
||||
"""
|
||||
torch_cmd_template = (
|
||||
"-m torch.distributed.launch "
|
||||
"-m torch.distributed.run "
|
||||
"--nproc_per_node={nproc_per_node} "
|
||||
"--nnodes={nnodes} "
|
||||
"--node_rank={node_rank} "
|
||||
@@ -252,10 +252,10 @@ def wrap_udf_in_torch_dist_launcher(
|
||||
master_addr: str,
|
||||
master_port: int,
|
||||
) -> str:
|
||||
"""Wraps the user-defined function (udf_command) with the torch.distributed.launch module.
|
||||
"""Wraps the user-defined function (udf_command) with the torch.distributed.run module.
|
||||
|
||||
Example: if udf_command is "python3 run/some/trainer.py arg1 arg2", then new_df_command becomes:
|
||||
"python3 -m torch.distributed.launch <TORCH DIST ARGS> run/some/trainer.py arg1 arg2
|
||||
"python3 -m torch.distributed.run <TORCH DIST ARGS> run/some/trainer.py arg1 arg2
|
||||
|
||||
udf_command is assumed to consist of pre-commands (optional) followed by the python launcher script (required):
|
||||
Examples:
|
||||
@@ -310,7 +310,7 @@ def wrap_udf_in_torch_dist_launcher(
|
||||
# transforms the udf_command from:
|
||||
# python path/to/dist_trainer.py arg0 arg1
|
||||
# to:
|
||||
# python -m torch.distributed.launch [DIST TORCH ARGS] path/to/dist_trainer.py arg0 arg1
|
||||
# python -m torch.distributed.run [DIST TORCH ARGS] path/to/dist_trainer.py arg0 arg1
|
||||
# Note: if there are multiple python commands in `udf_command`, this may do the Wrong Thing, eg launch each
|
||||
# python command within the torch distributed launcher.
|
||||
new_udf_command = udf_command.replace(
|
||||
@@ -593,7 +593,7 @@ def submit_jobs(args, udf_command, dry_run=False):
|
||||
master_port = get_available_port(master_addr)
|
||||
for node_id, host in enumerate(hosts):
|
||||
ip, _ = host
|
||||
# Transform udf_command to follow torch's dist launcher format: `PYTHON_BIN -m torch.distributed.launch ... UDF`
|
||||
# Transform udf_command to follow torch's dist launcher format: `PYTHON_BIN -m torch.distributed.run ... UDF`
|
||||
torch_dist_udf_command = wrap_udf_in_torch_dist_launcher(
|
||||
udf_command=udf_command,
|
||||
num_trainers=args.num_trainers,
|
||||
|
||||
Reference in New Issue
Block a user