Files
dgl/tests/python/pytorch/nn/test_sparse_emb.py
Xin Yao ae4a5b7394 [Feature] Add state_dict, load_state_dict, param_groups to dgl.optim.SparseGradOptimizer (#5311)
* init update

* all get/set optm_state

* add unit tests

* add docstring

* fix for multiple embeddings

* move embedding methods to private

* fix lint

* fix unit tests

* resolve comments

* merge master
2023-03-08 16:46:34 +08:00

134 lines
3.7 KiB
Python

import multiprocessing as mp
import os
import unittest
import backend as F
import pytest
import torch as th
from dgl.nn import NodeEmbedding
from dgl.optim import SparseAdam
def initializer(emb):
th.manual_seed(0)
emb.uniform_(-1.0, 1.0)
return emb
def check_all_set_all_get_emb(device, init_emb):
num_embs = init_emb.shape[0]
emb_dim = init_emb.shape[1]
dgl_emb = NodeEmbedding(num_embs, emb_dim, "test", device=device)
dgl_emb.all_set_embedding(init_emb)
out_emb = dgl_emb.all_get_embedding()
assert F.allclose(init_emb, out_emb)
def check_all_set_all_get_optm_state(
device, state_step, state_mem, state_power
):
num_embs = state_mem.shape[0]
emb_dim = state_mem.shape[1]
dgl_emb = NodeEmbedding(num_embs, emb_dim, "test", device=device)
optm = SparseAdam(params=[dgl_emb], lr=0.01)
dgl_emb._all_set_optm_state((state_step, state_mem, state_power))
out_step, out_mem, out_power = dgl_emb._all_get_optm_state()
assert F.allclose(state_step, out_step)
assert F.allclose(state_mem, out_mem)
assert F.allclose(state_power, out_power)
def start_sparse_worker(rank, world_size, test, args):
print("start sparse worker {}".format(rank))
dist_init_method = "tcp://{master_ip}:{master_port}".format(
master_ip="127.0.0.1", master_port="12345"
)
backend = "gloo"
device = F.ctx()
if device.type == "cuda":
device = th.device(rank)
th.cuda.set_device(device)
th.distributed.init_process_group(
backend=backend,
init_method=dist_init_method,
world_size=world_size,
rank=rank,
)
test(device, *args)
th.distributed.barrier()
th.distributed.destroy_process_group()
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("num_workers", [1, 2, 3])
def test_multiprocess_sparse_emb_get_set(num_workers):
if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
pytest.skip("Not enough GPUs to run test.")
worker_list = []
init_emb = th.rand([1000, 8])
ctx = mp.get_context("spawn")
for i in range(num_workers):
p = ctx.Process(
target=start_sparse_worker,
args=(i, num_workers, check_all_set_all_get_emb, (init_emb,)),
)
p.start()
worker_list.append(p)
for p in worker_list:
p.join()
for p in worker_list:
assert p.exitcode == 0
@unittest.skipIf(os.name == "nt", reason="Do not support windows yet")
@pytest.mark.parametrize("num_workers", [1, 2, 3])
def test_multiprocess_sparse_emb_get_set_optm_state(num_workers):
if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
pytest.skip("Not enough GPUs to run test.")
worker_list = []
num_embs, emb_dim = 1000, 8
state_step = th.randint(1000, (num_embs,))
state_mem = th.rand((num_embs, emb_dim))
state_power = th.rand((num_embs, emb_dim))
ctx = mp.get_context("spawn")
for i in range(num_workers):
p = ctx.Process(
target=start_sparse_worker,
args=(
i,
num_workers,
check_all_set_all_get_optm_state,
(state_step, state_mem, state_power),
),
)
p.start()
worker_list.append(p)
for p in worker_list:
p.join()
for p in worker_list:
assert p.exitcode == 0
if __name__ == "__main__":
# test_multiprocess_sparse_emb_get_set(1)
# test_multiprocess_sparse_emb_get_set(2)
# test_multiprocess_sparse_emb_get_set(3)
test_multiprocess_sparse_emb_get_set_optm_state(1)
# test_multiprocess_sparse_emb_get_set_optm_state(2)
# test_multiprocess_sparse_emb_get_set_optm_state(3)