mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
* init update * all get/set optm_state * add unit tests * add docstring * fix for multiple embeddings * move embedding methods to private * fix lint * fix unit tests * resolve comments * merge master
699 lines
21 KiB
Python
699 lines
21 KiB
Python
import os
|
|
import unittest
|
|
|
|
import backend as F
|
|
import pytest
|
|
import torch as th
|
|
import torch.multiprocessing as mp
|
|
|
|
from dgl.nn import NodeEmbedding
|
|
from dgl.optim import SparseAdagrad, SparseAdam
|
|
|
|
|
|
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
|
|
@pytest.mark.parametrize("emb_dim", [1, 4, 101, 1024])
|
|
def test_sparse_adam(emb_dim):
|
|
num_embs = 10
|
|
device = F.ctx()
|
|
dgl_emb = NodeEmbedding(num_embs, emb_dim, "test")
|
|
torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
|
|
th.manual_seed(0)
|
|
th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
|
|
th.manual_seed(0)
|
|
th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)
|
|
|
|
dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)
|
|
torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01)
|
|
|
|
# first step
|
|
idx = th.randint(0, num_embs, size=(4,))
|
|
dgl_value = dgl_emb(idx, device).to(th.device("cpu"))
|
|
torch_value = torch_emb(idx)
|
|
labels = th.zeros((4,)).long()
|
|
print("dgl_value = {}".format(dgl_value))
|
|
print("labels = {}".format(labels))
|
|
|
|
dgl_adam.zero_grad()
|
|
torch_adam.zero_grad()
|
|
dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
|
|
torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
|
|
dgl_loss.backward()
|
|
torch_loss.backward()
|
|
|
|
dgl_adam.step()
|
|
torch_adam.step()
|
|
assert F.allclose(dgl_emb.weight, torch_emb.weight)
|
|
|
|
# Can not test second step
|
|
# Pytorch sparseAdam maintains a global step
|
|
# DGL sparseAdam use a per embedding step
|
|
|
|
|
|
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
|
|
@pytest.mark.parametrize("use_uva", [False, True, None])
|
|
@pytest.mark.parametrize("emb_dim", [1, 4, 101, 1024])
|
|
def test_sparse_adam_uva(use_uva, emb_dim):
|
|
if F.ctx().type == "cpu" and use_uva == True:
|
|
# we want to only test values of False and None when not using GPU
|
|
pytest.skip("UVA cannot be used without GPUs.")
|
|
|
|
num_embs = 10
|
|
device = F.ctx()
|
|
dgl_emb = NodeEmbedding(num_embs, emb_dim, "test_uva{}".format(use_uva))
|
|
torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
|
|
th.manual_seed(0)
|
|
th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
|
|
th.manual_seed(0)
|
|
th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)
|
|
|
|
dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01, use_uva=use_uva)
|
|
torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01)
|
|
|
|
# first step
|
|
idx = th.randint(0, num_embs, size=(4,))
|
|
dgl_value = dgl_emb(idx, device).to(th.device("cpu"))
|
|
torch_value = torch_emb(idx)
|
|
labels = th.zeros((4,)).long()
|
|
|
|
dgl_adam.zero_grad()
|
|
torch_adam.zero_grad()
|
|
dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
|
|
torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
|
|
dgl_loss.backward()
|
|
torch_loss.backward()
|
|
|
|
dgl_adam.step()
|
|
torch_adam.step()
|
|
assert F.allclose(dgl_emb.weight, torch_emb.weight)
|
|
|
|
# Can not test second step
|
|
# Pytorch sparseAdam maintains a global step
|
|
# DGL sparseAdam use a per embedding step
|
|
|
|
|
|
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
|
|
@pytest.mark.parametrize("dtype", [th.float32, th.float16])
|
|
@pytest.mark.parametrize("emb_dim", [1, 4, 101, 1024])
|
|
def test_sparse_adam_dtype(dtype, emb_dim):
|
|
num_embs = 10
|
|
device = F.ctx()
|
|
dgl_emb = NodeEmbedding(num_embs, emb_dim, "test_dtype{}".format(dtype))
|
|
torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
|
|
th.manual_seed(0)
|
|
th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
|
|
th.manual_seed(0)
|
|
th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)
|
|
|
|
dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01, dtype=dtype)
|
|
torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01)
|
|
|
|
# first step
|
|
idx = th.randint(0, num_embs, size=(4,))
|
|
dgl_value = dgl_emb(idx, device).to(th.device("cpu"))
|
|
torch_value = torch_emb(idx)
|
|
labels = th.zeros((4,)).long()
|
|
|
|
dgl_adam.zero_grad()
|
|
torch_adam.zero_grad()
|
|
dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
|
|
torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
|
|
dgl_loss.backward()
|
|
torch_loss.backward()
|
|
|
|
dgl_adam.step()
|
|
torch_adam.step()
|
|
assert F.allclose(dgl_emb.weight, torch_emb.weight)
|
|
|
|
# Can not test second step
|
|
# Pytorch sparseAdam maintains a global step
|
|
# DGL sparseAdam use a per embedding step
|
|
|
|
|
|
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
|
|
def test_sparse_adam_zero_step():
|
|
num_embs = 10
|
|
emb_dim = 4
|
|
device = F.ctx()
|
|
dgl_emb = NodeEmbedding(num_embs, emb_dim, "test")
|
|
torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
|
|
dgl_emb_zero = NodeEmbedding(num_embs, emb_dim, "test2")
|
|
torch_emb_zero = th.nn.Embedding(num_embs, emb_dim, sparse=True)
|
|
th.manual_seed(0)
|
|
th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
|
|
th.nn.init.uniform_(torch_emb_zero.weight, 0, 1.0)
|
|
th.manual_seed(0)
|
|
th.nn.init.uniform_(dgl_emb.weight, 0, 1.0)
|
|
th.nn.init.uniform_(dgl_emb_zero.weight, 0, 1.0)
|
|
|
|
dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)
|
|
torch_adam = th.optim.SparseAdam(
|
|
list(torch_emb.parameters()) + list(torch_emb_zero.parameters()),
|
|
lr=0.01,
|
|
)
|
|
|
|
# first step
|
|
idx = th.randint(0, num_embs, size=(4,))
|
|
dgl_value = dgl_emb(idx, device).to(th.device("cpu"))
|
|
torch_value = torch_emb(idx)
|
|
labels = th.ones((4,)).long()
|
|
|
|
dgl_adam.zero_grad()
|
|
torch_adam.zero_grad()
|
|
dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
|
|
torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
|
|
dgl_loss.backward()
|
|
torch_loss.backward()
|
|
|
|
dgl_adam.step()
|
|
torch_adam.step()
|
|
assert F.allclose(dgl_emb.weight, torch_emb.weight)
|
|
|
|
|
|
def initializer(emb):
|
|
th.manual_seed(0)
|
|
emb.uniform_(-1.0, 1.0)
|
|
return emb
|
|
|
|
|
|
def start_sparse_adam_worker(
|
|
rank,
|
|
device,
|
|
world_size,
|
|
weight,
|
|
tensor_dev="cpu",
|
|
has_zero_grad=False,
|
|
backend="gloo",
|
|
num_embs=128,
|
|
emb_dim=10,
|
|
zero_comm=True,
|
|
):
|
|
print("start sparse worker for adam {}".format(rank))
|
|
dist_init_method = "tcp://{master_ip}:{master_port}".format(
|
|
master_ip="127.0.0.1", master_port="12345"
|
|
)
|
|
|
|
if device.type == "cuda":
|
|
th.cuda.set_device(device)
|
|
|
|
th.distributed.init_process_group(
|
|
backend=backend,
|
|
init_method=dist_init_method,
|
|
world_size=world_size,
|
|
rank=rank,
|
|
)
|
|
|
|
init_weight = th.empty((num_embs, emb_dim))
|
|
th.manual_seed(0)
|
|
th.nn.init.uniform_(init_weight, -1.0, 1.0)
|
|
dgl_emb = NodeEmbedding(
|
|
num_embs, emb_dim, "test", init_func=initializer, device=tensor_dev
|
|
)
|
|
dgl_emb.all_set_embedding(init_weight)
|
|
|
|
if has_zero_grad:
|
|
dgl_emb_zero = NodeEmbedding(
|
|
num_embs, emb_dim, "zero", init_func=initializer, device=tensor_dev
|
|
)
|
|
dgl_adam = SparseAdam(params=[dgl_emb, dgl_emb_zero], lr=0.01)
|
|
else:
|
|
dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)
|
|
|
|
th.manual_seed(rank)
|
|
if zero_comm:
|
|
start = (num_embs // world_size) * rank
|
|
end = (num_embs // world_size) * (rank + 1)
|
|
idx = th.randint(start, end, size=(4,)).to(tensor_dev)
|
|
else:
|
|
idx = th.randint(0, num_embs, size=(4,)).to(tensor_dev)
|
|
dgl_value = dgl_emb(idx, device)
|
|
labels = th.ones((4,)).long().to(device)
|
|
dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
|
|
dgl_adam.zero_grad()
|
|
dgl_loss.backward()
|
|
dgl_adam.step()
|
|
th.distributed.barrier()
|
|
dgl_weight = dgl_emb.all_get_embedding().detach()
|
|
after_step = dgl_emb(idx, device).cpu()
|
|
|
|
if rank == 0:
|
|
dgl_value = dgl_value.detach().cpu()
|
|
assert F.allclose(dgl_value, after_step) is False
|
|
weight[:] = dgl_weight[:]
|
|
th.distributed.barrier()
|
|
|
|
|
|
def start_torch_adam_worker(
|
|
rank,
|
|
world_size,
|
|
weight,
|
|
has_zero_grad=False,
|
|
num_embs=128,
|
|
emb_dim=10,
|
|
zero_comm=True,
|
|
):
|
|
print("start sparse worker for adam {}".format(rank))
|
|
dist_init_method = "tcp://{master_ip}:{master_port}".format(
|
|
master_ip="127.0.0.1", master_port="12345"
|
|
)
|
|
backend = "gloo"
|
|
|
|
th.distributed.init_process_group(
|
|
backend=backend,
|
|
init_method=dist_init_method,
|
|
world_size=world_size,
|
|
rank=rank,
|
|
)
|
|
|
|
torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
|
|
th.manual_seed(0)
|
|
th.nn.init.uniform_(torch_emb.weight, -1.0, 1.0)
|
|
torch_emb = th.nn.parallel.DistributedDataParallel(torch_emb)
|
|
if has_zero_grad:
|
|
torch_emb_zero = th.nn.Embedding(num_embs, emb_dim, sparse=True)
|
|
torch_emb_zero = torch_emb_zero.to(tensor_dev)
|
|
th.manual_seed(0)
|
|
th.nn.init.uniform_(torch_emb_zero.weight, -1.0, 1.0)
|
|
torch_emb_zero = th.nn.parallel.DistributedDataParallel(torch_emb_zero)
|
|
torch_adam = th.optim.SparseAdam(
|
|
list(torch_emb.module.parameters())
|
|
+ list(torch_emb_zero.module.parameters()),
|
|
lr=0.01,
|
|
)
|
|
else:
|
|
torch_adam = th.optim.SparseAdam(
|
|
list(torch_emb.module.parameters()), lr=0.01
|
|
)
|
|
|
|
th.manual_seed(rank)
|
|
if zero_comm:
|
|
start = (num_embs // world_size) * rank
|
|
end = (num_embs // world_size) * (rank + 1)
|
|
idx = th.randint(start, end, size=(4,))
|
|
else:
|
|
idx = th.randint(0, num_embs, size=(4,))
|
|
labels = th.ones((4,)).long()
|
|
torch_value = torch_emb(idx)
|
|
torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
|
|
torch_adam.zero_grad()
|
|
torch_loss.backward()
|
|
torch_adam.step()
|
|
th.distributed.barrier()
|
|
|
|
if rank == 0:
|
|
weight[:] = torch_emb.module.weight.cpu()[:]
|
|
th.distributed.barrier()
|
|
|
|
|
|
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
|
|
@unittest.skipIf(F.ctx().type != "cpu", reason="cpu only test")
|
|
@pytest.mark.parametrize("num_workers", [2, 4])
|
|
def test_multiprocess_cpu_sparse_adam(num_workers):
|
|
backend = "gloo"
|
|
worker_list = []
|
|
num_embs = 128
|
|
emb_dim = 10
|
|
dgl_weight = th.empty((num_embs, emb_dim))
|
|
ctx = mp.get_context("spawn")
|
|
for i in range(num_workers):
|
|
device = F.ctx()
|
|
p = ctx.Process(
|
|
target=start_sparse_adam_worker,
|
|
args=(
|
|
i,
|
|
device,
|
|
num_workers,
|
|
dgl_weight,
|
|
th.device("cpu"),
|
|
True,
|
|
backend,
|
|
),
|
|
)
|
|
p.start()
|
|
worker_list.append(p)
|
|
for p in worker_list:
|
|
p.join()
|
|
|
|
worker_list = []
|
|
torch_weight = th.empty((num_embs, emb_dim))
|
|
for i in range(num_workers):
|
|
p = ctx.Process(
|
|
target=start_torch_adam_worker,
|
|
args=(i, num_workers, torch_weight, False),
|
|
)
|
|
p.start()
|
|
worker_list.append(p)
|
|
for p in worker_list:
|
|
p.join()
|
|
|
|
assert F.allclose(dgl_weight, torch_weight)
|
|
|
|
|
|
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
|
|
@unittest.skipIf(F.ctx().type == "cpu", reason="gpu only test")
|
|
@pytest.mark.parametrize("num_workers", [2, 4, 8])
|
|
@pytest.mark.parametrize("backend", ["nccl", "gloo"])
|
|
@pytest.mark.parametrize("zero_comm", [True, False])
|
|
def test_multiprocess_sparse_adam(num_workers, backend, zero_comm):
|
|
if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
|
|
pytest.skip("Not enough GPUs to run test.")
|
|
|
|
worker_list = []
|
|
num_embs = 128
|
|
emb_dim = 10
|
|
dgl_weight = th.empty((num_embs, emb_dim))
|
|
ctx = mp.get_context("spawn")
|
|
for i in range(num_workers):
|
|
device = F.ctx()
|
|
if device.type == "cuda":
|
|
# make sure each process has a unique GPU
|
|
device = th.device(i)
|
|
p = ctx.Process(
|
|
target=start_sparse_adam_worker,
|
|
args=(
|
|
i,
|
|
device,
|
|
num_workers,
|
|
dgl_weight,
|
|
th.device("cpu"),
|
|
True,
|
|
backend,
|
|
num_embs,
|
|
emb_dim,
|
|
zero_comm,
|
|
),
|
|
)
|
|
p.start()
|
|
worker_list.append(p)
|
|
for p in worker_list:
|
|
p.join()
|
|
|
|
worker_list = []
|
|
torch_weight = th.empty((num_embs, emb_dim))
|
|
for i in range(num_workers):
|
|
p = ctx.Process(
|
|
target=start_torch_adam_worker,
|
|
args=(
|
|
i,
|
|
num_workers,
|
|
torch_weight,
|
|
False,
|
|
num_embs,
|
|
emb_dim,
|
|
zero_comm,
|
|
),
|
|
)
|
|
p.start()
|
|
worker_list.append(p)
|
|
for p in worker_list:
|
|
p.join()
|
|
|
|
assert F.allclose(dgl_weight, torch_weight)
|
|
|
|
|
|
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
|
|
@unittest.skipIf(
|
|
F.ctx().type == "cpu", reason="cuda tensor is not supported for cpu"
|
|
)
|
|
@pytest.mark.parametrize("num_workers", [2, 4, 8])
|
|
def test_multiprocess_sparse_adam_cuda_tensor(num_workers):
|
|
if F.ctx().type == "cpu":
|
|
pytest.skip("Do not test CPU")
|
|
if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
|
|
pytest.skip("Not enough GPUs to run test.")
|
|
|
|
backend = "nccl"
|
|
worker_list = []
|
|
num_embs = 128
|
|
emb_dim = 10
|
|
dgl_weight = th.empty((num_embs, emb_dim))
|
|
ctx = mp.get_context("spawn")
|
|
for i in range(num_workers):
|
|
device = th.device(i)
|
|
p = ctx.Process(
|
|
target=start_sparse_adam_worker,
|
|
args=(i, device, num_workers, dgl_weight, device, False, backend),
|
|
)
|
|
p.start()
|
|
worker_list.append(p)
|
|
for p in worker_list:
|
|
p.join()
|
|
|
|
worker_list = []
|
|
torch_weight = th.empty((num_embs, emb_dim))
|
|
for i in range(num_workers):
|
|
p = ctx.Process(
|
|
target=start_torch_adam_worker,
|
|
args=(i, num_workers, torch_weight, False),
|
|
)
|
|
p.start()
|
|
worker_list.append(p)
|
|
for p in worker_list:
|
|
p.join()
|
|
|
|
assert F.allclose(dgl_weight, torch_weight)
|
|
|
|
|
|
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
|
|
@unittest.skipIf(F.ctx().type != "cpu", reason="cpu only test")
|
|
@pytest.mark.parametrize("num_workers", [2, 4])
|
|
def test_multiprocess_sparse_adam_cpu_zero_step(num_workers):
|
|
backend = "gloo"
|
|
|
|
worker_list = []
|
|
num_embs = 128
|
|
emb_dim = 10
|
|
dgl_weight = th.empty((num_embs, emb_dim))
|
|
ctx = mp.get_context("spawn")
|
|
for i in range(num_workers):
|
|
device = F.ctx()
|
|
p = ctx.Process(
|
|
target=start_sparse_adam_worker,
|
|
args=(
|
|
i,
|
|
device,
|
|
num_workers,
|
|
dgl_weight,
|
|
th.device("cpu"),
|
|
True,
|
|
backend,
|
|
),
|
|
)
|
|
p.start()
|
|
worker_list.append(p)
|
|
for p in worker_list:
|
|
p.join()
|
|
|
|
worker_list = []
|
|
torch_weight = th.empty((num_embs, emb_dim))
|
|
for i in range(num_workers):
|
|
p = ctx.Process(
|
|
target=start_torch_adam_worker,
|
|
args=(i, num_workers, torch_weight, False),
|
|
)
|
|
p.start()
|
|
worker_list.append(p)
|
|
for p in worker_list:
|
|
p.join()
|
|
|
|
assert F.allclose(dgl_weight, torch_weight)
|
|
|
|
|
|
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
|
|
@unittest.skipIf(F.ctx().type == "cpu", reason="gpu only test")
|
|
@pytest.mark.parametrize("num_workers", [2, 4, 8])
|
|
@pytest.mark.parametrize("backend", ["nccl", "gloo"])
|
|
def test_multiprocess_sparse_adam_zero_step(num_workers, backend):
|
|
if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
|
|
pytest.skip("Not enough GPUs to run test.")
|
|
|
|
worker_list = []
|
|
num_embs = 128
|
|
emb_dim = 10
|
|
dgl_weight = th.empty((num_embs, emb_dim))
|
|
ctx = mp.get_context("spawn")
|
|
for i in range(num_workers):
|
|
device = F.ctx()
|
|
if device.type == "cuda":
|
|
# make sure each process has a unique GPU
|
|
device = th.device(i)
|
|
p = ctx.Process(
|
|
target=start_sparse_adam_worker,
|
|
args=(
|
|
i,
|
|
device,
|
|
num_workers,
|
|
dgl_weight,
|
|
th.device("cpu"),
|
|
True,
|
|
backend,
|
|
),
|
|
)
|
|
p.start()
|
|
worker_list.append(p)
|
|
for p in worker_list:
|
|
p.join()
|
|
|
|
worker_list = []
|
|
torch_weight = th.empty((num_embs, emb_dim))
|
|
for i in range(num_workers):
|
|
p = ctx.Process(
|
|
target=start_torch_adam_worker,
|
|
args=(i, num_workers, torch_weight, False),
|
|
)
|
|
p.start()
|
|
worker_list.append(p)
|
|
for p in worker_list:
|
|
p.join()
|
|
|
|
assert F.allclose(dgl_weight, torch_weight)
|
|
|
|
|
|
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
|
|
@unittest.skipIf(
|
|
F.ctx().type == "cpu", reason="cuda tensor is not supported for cpu"
|
|
)
|
|
@pytest.mark.parametrize("num_workers", [2, 4, 8])
|
|
def test_multiprocess_sparse_adam_zero_step_cuda_tensor(num_workers):
|
|
if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
|
|
pytest.skip("Not enough GPUs to run test.")
|
|
|
|
backend = "nccl"
|
|
worker_list = []
|
|
num_embs = 128
|
|
emb_dim = 10
|
|
dgl_weight = th.empty((num_embs, emb_dim))
|
|
ctx = mp.get_context("spawn")
|
|
for i in range(num_workers):
|
|
device = th.device(i)
|
|
p = ctx.Process(
|
|
target=start_sparse_adam_worker,
|
|
args=(i, device, num_workers, dgl_weight, device, True, backend),
|
|
)
|
|
p.start()
|
|
worker_list.append(p)
|
|
for p in worker_list:
|
|
p.join()
|
|
|
|
worker_list = []
|
|
torch_weight = th.empty((num_embs, emb_dim))
|
|
for i in range(num_workers):
|
|
p = ctx.Process(
|
|
target=start_torch_adam_worker,
|
|
args=(i, num_workers, torch_weight, False),
|
|
)
|
|
p.start()
|
|
worker_list.append(p)
|
|
for p in worker_list:
|
|
p.join()
|
|
|
|
assert F.allclose(dgl_weight, torch_weight)
|
|
|
|
|
|
def start_sparse_adam_state_dict_worker(
|
|
rank,
|
|
world_size,
|
|
init_weight,
|
|
backend,
|
|
num_embs,
|
|
emb_dim,
|
|
):
|
|
print("start sparse worker for adam {}".format(rank))
|
|
dist_init_method = "tcp://{master_ip}:{master_port}".format(
|
|
master_ip="127.0.0.1", master_port="12345"
|
|
)
|
|
|
|
device = th.device(f"cuda:{rank}")
|
|
th.cuda.set_device(device)
|
|
tensor_dev = device if backend == "nccl" else th.device("cpu")
|
|
|
|
th.distributed.init_process_group(
|
|
backend=backend,
|
|
init_method=dist_init_method,
|
|
world_size=world_size,
|
|
rank=rank,
|
|
)
|
|
|
|
th.manual_seed(0)
|
|
dgl_emb = NodeEmbedding(
|
|
num_embs, emb_dim, "test", init_func=initializer, device=tensor_dev
|
|
)
|
|
dgl_emb.all_set_embedding(init_weight)
|
|
|
|
dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)
|
|
|
|
start = (num_embs // world_size) * rank
|
|
end = (num_embs // world_size) * (rank + 1)
|
|
th.manual_seed(rank)
|
|
idx = th.randint(start, end, size=(4,)).to(tensor_dev)
|
|
dgl_value = dgl_emb(idx, device)
|
|
labels = th.ones((4,)).long().to(device)
|
|
dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
|
|
dgl_adam.zero_grad()
|
|
dgl_loss.backward()
|
|
dgl_adam.step()
|
|
th.distributed.barrier()
|
|
|
|
worker_state_dict = [t.detach().clone() for t in dgl_emb.optm_state]
|
|
state_dict = dgl_adam.state_dict()
|
|
for t in dgl_emb.optm_state:
|
|
t.zero_()
|
|
dgl_adam.load_state_dict(state_dict)
|
|
|
|
for i, j in zip(worker_state_dict, dgl_emb.optm_state):
|
|
F.allclose(i, j)
|
|
|
|
th.distributed.barrier()
|
|
|
|
|
|
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
|
|
@unittest.skipIf(F.ctx().type == "cpu", reason="gpu only test")
|
|
@pytest.mark.parametrize("num_workers", [1, 2, 4, 8])
|
|
@pytest.mark.parametrize("backend", ["nccl", "gloo"])
|
|
def test_multiprocess_sparse_adam_state_dict(num_workers, backend):
|
|
if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
|
|
pytest.skip("Not enough GPUs to run test.")
|
|
|
|
num_embs = 128
|
|
emb_dim = 10
|
|
init_weight = th.rand((num_embs, emb_dim))
|
|
mp.spawn(
|
|
start_sparse_adam_state_dict_worker,
|
|
(
|
|
num_workers,
|
|
init_weight,
|
|
backend,
|
|
num_embs,
|
|
emb_dim,
|
|
),
|
|
nprocs=num_workers,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_sparse_adam(1)
|
|
test_sparse_adam(4)
|
|
test_sparse_adam(101)
|
|
test_sparse_adam(1024)
|
|
test_sparse_adam_zero_step()
|
|
|
|
test_multiprocess_cpu_sparse_adam(2)
|
|
test_multiprocess_cpu_sparse_adam(4)
|
|
test_multiprocess_cpu_sparse_adam(8)
|
|
test_multiprocess_sparse_adam_cpu_zero_step(2)
|
|
|
|
test_multiprocess_sparse_adam(2, backend="gloo")
|
|
test_multiprocess_sparse_adam(4, backend="gloo")
|
|
test_multiprocess_sparse_adam(8, backend="gloo")
|
|
test_multiprocess_sparse_adam(2, backend="nccl")
|
|
test_multiprocess_sparse_adam(4, backend="nccl")
|
|
test_multiprocess_sparse_adam(8, backend="nccl")
|
|
|
|
test_multiprocess_sparse_adam_zero_step(2, backend="gloo")
|
|
test_multiprocess_sparse_adam_zero_step(4, backend="nccl")
|
|
|
|
test_multiprocess_sparse_adam_cuda_tensor(2)
|
|
test_multiprocess_sparse_adam_zero_step_cuda_tensor(4)
|
|
|
|
test_multiprocess_sparse_adam_state_dict(2, "nccl")
|
|
test_multiprocess_sparse_adam_state_dict(2, "gloo")
|