diff --git a/README.md b/README.md index 9913d829b0..75945bb911 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ A summary of part of the model accuracy and training speed with the Pytorch back | [JTNN](https://arxiv.org/abs/1802.04364) | 96.44% | 96.44% | [1826s (Pytorch)](https://github.com/wengong-jin/icml18-jtnn) | 743s | 2.5x | | [LGNN](https://arxiv.org/abs/1705.08415) | 94% | 94% | n/a | 1.45s | n/a | | [DGMG](https://arxiv.org/pdf/1803.03324.pdf) | 84% | 90% | n/a | 238s | n/a | +| [GraphWriter](https://www.aclweb.org/anthology/N19-1238.pdf) | 14.3(BLEU) | 14.31(BLEU) | [1970s (PyTorch)](https://github.com/rikdz/GraphWriter) | 1192s | 1.65x | With the MXNet/Gluon backend , we scaled a graph of 50M nodes and 150M edges on a P3.8xlarge instance, with 160s per epoch, on SSE ([Stochastic Steady-state Embedding](https://www.cc.gatech.edu/~hdai8/pdf/equilibrium_embedding.pdf)), diff --git a/examples/pytorch/README.md b/examples/pytorch/README.md index 8691691873..ce851c6d5f 100644 --- a/examples/pytorch/README.md +++ b/examples/pytorch/README.md @@ -21,3 +21,4 @@ Here is a summary of the model accuracy and training speed. Our testbed is Amazo | [JTNN](https://arxiv.org/abs/1802.04364) | 96.44% | 96.44% | [1826s (Pytorch)](https://github.com/wengong-jin/icml18-jtnn) | 743s | 2.5x | | [LGNN](https://arxiv.org/abs/1705.08415) | 94% | 94% | n/a | 1.45s | n/a | | [DGMG](https://arxiv.org/pdf/1803.03324.pdf) | 84% | 90% | n/a | 238s | n/a | +| [GraphWriter](https://www.aclweb.org/anthology/N19-1238.pdf) | 14.31(BLEU) | 14.3(BLEU) | 1970s | 1192s | 1.65x | diff --git a/examples/pytorch/graphwriter/README.md b/examples/pytorch/graphwriter/README.md new file mode 100644 index 0000000000..98e0beb525 --- /dev/null +++ b/examples/pytorch/graphwriter/README.md @@ -0,0 +1,44 @@ +# GraphWriter-DGL +In this example we implement the GraphWriter, [Text Generation from Knowledge Graphs with Graph Transformers](https://arxiv.org/abs/1904.02342) in DGL. And the [author's code](https://github.com/rikdz/GraphWriter). + +## Dependencies +- PyTorch >= 1.2 +- tqdm +- pycoco (only for testing) +- multi-bleu.perl and other scripts from mosesdecoder (only for testing) + +## Usage +``` + # download data + sh prepare_data.sh + # training + sh run.sh + # testing + sh test.sh +``` + +## Result on AGENDA +| |BLEU|METEOR| training time per epoch| +|-|-|-|-| +|Author's implementation|14.3+-1.01| 18.8+-0.28| 1970s| +|DGL implementation|14.31+-0.34|19.74+-0.69| 1192s| + +We use the author's code for the speed test, and our testbed is V100 GPU. + +| |BLEU| detok BLEU| METEOR | +|-|-|-|-| +|greedy, two layers| 13.97 +- 0.40| 13.78 +- 0.46| 18.76 +- 0.36| +|beam 4, length penalty 1.0, two layers| 14.66 +- 0.65| 14.53 +- 0.52| 19.50 +- 0.49| +|beam 4, length penalty 0.0, two layers| 14.33 +- 0.39| 14.09 +- 0.39| 18.63 +- 0.52| +|greedy, six layers| 14.17 +- 0.46| 14.01 +- 0.51| 19.18 +- 0.49| +|beam 4, length penalty 1.0, six layers| 14.31 +- 0.34| 14.35 +- 0.36| 19.74 +- 0.69| +|beam 4, length penalty 0.0, six layers| 14.40 +- 0.85| 14.15 +- 0.84| 18.86 +- 0.78| + +We repeat the experiment five times. + +### Examples + +We also provide the output of our implementation on test set together with the reference text. +- [GraphWriter's output](https://s3.us-east-2.amazonaws.com/dgl.ai/models/graphwriter/tmp_pred.txt) +- [Reference text](https://s3.us-east-2.amazonaws.com/dgl.ai/models/graphwriter/tmp_gold.txt) + diff --git a/examples/pytorch/graphwriter/graphwriter.py b/examples/pytorch/graphwriter/graphwriter.py new file mode 100644 index 0000000000..f87a374522 --- /dev/null +++ b/examples/pytorch/graphwriter/graphwriter.py @@ -0,0 +1,182 @@ +import torch +from modules import MSA, BiLSTM, GraphTrans +from utlis import * +from torch import nn +import dgl + + +class GraphWriter(nn.Module): + def __init__(self, args): + super(GraphWriter, self).__init__() + self.args = args + if args.title: + self.title_emb = nn.Embedding(len(args.title_vocab), args.nhid, padding_idx=0) + self.title_enc = BiLSTM(args, enc_type='title') + self.title_attn = MSA(args) + self.ent_emb = nn.Embedding(len(args.ent_text_vocab), args.nhid, padding_idx=0) + self.tar_emb = nn.Embedding(len(args.text_vocab), args.nhid, padding_idx=0) + if args.title: + nn.init.xavier_normal_(self.title_emb.weight) + nn.init.xavier_normal_(self.ent_emb.weight) + self.rel_emb = nn.Embedding(len(args.rel_vocab), args.nhid, padding_idx=0) + nn.init.xavier_normal_(self.rel_emb.weight) + self.decode_lstm = nn.LSTMCell(args.dec_ninp, args.nhid) + self.ent_enc = BiLSTM(args, enc_type='entity') + self.graph_enc = GraphTrans(args) + self.ent_attn = MSA(args) + self.copy_attn = MSA(args, mode='copy') + self.copy_fc = nn.Linear(args.dec_ninp, 1) + self.pred_v_fc = nn.Linear(args.dec_ninp, len(args.text_vocab)) + + def enc_forward(self, batch, ent_mask, ent_text_mask, ent_len, rel_mask, title_mask): + title_enc = None + if self.args.title: + title_enc = self.title_enc(self.title_emb(batch['title']), title_mask) + ent_enc = self.ent_enc(self.ent_emb(batch['ent_text']), ent_text_mask, ent_len = batch['ent_len']) + rel_emb = self.rel_emb(batch['rel']) + g_ent, g_root = self.graph_enc(ent_enc, ent_mask, ent_len, rel_emb, rel_mask, batch['graph']) + return g_ent, g_root, title_enc, ent_enc + + def forward(self, batch, beam_size=-1): + ent_mask = len2mask(batch['ent_len'], self.args.device) + ent_text_mask = batch['ent_text']==0 + rel_mask = batch['rel']==0 # 0 means the + title_mask = batch['title']==0 + g_ent, g_root, title_enc, ent_enc = self.enc_forward(batch, ent_mask, ent_text_mask, batch['ent_len'], rel_mask, title_mask) + + _h, _c = g_root, g_root.clone().detach() + ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask) + if self.args.title: + attn = _h + self.title_attn(_h, title_enc, mask=title_mask) + ctx = torch.cat([ctx, attn], 1) + if beam_size<1: + # training + outs = [] + tar_inp = self.tar_emb(batch['text'].transpose(0,1)) + for t, xt in enumerate(tar_inp): + _xt = torch.cat([ctx, xt], 1) + _h, _c = self.decode_lstm(_xt, (_h, _c)) + ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask) + if self.args.title: + attn = _h + self.title_attn(_h, title_enc, mask=title_mask) + ctx = torch.cat([ctx, attn], 1) + outs.append(torch.cat([_h, ctx], 1)) + outs = torch.stack(outs, 1) + copy_gate = torch.sigmoid(self.copy_fc(outs)) + EPSI = 1e-6 + # copy + pred_v = torch.log(copy_gate+EPSI) + torch.log_softmax(self.pred_v_fc(outs), -1) + pred_c = torch.log((1. - copy_gate)+EPSI) + torch.log_softmax(self.copy_attn(outs, ent_enc, mask=ent_mask), -1) + pred = torch.cat([pred_v, pred_c], -1) + return pred + else: + if beam_size==1: + # greedy + device = g_ent.device + B = g_ent.shape[0] + ent_type = batch['ent_type'].view(B, -1) + seq = (torch.ones(B,).long().to(device) * self.args.text_vocab('')).unsqueeze(1) + for t in range(self.args.beam_max_len): + _inp = replace_ent(seq[:,-1], ent_type, len(self.args.text_vocab)) + xt = self.tar_emb(_inp) + _xt = torch.cat([ctx, xt], 1) + _h, _c = self.decode_lstm(_xt, (_h, _c)) + ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask) + if self.args.title: + attn = _h + self.title_attn(_h, title_enc, mask=title_mask) + ctx = torch.cat([ctx, attn], 1) + _y = torch.cat([_h, ctx], 1) + copy_gate = torch.sigmoid(self.copy_fc(_y)) + pred_v = torch.log(copy_gate) + torch.log_softmax(self.pred_v_fc(_y), -1) + pred_c = torch.log((1. - copy_gate)) + torch.log_softmax(self.copy_attn(_y.unsqueeze(1), ent_enc, mask=ent_mask).squeeze(1), -1) + pred = torch.cat([pred_v, pred_c], -1).view(B,-1) + for ban_item in ['', '', '']: + pred[:, self.args.text_vocab(ban_item)] = -1e8 + _, word = pred.max(-1) + seq = torch.cat([seq, word.unsqueeze(1)], 1) + return seq + else: + # beam search + device = g_ent.device + B = g_ent.shape[0] + BSZ = B * beam_size + _h = _h.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) + _c = _c.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) + ent_mask = ent_mask.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) + if self.args.title: + title_mask = title_mask.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) + title_enc = title_enc.view(B, 1, title_enc.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, title_enc.size(1), -1) + ctx = ctx.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) + ent_type = batch['ent_type'].view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) + g_ent = g_ent.view(B, 1, g_ent.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, g_ent.size(1), -1) + ent_enc = ent_enc.view(B, 1, ent_enc.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, ent_enc.size(1), -1) + + beam_best = torch.zeros(B).to(device) - 1e9 + beam_best_seq = [None] * B + beam_seq = (torch.ones(B, beam_size).long().to(device) * self.args.text_vocab('')).unsqueeze(-1) + beam_score = torch.zeros(B, beam_size).to(device) + done_flag = torch.zeros(B, beam_size) + for t in range(self.args.beam_max_len): + _inp = replace_ent(beam_seq[:,:,-1].view(-1), ent_type, len(self.args.text_vocab)) + xt = self.tar_emb(_inp) + _xt = torch.cat([ctx, xt], 1) + _h, _c = self.decode_lstm(_xt, (_h, _c)) + ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask) + if self.args.title: + attn = _h + self.title_attn(_h, title_enc, mask=title_mask) + ctx = torch.cat([ctx, attn], 1) + _y = torch.cat([_h, ctx], 1) + copy_gate = torch.sigmoid(self.copy_fc(_y)) + pred_v = torch.log(copy_gate) + torch.log_softmax(self.pred_v_fc(_y), -1) + pred_c = torch.log((1. - copy_gate)) + torch.log_softmax(self.copy_attn(_y.unsqueeze(1), ent_enc, mask=ent_mask).squeeze(1), -1) + pred = torch.cat([pred_v, pred_c], -1).view(B, beam_size, -1) + for ban_item in ['', '', '']: + pred[:, :, self.args.text_vocab(ban_item)] = -1e8 + if t==self.args.beam_max_len-1: # force ending + tt = pred[:, :, self.args.text_vocab('')] + pred = pred*0-1e8 + pred[:, :, self.args.text_vocab('')] = tt + cum_score = beam_score.view(B,beam_size,1) + pred + score, word = cum_score.topk(dim=-1, k=beam_size) # B, beam_size, beam_size + score, word = score.view(B,-1), word.view(B,-1) + eos_idx = self.args.text_vocab('') + if beam_seq.size(2)==1: + new_idx = torch.arange(beam_size).to(word) + new_idx = new_idx[None,:].repeat(B,1) + else: + _, new_idx = score.topk(dim=-1, k=beam_size) + new_src, new_score, new_word, new_done = [], [], [], [] + LP = beam_seq.size(2) ** self.args.lp + for i in range(B): + for j in range(beam_size): + tmp_score = score[i][new_idx[i][j]] + tmp_word = word[i][new_idx[i][j]] + src_idx = new_idx[i][j]//beam_size + new_src.append(src_idx) + if tmp_word == eos_idx: + new_score.append(-1e8) + else: + new_score.append(tmp_score) + new_word.append(tmp_word) + + if tmp_word == eos_idx and done_flag[i][src_idx]==0 and tmp_score/LP>beam_best[i]: + beam_best[i] = tmp_score/LP + beam_best_seq[i] = beam_seq[i][src_idx] + if tmp_word == eos_idx: + new_done.append(1) + else: + new_done.append(done_flag[i][src_idx]) + new_score = torch.Tensor(new_score).view(B,beam_size).to(beam_score) + new_word = torch.Tensor(new_word).view(B,beam_size).to(beam_seq) + new_src = torch.LongTensor(new_src).view(B,beam_size).to(device) + new_done = torch.Tensor(new_done).view(B,beam_size).to(done_flag) + beam_score = new_score + done_flag = new_done + beam_seq = beam_seq.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src] + beam_seq = torch.cat([beam_seq, new_word.unsqueeze(2)], 2) + _h = _h.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1) + _c = _c.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1) + ctx = ctx.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1) + + return beam_best_seq + diff --git a/examples/pytorch/graphwriter/modules.py b/examples/pytorch/graphwriter/modules.py new file mode 100755 index 0000000000..36efc3fa9f --- /dev/null +++ b/examples/pytorch/graphwriter/modules.py @@ -0,0 +1,164 @@ +import torch +import math +import dgl.function as fn +from dgl.nn.pytorch import edge_softmax +from utlis import * +from torch import nn +import torch.nn.functional as F +from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence + + +class MSA(nn.Module): + # multi-head self-attention, three modes + # the first is the copy, determining which entity should be copied. + # the second is the normal attention with two sequence inputs + # the third is the attention but with one token and a sequence. (gather, attentive pooling) + + def __init__(self, args, mode='normal'): + super(MSA, self).__init__() + if mode=='copy': + nhead, head_dim = 1, args.nhid + qninp, kninp = args.dec_ninp, args.nhid + if mode=='normal': + nhead, head_dim = args.nhead, args.head_dim + qninp, kninp = args.nhid, args.nhid + self.attn_drop = nn.Dropout(0.1) + self.WQ = nn.Linear(qninp, nhead*head_dim, bias=True if mode=='copy' else False) + if mode!='copy': + self.WK = nn.Linear(kninp, nhead*head_dim, bias=False) + self.WV = nn.Linear(kninp, nhead*head_dim, bias=False) + self.args, self.nhead, self.head_dim, self.mode = args, nhead, head_dim, mode + + def forward(self, inp1, inp2, mask=None): + B, L2, H = inp2.shape + NH, HD = self.nhead, self.head_dim + if self.mode=='copy': + q, k, v = self.WQ(inp1), inp2, inp2 + else: + q, k, v = self.WQ(inp1), self.WK(inp2), self.WV(inp2) + L1 = 1 if inp1.ndim==2 else inp1.shape[1] + if self.mode!='copy': + q = q / math.sqrt(H) + q = q.view(B, L1, NH, HD).permute(0, 2, 1, 3) + k = k.view(B, L2, NH, HD).permute(0, 2, 3, 1) + v = v.view(B, L2, NH, HD).permute(0, 2, 1, 3) + pre_attn = torch.matmul(q,k) + if mask is not None: + pre_attn = pre_attn.masked_fill(mask[:,None,None,:], -1e8) + if self.mode=='copy': + return pre_attn.squeeze(1) + else: + alpha = self.attn_drop(torch.softmax(pre_attn, -1)) + attn = torch.matmul(alpha, v).permute(0, 2, 1, 3).contiguous().view(B,L1,NH*HD) + ret = attn + if inp1.ndim==2: + return ret.squeeze(1) + else: + return ret + + +class BiLSTM(nn.Module): + # for entity encoding or the title encoding + def __init__(self, args, enc_type='title'): + super(BiLSTM, self).__init__() + self.enc_type = enc_type + self.drop = nn.Dropout(args.emb_drop) + self.bilstm = nn.LSTM(args.nhid, args.nhid//2, bidirectional=True, \ + num_layers=args.enc_lstm_layers, batch_first=True) + + def forward(self, inp, mask, ent_len=None): + inp = self.drop(inp) + lens = (mask==0).sum(-1).long().tolist() + pad_seq = pack_padded_sequence(inp, lens, batch_first=True, enforce_sorted=False) + y, (_h, _c) = self.bilstm(pad_seq) + if self.enc_type=='title': + y = pad_packed_sequence(y, batch_first=True)[0] + return y + if self.enc_type=='entity': + _h = _h.transpose(0,1).contiguous() + _h = _h[:,-2:].view(_h.size(0), -1) # two directions of the top-layer + ret = pad(_h.split(ent_len), out_type='tensor') + return ret + + +class GAT(nn.Module): + # a graph attention network with dot-product attention + def __init__(self, + in_feats, + out_feats, + num_heads, + ffn_drop=0., + attn_drop=0., + trans=True): + super(GAT, self).__init__() + self._num_heads = num_heads + self._in_feats = in_feats + self._out_feats = out_feats + self.q_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False) + self.k_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False) + self.v_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False) + self.attn_drop = nn.Dropout(0.1) + self.ln1 = nn.LayerNorm(in_feats) + self.ln2 = nn.LayerNorm(in_feats) + if trans: + self.FFN = nn.Sequential( + nn.Linear(in_feats, 4*in_feats), + nn.PReLU(4*in_feats), + nn.Linear(4*in_feats, in_feats), + nn.Dropout(0.1), + ) + # a strange FFN, see the author's code + self._trans = trans + + def forward(self, graph, feat): + graph = graph.local_var() + feat_c = feat.clone().detach().requires_grad_(False) + q, k, v = self.q_proj(feat), self.k_proj(feat_c), self.v_proj(feat_c) + q = q.view(-1, self._num_heads, self._out_feats) + k = k.view(-1, self._num_heads, self._out_feats) + v = v.view(-1, self._num_heads, self._out_feats) + graph.ndata.update({'ft': v, 'el': k, 'er': q}) # k,q instead of q,k, the edge_softmax is applied on incoming edges + # compute edge attention + graph.apply_edges(fn.u_dot_v('el', 'er', 'e')) + e = graph.edata.pop('e') / math.sqrt(self._out_feats * self._num_heads) + graph.edata['a'] = edge_softmax(graph, e).unsqueeze(-1) + # message passing + graph.update_all(fn.u_mul_e('ft', 'a', 'm'), + fn.sum('m', 'ft2')) + rst = graph.ndata['ft2'] + # residual + rst = rst.view(feat.shape) + feat + if self._trans: + rst = self.ln1(rst) + rst = self.ln1(rst+self.FFN(rst)) + # use the same layer norm, see the author's code + return rst + + +class GraphTrans(nn.Module): + def __init__(self,args): + super().__init__() + self.args = args + if args.graph_enc == "gat": + # we only support gtrans, don't use this one + self.gat = nn.ModuleList([GAT(args.nhid, args.nhid//4, 4, attn_drop=args.attn_drop, trans=False) for _ in range(args.prop)]) #untested + else: + self.gat = nn.ModuleList([GAT(args.nhid, args.nhid//4, 4, attn_drop=args.attn_drop, ffn_drop=args.drop, trans=True) for _ in range(args.prop)]) + self.prop = args.prop + + def forward(self, ent, ent_mask, ent_len, rel, rel_mask, graphs): + device = ent.device + ent_mask = (ent_mask==0) # reverse mask + rel_mask = (rel_mask==0) + init_h = [] + for i in range(graphs.batch_size): + init_h.append(ent[i][ent_mask[i]]) + init_h.append(rel[i][rel_mask[i]]) + init_h = torch.cat(init_h, 0) + feats = init_h + for i in range(self.prop): + feats = self.gat[i](graphs, feats) + g_root = feats.index_select(0, graphs.filter_nodes(lambda x: x.data['type']==NODE_TYPE['root']).to(device)) + g_ent = pad(feats.index_select(0, graphs.filter_nodes(lambda x: x.data['type']==NODE_TYPE['entity']).to(device)).split(ent_len), out_type='tensor') + return g_ent, g_root + diff --git a/examples/pytorch/graphwriter/opts.py b/examples/pytorch/graphwriter/opts.py new file mode 100644 index 0000000000..5309e0b306 --- /dev/null +++ b/examples/pytorch/graphwriter/opts.py @@ -0,0 +1,55 @@ +import torch +import argparse + + +def fill_config(args): + # dirty work + args.device = torch.device(args.gpu) + args.dec_ninp = args.nhid * 3 if args.title else args.nhid * 2 + args.fnames = [args.train_file, args.valid_file, args.test_file] + return args + + +def vocab_config(args, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab): + # dirty work + args.ent_vocab = ent_vocab + args.rel_vocab = rel_vocab + args.text_vocab = text_vocab + args.ent_text_vocab = ent_text_vocab + args.title_vocab = title_vocab + return args + + +def get_args(): + args = argparse.ArgumentParser(description='Graph Writer in DGL') + args.add_argument('--nhid', default=500, type=int, help='hidden size') + args.add_argument('--nhead', default=4, type=int, help='number of heads') + args.add_argument('--head_dim', default=125, type=int, help='head dim') + args.add_argument('--weight_decay', default=0.0, type=float, help='weight decay') + args.add_argument('--prop', default=6, type=int, help='number of layers of gnn') + args.add_argument('--title', action='store_true', help='use title input') + args.add_argument('--test', action='store_true', help='inference mode') + args.add_argument('--batch_size', default=32, type=int, help='batch_size') + args.add_argument('--beam_size', default=4, type=int, help='beam size, 1 for greedy') + args.add_argument('--epoch', default=20, type=int, help='training epoch') + args.add_argument('--beam_max_len', default=200, type=int, help='max length of the generated text') + args.add_argument('--enc_lstm_layers', default=2, type=int, help='number of layers of lstm') + args.add_argument('--lr', default=1e-1, type=float, help='learning rate') + #args.add_argument('--lr_decay', default=1e-8, type=float, help='') + args.add_argument('--clip', default=1, type=float, help='gradient clip') + args.add_argument('--emb_drop', default=0.0, type=float, help='embedding dropout') + args.add_argument('--attn_drop', default=0.1, type=float, help='attention dropout') + args.add_argument('--drop', default=0.1, type=float, help='dropout') + args.add_argument('--lp', default=1.0, type=float, help='length penalty') + args.add_argument('--graph_enc', default='gtrans', type=str, help='gnn mode, we only support the graph transformer now') + args.add_argument('--train_file', default='data/unprocessed.train.json', type=str, help='training file') + args.add_argument('--valid_file', default='data/unprocessed.val.json', type=str, help='validation file') + args.add_argument('--test_file', default='data/unprocessed.test.json', type=str, help='test file') + args.add_argument('--save_dataset', default='data.pickle', type=str, help='save path of dataset') + args.add_argument('--save_model', default='saved_model.pt', type=str, help='save path of model') + + args.add_argument('--gpu', default=0, type=int, help='gpu mode') + args = args.parse_args() + args = fill_config(args) + return args + diff --git a/examples/pytorch/graphwriter/prepare_data.sh b/examples/pytorch/graphwriter/prepare_data.sh new file mode 100644 index 0000000000..3e03aa059e --- /dev/null +++ b/examples/pytorch/graphwriter/prepare_data.sh @@ -0,0 +1,3 @@ +wget https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/AGENDA.tar.gz +mkdir data +tar -C data/ -xvzf AGENDA.tar.gz diff --git a/examples/pytorch/graphwriter/run.sh b/examples/pytorch/graphwriter/run.sh new file mode 100644 index 0000000000..ad54b1caf6 --- /dev/null +++ b/examples/pytorch/graphwriter/run.sh @@ -0,0 +1,6 @@ +nohup env CUDA_VISIBLE_DEVICES=0 python -u train.py --prop 6 --save_model tmp_model.pt --title > train_1.log 2>&1 & +#nohup env CUDA_VISIBLE_DEVICES=2 python -u train.py --prop 6 --save_model tmp_model1.pt --title > train_2.log 2>&1 & +#nohup env CUDA_VISIBLE_DEVICES=3 python -u train.py --prop 6 --save_model tmp_model2.pt --title > train_3.log 2>&1 & +#nohup env CUDA_VISIBLE_DEVICES=4 python -u train.py --prop 6 --save_model tmp_model3.pt --title > train_4.log 2>&1 & +#nohup env CUDA_VISIBLE_DEVICES=5 python -u train.py --prop 2 --save_model tmp_model4.pt --title > train_5.log 2>&1 & +#nohup env CUDA_VISIBLE_DEVICES=6 python -u train.py --prop 2 --save_model tmp_model5.pt --title > train_6.log 2>&1 & diff --git a/examples/pytorch/graphwriter/test.sh b/examples/pytorch/graphwriter/test.sh new file mode 100644 index 0000000000..3b316ab3e0 --- /dev/null +++ b/examples/pytorch/graphwriter/test.sh @@ -0,0 +1,11 @@ +env CUDA_VISIBLE_DEVICES=0 python -u train.py --save_model tmp_model.ptbest --test --title --lp 1.0 --beam_size 1 +if [ ! detokenizer.perl ]; then + wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/8c5eaa1a122236bbf927bde4ec610906fea599e6/scripts/tokenizer/detokenizer.perl +fi +if [ ! multi-bleu.perl ]; then + wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/8c5eaa1a122236bbf927bde4ec610906fea599e6/scripts/generic/multi-bleu.perl +fi +perl detokenizer.perl -l en < tmp_gold.txt > tmp_gold.txt.a +perl detokenizer.perl -l en < tmp_pred.txt > tmp_pred.txt.a +perl multi-bleu.perl tmp_gold.txt < tmp_pred.txt +perl multi-bleu-detok.perl tmp_gold.txt.a < tmp_pred.txt.a diff --git a/examples/pytorch/graphwriter/train.py b/examples/pytorch/graphwriter/train.py new file mode 100644 index 0000000000..2d2291db51 --- /dev/null +++ b/examples/pytorch/graphwriter/train.py @@ -0,0 +1,125 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +import numpy as np +import time +from tqdm import tqdm +from graphwriter import * +from utlis import * +from opts import * +import os +import sys + +sys.path.append('./pycocoevalcap') +from pycocoevalcap.bleu.bleu import Bleu +from pycocoevalcap.rouge.rouge import Rouge +from pycocoevalcap.meteor.meteor import Meteor + + +def train_one_epoch(model, dataloader, optimizer, args, epoch): + model.train() + tloss = 0. + tcnt = 0. + st_time = time.time() + with tqdm(dataloader, desc='Train Ep '+str(epoch), mininterval=60) as tq: + for batch in tq: + pred = model(batch) + nll_loss = F.nll_loss(pred.view(-1, pred.shape[-1]), batch['tgt_text'].view(-1), ignore_index=0) + loss = nll_loss + optimizer.zero_grad() + loss.backward() + nn.utils.clip_grad_norm_(model.parameters(), args.clip) + optimizer.step() + loss = loss.item() + if loss!=loss: + raise ValueError('NaN appear') + tloss += loss * len(batch['tgt_text']) + tcnt += len(batch['tgt_text']) + tq.set_postfix({'loss': tloss/tcnt}, refresh=False) + print('Train Ep ', str(epoch), 'AVG Loss ', tloss/tcnt, 'Steps ', tcnt, 'Time ', time.time()-st_time, 'GPU', torch.cuda.max_memory_cached()/1024.0/1024.0/1024.0) + torch.save(model, args.save_model+str(epoch%100)) + + +val_loss = 2**31 +def eval_it(model, dataloader, args, epoch): + global val_loss + model.eval() + tloss = 0. + tcnt = 0. + st_time = time.time() + with tqdm(dataloader, desc='Eval Ep '+str(epoch), mininterval=60) as tq: + for batch in tq: + with torch.no_grad(): + pred = model(batch) + nll_loss = F.nll_loss(pred.view(-1, pred.shape[-1]), batch['tgt_text'].view(-1), ignore_index=0) + loss = nll_loss + loss = loss.item() + tloss += loss * len(batch['tgt_text']) + tcnt += len(batch['tgt_text']) + tq.set_postfix({'loss': tloss/tcnt}, refresh=False) + print('Eval Ep ', str(epoch), 'AVG Loss ', tloss/tcnt, 'Steps ', tcnt, 'Time ', time.time()-st_time) + if tloss/tcnt < val_loss: + print('Saving best model ', 'Ep ', epoch, ' loss ', tloss/tcnt) + torch.save(model, args.save_model+'best') + val_loss = tloss/tcnt + + +def test(model, dataloader, args): + scorer = Bleu(4) + m_scorer = Meteor() + r_scorer = Rouge() + hyp = [] + ref = [] + model.eval() + gold_file = open('tmp_gold.txt', 'w') + pred_file = open('tmp_pred.txt', 'w') + with tqdm(dataloader, desc='Test ', mininterval=1) as tq: + for batch in tq: + with torch.no_grad(): + seq = model(batch, beam_size=args.beam_size) + r = write_txt(batch, batch['tgt_text'], gold_file, args) + h = write_txt(batch, seq, pred_file, args) + hyp.extend(h) + ref.extend(r) + hyp = dict(zip(range(len(hyp)), hyp)) + ref = dict(zip(range(len(ref)), ref)) + print(hyp[0], ref[0]) + print('BLEU INP', len(hyp), len(ref)) + print('BLEU', scorer.compute_score(ref, hyp)[0]) + print('METEOR', m_scorer.compute_score(ref, hyp)[0]) + print('ROUGE_L', r_scorer.compute_score(ref, hyp)[0]) + gold_file.close() + pred_file.close() + + +def main(args): + if os.path.exists(args.save_dataset): + train_dataset, valid_dataset, test_dataset = pickle.load(open(args.save_dataset, 'rb')) + else: + train_dataset, valid_dataset, test_dataset = get_datasets(args.fnames, device=args.device, save=args.save_dataset) + args = vocab_config(args, train_dataset.ent_vocab, train_dataset.rel_vocab, train_dataset.text_vocab, train_dataset.ent_text_vocab, train_dataset.title_vocab) + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler = BucketSampler(train_dataset, batch_size=args.batch_size), \ + collate_fn=train_dataset.batch_fn) + valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, \ + shuffle=False, collate_fn=train_dataset.batch_fn) + test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, \ + shuffle=False, collate_fn=train_dataset.batch_fn) + + model = GraphWriter(args) + model.to(args.device) + if args.test: + model = torch.load(args.save_model) + model.args = args + print(model) + test(model, test_dataloader, args) + else: + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9) + print(model) + for epoch in range(args.epoch): + train_one_epoch(model, train_dataloader, optimizer, args, epoch) + eval_it(model, valid_dataloader, args, epoch) + + +if __name__ == '__main__': + args = get_args() + main(args) diff --git a/examples/pytorch/graphwriter/utlis.py b/examples/pytorch/graphwriter/utlis.py new file mode 100644 index 0000000000..62701c8e14 --- /dev/null +++ b/examples/pytorch/graphwriter/utlis.py @@ -0,0 +1,321 @@ +import torch +import dgl +import numpy as np +import json +import pickle +import random + + +NODE_TYPE = {'entity': 0, 'root': 1, 'relation':2} + + +def write_txt(batch, seqs, w_file, args): + # converting the prediction to real text. + ret = [] + for b, seq in enumerate(seqs): + txt = [] + for token in seq: + # copy the entity + if token>=len(args.text_vocab): + ent_text = batch['raw_ent_text'][b][token-len(args.text_vocab)] + ent_text = filter(lambda x:x!='', ent_text) + txt.extend(ent_text) + else: + if int(token) not in [args.text_vocab(x) for x in ['', '', '']]: + txt.append(args.text_vocab(int(token))) + if int(token) == args.text_vocab(''): + break + w_file.write(' '.join([str(x) for x in txt])+'\n') + ret.append([' '.join([str(x) for x in txt])]) + return ret + + +def replace_ent(x, ent, V): + # replace the entity + mask = x>=V + if mask.sum()==0: + return x + nz = mask.nonzero() + fill_ent = ent[nz, x[mask]-V] + x = x.masked_scatter(mask, fill_ent) + return x + + +def len2mask(lens, device): + max_len = max(lens) + mask = torch.arange(max_len, device=device).unsqueeze(0).expand(len(lens), max_len) + mask = mask >= torch.LongTensor(lens).to(mask).unsqueeze(1) + return mask + + +def pad(var_len_list, out_type='list', flatten=False): + if flatten: + lens = [len(x) for x in var_len_list] + var_len_list = sum(var_len_list, []) + max_len = max([len(x) for x in var_len_list]) + if out_type=='list': + if flatten: + return [x+['']*(max_len-len(x)) for x in var_len_list], lens + else: + return [x+['']*(max_len-len(x)) for x in var_len_list] + if out_type=='tensor': + if flatten: + return torch.stack([torch.cat([x, \ + torch.zeros([max_len-len(x)]+list(x.shape[1:])).type_as(x)], 0) for x in var_len_list], 0), lens + else: + return torch.stack([torch.cat([x, \ + torch.zeros([max_len-len(x)]+list(x.shape[1:])).type_as(x)], 0) for x in var_len_list], 0) + + +class Vocab(object): + def __init__(self, max_vocab=2**31, min_freq=-1, sp=['', '', '', '']): + self.i2s = [] + self.s2i = {} + self.wf = {} + self.max_vocab, self.min_freq, self.sp = max_vocab, min_freq, sp + + def __len__(self): + return len(self.i2s) + + def __str__(self): + return 'Total ' + str(len(self.i2s)) + str(self.i2s[:10]) + + def update(self, token): + if isinstance(token, list): + for t in token: + self.update(t) + else: + self.wf[token] = self.wf.get(token, 0) + 1 + + def build(self): + self.i2s.extend(self.sp) + sort_kv = sorted(self.wf.items(), key=lambda x:x[1], reverse=True) + for k,v in sort_kv: + if len(self.i2s)=self.min_freq and k not in self.sp: + self.i2s.append(k) + self.s2i.update(list(zip(self.i2s, range(len(self.i2s))))) + + def __call__(self, x): + if isinstance(x, int): + return self.i2s[x] + else: + return self.s2i.get(x, self.s2i['']) + + def save(self, fname): + pass + + def load(self, fname): + pass + +def at_least(x): + # handling the illegal data + if len(x) == 0: + return [''] + else: + return x + +class Example(object): + def __init__(self, title, ent_text, ent_type, rel, text): + # one object corresponds to a data sample + self.raw_title = title.split() + self.raw_ent_text = [at_least(x.split()) for x in ent_text] + assert min([len(x) for x in self.raw_ent_text])>0, str(self.raw_ent_text) + self.raw_ent_type = ent_type.split() # .. <> + self.raw_rel = [] + for r in rel: + rel_list = r.split() + for i in range(len(rel_list)): + if i>0 and i0: + graph.add_edges(*list(map(list, zip(*adj_edges)))) + return graph + + def get_tensor(self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab): + if hasattr(self, '_cached_tensor'): + return self._cached_tensor + else: + title_data = [''] + self.raw_title + [''] + title = [title_vocab(x) for x in title_data] + ent_text = [[ent_text_vocab(y) for y in x] for x in self.raw_ent_text] + ent_type = [text_vocab(x) for x in self.raw_ent_type] # for inference + rel_data = ['--root--'] + sum([[x[1],x[1]+'_INV'] for x in self.raw_rel], []) + rel = [rel_vocab(x) for x in rel_data] + + text_data = [''] + self.raw_text + [''] + text = [text_vocab(x) for x in text_data] + tgt_text = [] + # the input text and decoding target are different since the consideration of the copy mechanism. + for i, str1 in enumerate(text_data): + if str1[0]=='<' and str1[-1]=='>' and '_' in str1: + a, b = str1[1:-1].split('_') + text[i] = text_vocab('<'+a+'>') + tgt_text.append(len(text_vocab)+int(b)) + else: + tgt_text.append(text[i]) + self._cached_tensor = {'title': torch.LongTensor(title), 'ent_text': [torch.LongTensor(x) for x in ent_text], \ + 'ent_type': torch.LongTensor(ent_type), 'rel': torch.LongTensor(rel), \ + 'text': torch.LongTensor(text[:-1]), 'tgt_text': torch.LongTensor(tgt_text[1:]), 'graph': self.graph, 'raw_ent_text': self.raw_ent_text} + return self._cached_tensor + + def update_vocab(self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab): + ent_vocab.update(self.raw_ent_type) + ent_text_vocab.update(self.raw_ent_text) + title_vocab.update(self.raw_title) + rel_vocab.update(['--root--']+[x[1] for x in self.raw_rel]+[x[1]+'_INV' for x in self.raw_rel]) + text_vocab.update(self.raw_ent_type) + text_vocab.update(self.raw_text) + +class BucketSampler(torch.utils.data.Sampler): + def __init__(self, data_source, batch_size=32, bucket=3): + self.data_source = data_source + self.bucket = bucket + self.batch_size = batch_size + + def __iter__(self): + # the magic number comes from the author's code + perm = torch.randperm(len(self.data_source)) + lens = torch.Tensor([len(x) for x in self.data_source]) + lens = lens[perm] + t1 = [] + t2 = [] + t3 = [] + for i, l in enumerate(lens): + if (l<100): + t1.append(perm[i]) + elif (l>100 and l<220): + t2.append(perm[i]) + else: + t3.append(perm[i]) + datas = [t1,t2,t3] + random.shuffle(datas) + idxs = sum(datas, []) + batch = [] + for idx in idxs: + batch.append(idx) + mlen = max([0]+[lens[x] for x in batch]) + if (mlen<100 and len(batch) == 32) or (mlen>100 and mlen<220 and len(batch) >= 24) or (mlen>220 and len(batch)>=8) or len(batch)==32: + yield batch + batch = [] + if len(batch) > 0: + yield batch + + def __len__(self): + return (len(self.data_source)+self.batch_size-1)//self.batch_size + + +class GWdataset(torch.utils.data.Dataset): + def __init__(self, exs, ent_vocab=None, rel_vocab=None, text_vocab=None, ent_text_vocab=None, title_vocab=None, device=None): + super(GWdataset, self).__init__() + self.exs = exs + self.ent_vocab, self.rel_vocab, self.text_vocab, self.ent_text_vocab, self.title_vocab, self.device = \ + ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab, device + + def __iter__(self): + return iter(self.exs) + + def __getitem__(self, index): + return self.exs[index] + + def __len__(self): + return len(self.exs) + + def batch_fn(self, batch_ex): + batch_title, batch_ent_text, batch_ent_type, batch_rel, batch_text, batch_tgt_text, batch_graph = \ + [], [], [], [], [], [], [] + batch_raw_ent_text = [] + for ex in batch_ex: + ex_data = ex.get_tensor(self.ent_vocab, self.rel_vocab, self.text_vocab, self.ent_text_vocab, self.title_vocab) + batch_title.append(ex_data['title']) + batch_ent_text.append(ex_data['ent_text']) + batch_ent_type.append(ex_data['ent_type']) + batch_rel.append(ex_data['rel']) + batch_text.append(ex_data['text']) + batch_tgt_text.append(ex_data['tgt_text']) + batch_graph.append(ex_data['graph']) + batch_raw_ent_text.append(ex_data['raw_ent_text']) + batch_title = pad(batch_title, out_type='tensor') + batch_ent_text, ent_len = pad(batch_ent_text, out_type='tensor', flatten=True) + batch_ent_type = pad(batch_ent_type, out_type='tensor') + batch_rel = pad(batch_rel, out_type='tensor') + batch_text = pad(batch_text, out_type='tensor') + batch_tgt_text = pad(batch_tgt_text, out_type='tensor') + batch_graph = dgl.batch(batch_graph) + batch_graph.to(self.device) + return {'title': batch_title.to(self.device), 'ent_text': batch_ent_text.to(self.device), 'ent_len': ent_len, \ + 'ent_type': batch_ent_type.to(self.device), 'rel': batch_rel.to(self.device), 'text': batch_text.to(self.device), \ + 'tgt_text': batch_tgt_text.to(self.device), 'graph': batch_graph, 'raw_ent_text': batch_raw_ent_text} + + +def get_datasets(fnames, min_freq=-1, sep=';', joint_vocab=True, device=None, save='tmp.pickle'): + # min_freq : not support now since it's very sensitive to the final results, but you can set it via passing min_freq to the Vocab class. + # sep : not support now + # joint_vocab : not support now + ent_vocab = Vocab(sp=['', '']) + title_vocab = Vocab(min_freq=5) + rel_vocab = Vocab(sp=['', '']) + text_vocab = Vocab(min_freq=5) + ent_text_vocab = Vocab(sp=['', '']) + datasets = [] + for fname in fnames: + exs = [] + json_datas = json.loads(open(fname).read()) + for json_data in json_datas: + # construct one data example + ex = Example.from_json(json_data) + if fname == fnames[0]: # only training set + ex.update_vocab(ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab) + exs.append(ex) + datasets.append(exs) + ent_vocab.build() + rel_vocab.build() + text_vocab.build() + ent_text_vocab.build() + title_vocab.build() + datasets = [GWdataset(exs, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab, device) for exs in datasets] + with open(save, 'wb') as f: + pickle.dump(datasets, f) + return datasets + + +if __name__ == '__main__' : + ds = get_datasets(['data/unprocessed.val.json', 'data/unprocessed.val.json', 'data/unprocessed.test.json']) + print(ds[0].exs[0]) + print(ds[0].exs[0].get_tensor(ds[0].ent_vocab, ds[0].rel_vocab, ds[0].text_vocab, ds[0].ent_text_vocab, ds[0].title_vocab)) + diff --git a/examples/pytorch/model_zoo/chem/README.md b/examples/pytorch/model_zoo/chem/README.md index 1f7e545123..66f09de5ca 100644 --- a/examples/pytorch/model_zoo/chem/README.md +++ b/examples/pytorch/model_zoo/chem/README.md @@ -63,16 +63,16 @@ as front end and Set2Set for output prediction. ### Example Usage of Pre-trained Models ```python -from dgl.data.chem import Tox21 +from dgl.data.chem import Tox21, smiles_to_bigraph, CanonicalAtomFeaturizer from dgl import model_zoo -dataset = Tox21() +dataset = Tox21(smiles_to_bigraph, CanonicalAtomFeaturizer()) model = model_zoo.chem.load_pretrained('GCN_Tox21') # Pretrained model loaded model.eval() smiles, g, label, mask = dataset[0] feats = g.ndata.pop('h') -label_pred = model(feats, g) +label_pred = model(g, feats) print(smiles) # CCOc1ccc2nc(S(N)(=O)=O)sc2c1 print(label_pred[:, mask != 0]) # Mask non-existing labels # tensor([[-0.7956, 0.4054, 0.4288, -0.5565, -0.0911, diff --git a/python/dgl/convert.py b/python/dgl/convert.py index b7a8e2dff8..bcb5a6cf1e 100644 --- a/python/dgl/convert.py +++ b/python/dgl/convert.py @@ -21,7 +21,7 @@ __all__ = [ 'to_networkx', ] -def graph(data, ntype='_N', etype='_E', card=None, **kwargs): +def graph(data, ntype='_N', etype='_E', card=None, validate=False, **kwargs): """Create a graph with one type of nodes and edges. In the sparse matrix perspective, :func:`dgl.graph` creates a graph @@ -45,6 +45,10 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs): card : int, optional Cardinality (number of nodes in the graph). If None, infer from input data, i.e. the largest node ID plus 1. (Default: None) + validate : bool, optional + If True, check if node ids are within cardinality, the check process may take + some time. + If False and card is not None, user would receive a warning. (Default: False) kwargs : key-word arguments, optional Other key word arguments. Only comes into effect when we are using a NetworkX graph. It can consist of: @@ -101,6 +105,16 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs): ['follows'] >>> g.canonical_etypes [('user', 'follows', 'user')] + + Check if node ids are within cardinality + + >>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=2, validate=True) + ... + dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2). + >>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=3, validate=True) + Graph(num_nodes=3, num_edges=3, + ndata_schemes={} + edata_schemes={}) """ if card is not None: urange, vrange = card, card @@ -108,9 +122,9 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs): urange, vrange = None, None if isinstance(data, tuple): u, v = data - return create_from_edges(u, v, ntype, etype, ntype, urange, vrange) + return create_from_edges(u, v, ntype, etype, ntype, urange, vrange, validate) elif isinstance(data, list): - return create_from_edge_list(data, ntype, etype, ntype, urange, vrange) + return create_from_edge_list(data, ntype, etype, ntype, urange, vrange, validate) elif isinstance(data, sp.sparse.spmatrix): return create_from_scipy(data, ntype, etype, ntype) elif isinstance(data, nx.Graph): @@ -118,7 +132,7 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs): else: raise DGLError('Unsupported graph data type:', type(data)) -def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs): +def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=False, **kwargs): """Create a bipartite graph. The result graph is directed and edges must be from ``utype`` nodes @@ -147,6 +161,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs): card : pair of int, optional Cardinality (number of nodes in the source and destination group). If None, infer from input data, i.e. the largest node ID plus 1 for each type. (Default: None) + validate : bool, optional + If True, check if node ids are within cardinality, the check process may take + some time. + If False and card is not None, user would receive a warning. (Default: False) kwargs : key-word arguments, optional Other key word arguments. Only comes into effect when we are using a NetworkX graph. It can consist of: @@ -215,6 +233,16 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs): 4 >>> g.edges() (tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3])) + + Check if node ids are within cardinality + >>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(2, 4), validate=True) + ... + dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2). + >>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(3, 4), validate=True) + >>> g + Graph(num_nodes={'_U': 3, '_V': 4}, + num_edges={('_U', '_E', '_V'): 3}, + metagraph=[('_U', '_V')]) """ if utype == vtype: raise DGLError('utype should not be equal to vtype. Use ``dgl.graph`` instead.') @@ -224,9 +252,9 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs): urange, vrange = None, None if isinstance(data, tuple): u, v = data - return create_from_edges(u, v, utype, etype, vtype, urange, vrange) + return create_from_edges(u, v, utype, etype, vtype, urange, vrange, validate) elif isinstance(data, list): - return create_from_edge_list(data, utype, etype, vtype, urange, vrange) + return create_from_edge_list(data, utype, etype, vtype, urange, vrange, validate) elif isinstance(data, sp.sparse.spmatrix): return create_from_scipy(data, utype, etype, vtype) elif isinstance(data, nx.Graph): @@ -667,7 +695,7 @@ def to_homo(G): # Internal APIs ############################################################ -def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None): +def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=False): """Internal function to create a graph from incident nodes with types. utype could be equal to vtype @@ -690,6 +718,8 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None): vrange : int, optional The destination node ID range. If None, the value is the maximum of the destination node IDs in the edge list plus 1. (Default: None) + validate : bool, optional + If True, checks if node IDs are within range. Returns ------- @@ -697,6 +727,13 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None): """ u = utils.toindex(u) v = utils.toindex(v) + if validate: + if urange is not None and urange <= int(F.asnumpy(F.max(u.tousertensor(), dim=0))): + raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format( + urange, int(F.asnumpy(F.max(u.tousertensor(), dim=0))))) + if vrange is not None and vrange <= int(F.asnumpy(F.max(v.tousertensor(), dim=0))): + raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format( + vrange, int(F.asnumpy(F.max(v.tousertensor(), dim=0))))) urange = urange or (int(F.asnumpy(F.max(u.tousertensor(), dim=0))) + 1) vrange = vrange or (int(F.asnumpy(F.max(v.tousertensor(), dim=0))) + 1) if utype == vtype: @@ -710,7 +747,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None): else: return DGLHeteroGraph(hgidx, [utype, vtype], [etype]) -def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None): +def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, validate=False): """Internal function to create a heterograph from a list of edge tuples with types. utype could be equal to vtype @@ -731,6 +768,9 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None): vrange : int, optional The destination node ID range. If None, the value is the maximum of the destination node IDs in the edge list plus 1. (Default: None) + validate : bool, optional + If True, checks if node IDs are within range. + Returns ------- @@ -742,7 +782,7 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None): u, v = zip(*elist) u = list(u) v = list(v) - return create_from_edges(u, v, utype, etype, vtype, urange, vrange) + return create_from_edges(u, v, utype, etype, vtype, urange, vrange, validate) def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False): """Internal function to create a heterograph from a scipy sparse matrix with types. @@ -762,6 +802,9 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False): If True, the entries in the sparse matrix are treated as edge IDs. Otherwise, the entries are ignored and edges will be added in (source, destination) order. + validate : bool, optional + If True, checks if node IDs are within range. + Returns ------- diff --git a/tests/compute/test_heterograph.py b/tests/compute/test_heterograph.py index c5db100d19..4e0809c5c3 100644 --- a/tests/compute/test_heterograph.py +++ b/tests/compute/test_heterograph.py @@ -7,6 +7,7 @@ import itertools import backend as F import networkx as nx import unittest +from dgl import DGLError def create_test_heterograph(): # test heterograph from the docstring, plus a user -- wishes -- game relation @@ -94,6 +95,36 @@ def test_create(): assert g.number_of_nodes('l1') == 3 assert g.number_of_nodes('l2') == 4 + # test if validate flag works + # homo graph + fail = False + try: + g = dgl.graph( + ([0, 0, 0, 1, 1, 2], [0, 1, 2, 0, 1, 2]), + card=2, + validate=True + ) + except DGLError: + fail = True + finally: + assert fail, "should catch a DGLError because node ID is out of bound." + # bipartite graph + def _test_validate_bipartite(card): + fail = False + try: + g = dgl.bipartite( + ([0, 0, 1, 1, 2], [1, 1, 2, 2, 3]), + card=card, + validate=True + ) + except DGLError: + fail = True + finally: + assert fail, "should catch a DGLError because node ID is out of bound." + + _test_validate_bipartite((3, 3)) + _test_validate_bipartite((2, 4)) + def test_query(): g = create_test_heterograph() diff --git a/tutorials/models/2_small_graph/README.txt b/tutorials/models/2_small_graph/README.txt index 357613da69..0981ca8328 100644 --- a/tutorials/models/2_small_graph/README.txt +++ b/tutorials/models/2_small_graph/README.txt @@ -1,16 +1,16 @@ .. _tutorials2-index: -Dealing with many small graphs +Batching many small graphs ============================== * **Tree-LSTM** `[paper] `__ `[tutorial] - <2_small_graph/3_tree-lstm.html>`__ `[code] + <2_small_graph/3_tree-lstm.html>`__ `[PyTorch code] `__: - sentences of natural languages have inherent structures, which are thrown + Sentences have inherent structures that are thrown away by treating them simply as sequences. Tree-LSTM is a powerful model - that learns the representation by leveraging prior syntactic structures - (e.g. parse-tree). The challenge to train it well is that simply by padding - a sentence to the maximum length no longer works, since trees of different + that learns the representation by using prior syntactic structures such as a parse-tree. + The challenge in training is that simply by padding + a sentence to the maximum length no longer works. Trees of different sentences have different sizes and topologies. DGL solves this problem by - throwing the trees into a bigger "container" graph, and use message-passing - to explore maximum parallelism. The key API we use is batching. + adding the trees to a bigger container graph, and then using message-passing + to explore maximum parallelism. Batching is a key API for this.