mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
Merge remote-tracking branch 'upstream/master' into change_ci_image
This commit is contained in:
@@ -28,6 +28,7 @@ A summary of part of the model accuracy and training speed with the Pytorch back
|
||||
| [JTNN](https://arxiv.org/abs/1802.04364) | 96.44% | 96.44% | [1826s (Pytorch)](https://github.com/wengong-jin/icml18-jtnn) | 743s | 2.5x |
|
||||
| [LGNN](https://arxiv.org/abs/1705.08415) | 94% | 94% | n/a | 1.45s | n/a |
|
||||
| [DGMG](https://arxiv.org/pdf/1803.03324.pdf) | 84% | 90% | n/a | 238s | n/a |
|
||||
| [GraphWriter](https://www.aclweb.org/anthology/N19-1238.pdf) | 14.3(BLEU) | 14.31(BLEU) | [1970s (PyTorch)](https://github.com/rikdz/GraphWriter) | 1192s | 1.65x |
|
||||
|
||||
With the MXNet/Gluon backend , we scaled a graph of 50M nodes and 150M edges on a P3.8xlarge instance,
|
||||
with 160s per epoch, on SSE ([Stochastic Steady-state Embedding](https://www.cc.gatech.edu/~hdai8/pdf/equilibrium_embedding.pdf)),
|
||||
|
||||
@@ -21,3 +21,4 @@ Here is a summary of the model accuracy and training speed. Our testbed is Amazo
|
||||
| [JTNN](https://arxiv.org/abs/1802.04364) | 96.44% | 96.44% | [1826s (Pytorch)](https://github.com/wengong-jin/icml18-jtnn) | 743s | 2.5x |
|
||||
| [LGNN](https://arxiv.org/abs/1705.08415) | 94% | 94% | n/a | 1.45s | n/a |
|
||||
| [DGMG](https://arxiv.org/pdf/1803.03324.pdf) | 84% | 90% | n/a | 238s | n/a |
|
||||
| [GraphWriter](https://www.aclweb.org/anthology/N19-1238.pdf) | 14.31(BLEU) | 14.3(BLEU) | 1970s | 1192s | 1.65x |
|
||||
|
||||
44
examples/pytorch/graphwriter/README.md
Normal file
44
examples/pytorch/graphwriter/README.md
Normal file
@@ -0,0 +1,44 @@
|
||||
# GraphWriter-DGL
|
||||
In this example we implement the GraphWriter, [Text Generation from Knowledge Graphs with Graph Transformers](https://arxiv.org/abs/1904.02342) in DGL. And the [author's code](https://github.com/rikdz/GraphWriter).
|
||||
|
||||
## Dependencies
|
||||
- PyTorch >= 1.2
|
||||
- tqdm
|
||||
- pycoco (only for testing)
|
||||
- multi-bleu.perl and other scripts from mosesdecoder (only for testing)
|
||||
|
||||
## Usage
|
||||
```
|
||||
# download data
|
||||
sh prepare_data.sh
|
||||
# training
|
||||
sh run.sh
|
||||
# testing
|
||||
sh test.sh
|
||||
```
|
||||
|
||||
## Result on AGENDA
|
||||
| |BLEU|METEOR| training time per epoch|
|
||||
|-|-|-|-|
|
||||
|Author's implementation|14.3+-1.01| 18.8+-0.28| 1970s|
|
||||
|DGL implementation|14.31+-0.34|19.74+-0.69| 1192s|
|
||||
|
||||
We use the author's code for the speed test, and our testbed is V100 GPU.
|
||||
|
||||
| |BLEU| detok BLEU| METEOR |
|
||||
|-|-|-|-|
|
||||
|greedy, two layers| 13.97 +- 0.40| 13.78 +- 0.46| 18.76 +- 0.36|
|
||||
|beam 4, length penalty 1.0, two layers| 14.66 +- 0.65| 14.53 +- 0.52| 19.50 +- 0.49|
|
||||
|beam 4, length penalty 0.0, two layers| 14.33 +- 0.39| 14.09 +- 0.39| 18.63 +- 0.52|
|
||||
|greedy, six layers| 14.17 +- 0.46| 14.01 +- 0.51| 19.18 +- 0.49|
|
||||
|beam 4, length penalty 1.0, six layers| 14.31 +- 0.34| 14.35 +- 0.36| 19.74 +- 0.69|
|
||||
|beam 4, length penalty 0.0, six layers| 14.40 +- 0.85| 14.15 +- 0.84| 18.86 +- 0.78|
|
||||
|
||||
We repeat the experiment five times.
|
||||
|
||||
### Examples
|
||||
|
||||
We also provide the output of our implementation on test set together with the reference text.
|
||||
- [GraphWriter's output](https://s3.us-east-2.amazonaws.com/dgl.ai/models/graphwriter/tmp_pred.txt)
|
||||
- [Reference text](https://s3.us-east-2.amazonaws.com/dgl.ai/models/graphwriter/tmp_gold.txt)
|
||||
|
||||
182
examples/pytorch/graphwriter/graphwriter.py
Normal file
182
examples/pytorch/graphwriter/graphwriter.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import torch
|
||||
from modules import MSA, BiLSTM, GraphTrans
|
||||
from utlis import *
|
||||
from torch import nn
|
||||
import dgl
|
||||
|
||||
|
||||
class GraphWriter(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(GraphWriter, self).__init__()
|
||||
self.args = args
|
||||
if args.title:
|
||||
self.title_emb = nn.Embedding(len(args.title_vocab), args.nhid, padding_idx=0)
|
||||
self.title_enc = BiLSTM(args, enc_type='title')
|
||||
self.title_attn = MSA(args)
|
||||
self.ent_emb = nn.Embedding(len(args.ent_text_vocab), args.nhid, padding_idx=0)
|
||||
self.tar_emb = nn.Embedding(len(args.text_vocab), args.nhid, padding_idx=0)
|
||||
if args.title:
|
||||
nn.init.xavier_normal_(self.title_emb.weight)
|
||||
nn.init.xavier_normal_(self.ent_emb.weight)
|
||||
self.rel_emb = nn.Embedding(len(args.rel_vocab), args.nhid, padding_idx=0)
|
||||
nn.init.xavier_normal_(self.rel_emb.weight)
|
||||
self.decode_lstm = nn.LSTMCell(args.dec_ninp, args.nhid)
|
||||
self.ent_enc = BiLSTM(args, enc_type='entity')
|
||||
self.graph_enc = GraphTrans(args)
|
||||
self.ent_attn = MSA(args)
|
||||
self.copy_attn = MSA(args, mode='copy')
|
||||
self.copy_fc = nn.Linear(args.dec_ninp, 1)
|
||||
self.pred_v_fc = nn.Linear(args.dec_ninp, len(args.text_vocab))
|
||||
|
||||
def enc_forward(self, batch, ent_mask, ent_text_mask, ent_len, rel_mask, title_mask):
|
||||
title_enc = None
|
||||
if self.args.title:
|
||||
title_enc = self.title_enc(self.title_emb(batch['title']), title_mask)
|
||||
ent_enc = self.ent_enc(self.ent_emb(batch['ent_text']), ent_text_mask, ent_len = batch['ent_len'])
|
||||
rel_emb = self.rel_emb(batch['rel'])
|
||||
g_ent, g_root = self.graph_enc(ent_enc, ent_mask, ent_len, rel_emb, rel_mask, batch['graph'])
|
||||
return g_ent, g_root, title_enc, ent_enc
|
||||
|
||||
def forward(self, batch, beam_size=-1):
|
||||
ent_mask = len2mask(batch['ent_len'], self.args.device)
|
||||
ent_text_mask = batch['ent_text']==0
|
||||
rel_mask = batch['rel']==0 # 0 means the <PAD>
|
||||
title_mask = batch['title']==0
|
||||
g_ent, g_root, title_enc, ent_enc = self.enc_forward(batch, ent_mask, ent_text_mask, batch['ent_len'], rel_mask, title_mask)
|
||||
|
||||
_h, _c = g_root, g_root.clone().detach()
|
||||
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
|
||||
if self.args.title:
|
||||
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
|
||||
ctx = torch.cat([ctx, attn], 1)
|
||||
if beam_size<1:
|
||||
# training
|
||||
outs = []
|
||||
tar_inp = self.tar_emb(batch['text'].transpose(0,1))
|
||||
for t, xt in enumerate(tar_inp):
|
||||
_xt = torch.cat([ctx, xt], 1)
|
||||
_h, _c = self.decode_lstm(_xt, (_h, _c))
|
||||
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
|
||||
if self.args.title:
|
||||
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
|
||||
ctx = torch.cat([ctx, attn], 1)
|
||||
outs.append(torch.cat([_h, ctx], 1))
|
||||
outs = torch.stack(outs, 1)
|
||||
copy_gate = torch.sigmoid(self.copy_fc(outs))
|
||||
EPSI = 1e-6
|
||||
# copy
|
||||
pred_v = torch.log(copy_gate+EPSI) + torch.log_softmax(self.pred_v_fc(outs), -1)
|
||||
pred_c = torch.log((1. - copy_gate)+EPSI) + torch.log_softmax(self.copy_attn(outs, ent_enc, mask=ent_mask), -1)
|
||||
pred = torch.cat([pred_v, pred_c], -1)
|
||||
return pred
|
||||
else:
|
||||
if beam_size==1:
|
||||
# greedy
|
||||
device = g_ent.device
|
||||
B = g_ent.shape[0]
|
||||
ent_type = batch['ent_type'].view(B, -1)
|
||||
seq = (torch.ones(B,).long().to(device) * self.args.text_vocab('<BOS>')).unsqueeze(1)
|
||||
for t in range(self.args.beam_max_len):
|
||||
_inp = replace_ent(seq[:,-1], ent_type, len(self.args.text_vocab))
|
||||
xt = self.tar_emb(_inp)
|
||||
_xt = torch.cat([ctx, xt], 1)
|
||||
_h, _c = self.decode_lstm(_xt, (_h, _c))
|
||||
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
|
||||
if self.args.title:
|
||||
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
|
||||
ctx = torch.cat([ctx, attn], 1)
|
||||
_y = torch.cat([_h, ctx], 1)
|
||||
copy_gate = torch.sigmoid(self.copy_fc(_y))
|
||||
pred_v = torch.log(copy_gate) + torch.log_softmax(self.pred_v_fc(_y), -1)
|
||||
pred_c = torch.log((1. - copy_gate)) + torch.log_softmax(self.copy_attn(_y.unsqueeze(1), ent_enc, mask=ent_mask).squeeze(1), -1)
|
||||
pred = torch.cat([pred_v, pred_c], -1).view(B,-1)
|
||||
for ban_item in ['<BOS>', '<PAD>', '<UNK>']:
|
||||
pred[:, self.args.text_vocab(ban_item)] = -1e8
|
||||
_, word = pred.max(-1)
|
||||
seq = torch.cat([seq, word.unsqueeze(1)], 1)
|
||||
return seq
|
||||
else:
|
||||
# beam search
|
||||
device = g_ent.device
|
||||
B = g_ent.shape[0]
|
||||
BSZ = B * beam_size
|
||||
_h = _h.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
|
||||
_c = _c.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
|
||||
ent_mask = ent_mask.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
|
||||
if self.args.title:
|
||||
title_mask = title_mask.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
|
||||
title_enc = title_enc.view(B, 1, title_enc.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, title_enc.size(1), -1)
|
||||
ctx = ctx.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
|
||||
ent_type = batch['ent_type'].view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
|
||||
g_ent = g_ent.view(B, 1, g_ent.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, g_ent.size(1), -1)
|
||||
ent_enc = ent_enc.view(B, 1, ent_enc.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, ent_enc.size(1), -1)
|
||||
|
||||
beam_best = torch.zeros(B).to(device) - 1e9
|
||||
beam_best_seq = [None] * B
|
||||
beam_seq = (torch.ones(B, beam_size).long().to(device) * self.args.text_vocab('<BOS>')).unsqueeze(-1)
|
||||
beam_score = torch.zeros(B, beam_size).to(device)
|
||||
done_flag = torch.zeros(B, beam_size)
|
||||
for t in range(self.args.beam_max_len):
|
||||
_inp = replace_ent(beam_seq[:,:,-1].view(-1), ent_type, len(self.args.text_vocab))
|
||||
xt = self.tar_emb(_inp)
|
||||
_xt = torch.cat([ctx, xt], 1)
|
||||
_h, _c = self.decode_lstm(_xt, (_h, _c))
|
||||
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
|
||||
if self.args.title:
|
||||
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
|
||||
ctx = torch.cat([ctx, attn], 1)
|
||||
_y = torch.cat([_h, ctx], 1)
|
||||
copy_gate = torch.sigmoid(self.copy_fc(_y))
|
||||
pred_v = torch.log(copy_gate) + torch.log_softmax(self.pred_v_fc(_y), -1)
|
||||
pred_c = torch.log((1. - copy_gate)) + torch.log_softmax(self.copy_attn(_y.unsqueeze(1), ent_enc, mask=ent_mask).squeeze(1), -1)
|
||||
pred = torch.cat([pred_v, pred_c], -1).view(B, beam_size, -1)
|
||||
for ban_item in ['<BOS>', '<PAD>', '<UNK>']:
|
||||
pred[:, :, self.args.text_vocab(ban_item)] = -1e8
|
||||
if t==self.args.beam_max_len-1: # force ending
|
||||
tt = pred[:, :, self.args.text_vocab('<EOS>')]
|
||||
pred = pred*0-1e8
|
||||
pred[:, :, self.args.text_vocab('<EOS>')] = tt
|
||||
cum_score = beam_score.view(B,beam_size,1) + pred
|
||||
score, word = cum_score.topk(dim=-1, k=beam_size) # B, beam_size, beam_size
|
||||
score, word = score.view(B,-1), word.view(B,-1)
|
||||
eos_idx = self.args.text_vocab('<EOS>')
|
||||
if beam_seq.size(2)==1:
|
||||
new_idx = torch.arange(beam_size).to(word)
|
||||
new_idx = new_idx[None,:].repeat(B,1)
|
||||
else:
|
||||
_, new_idx = score.topk(dim=-1, k=beam_size)
|
||||
new_src, new_score, new_word, new_done = [], [], [], []
|
||||
LP = beam_seq.size(2) ** self.args.lp
|
||||
for i in range(B):
|
||||
for j in range(beam_size):
|
||||
tmp_score = score[i][new_idx[i][j]]
|
||||
tmp_word = word[i][new_idx[i][j]]
|
||||
src_idx = new_idx[i][j]//beam_size
|
||||
new_src.append(src_idx)
|
||||
if tmp_word == eos_idx:
|
||||
new_score.append(-1e8)
|
||||
else:
|
||||
new_score.append(tmp_score)
|
||||
new_word.append(tmp_word)
|
||||
|
||||
if tmp_word == eos_idx and done_flag[i][src_idx]==0 and tmp_score/LP>beam_best[i]:
|
||||
beam_best[i] = tmp_score/LP
|
||||
beam_best_seq[i] = beam_seq[i][src_idx]
|
||||
if tmp_word == eos_idx:
|
||||
new_done.append(1)
|
||||
else:
|
||||
new_done.append(done_flag[i][src_idx])
|
||||
new_score = torch.Tensor(new_score).view(B,beam_size).to(beam_score)
|
||||
new_word = torch.Tensor(new_word).view(B,beam_size).to(beam_seq)
|
||||
new_src = torch.LongTensor(new_src).view(B,beam_size).to(device)
|
||||
new_done = torch.Tensor(new_done).view(B,beam_size).to(done_flag)
|
||||
beam_score = new_score
|
||||
done_flag = new_done
|
||||
beam_seq = beam_seq.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src]
|
||||
beam_seq = torch.cat([beam_seq, new_word.unsqueeze(2)], 2)
|
||||
_h = _h.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1)
|
||||
_c = _c.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1)
|
||||
ctx = ctx.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1)
|
||||
|
||||
return beam_best_seq
|
||||
|
||||
164
examples/pytorch/graphwriter/modules.py
Executable file
164
examples/pytorch/graphwriter/modules.py
Executable file
@@ -0,0 +1,164 @@
|
||||
import torch
|
||||
import math
|
||||
import dgl.function as fn
|
||||
from dgl.nn.pytorch import edge_softmax
|
||||
from utlis import *
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence
|
||||
|
||||
|
||||
class MSA(nn.Module):
|
||||
# multi-head self-attention, three modes
|
||||
# the first is the copy, determining which entity should be copied.
|
||||
# the second is the normal attention with two sequence inputs
|
||||
# the third is the attention but with one token and a sequence. (gather, attentive pooling)
|
||||
|
||||
def __init__(self, args, mode='normal'):
|
||||
super(MSA, self).__init__()
|
||||
if mode=='copy':
|
||||
nhead, head_dim = 1, args.nhid
|
||||
qninp, kninp = args.dec_ninp, args.nhid
|
||||
if mode=='normal':
|
||||
nhead, head_dim = args.nhead, args.head_dim
|
||||
qninp, kninp = args.nhid, args.nhid
|
||||
self.attn_drop = nn.Dropout(0.1)
|
||||
self.WQ = nn.Linear(qninp, nhead*head_dim, bias=True if mode=='copy' else False)
|
||||
if mode!='copy':
|
||||
self.WK = nn.Linear(kninp, nhead*head_dim, bias=False)
|
||||
self.WV = nn.Linear(kninp, nhead*head_dim, bias=False)
|
||||
self.args, self.nhead, self.head_dim, self.mode = args, nhead, head_dim, mode
|
||||
|
||||
def forward(self, inp1, inp2, mask=None):
|
||||
B, L2, H = inp2.shape
|
||||
NH, HD = self.nhead, self.head_dim
|
||||
if self.mode=='copy':
|
||||
q, k, v = self.WQ(inp1), inp2, inp2
|
||||
else:
|
||||
q, k, v = self.WQ(inp1), self.WK(inp2), self.WV(inp2)
|
||||
L1 = 1 if inp1.ndim==2 else inp1.shape[1]
|
||||
if self.mode!='copy':
|
||||
q = q / math.sqrt(H)
|
||||
q = q.view(B, L1, NH, HD).permute(0, 2, 1, 3)
|
||||
k = k.view(B, L2, NH, HD).permute(0, 2, 3, 1)
|
||||
v = v.view(B, L2, NH, HD).permute(0, 2, 1, 3)
|
||||
pre_attn = torch.matmul(q,k)
|
||||
if mask is not None:
|
||||
pre_attn = pre_attn.masked_fill(mask[:,None,None,:], -1e8)
|
||||
if self.mode=='copy':
|
||||
return pre_attn.squeeze(1)
|
||||
else:
|
||||
alpha = self.attn_drop(torch.softmax(pre_attn, -1))
|
||||
attn = torch.matmul(alpha, v).permute(0, 2, 1, 3).contiguous().view(B,L1,NH*HD)
|
||||
ret = attn
|
||||
if inp1.ndim==2:
|
||||
return ret.squeeze(1)
|
||||
else:
|
||||
return ret
|
||||
|
||||
|
||||
class BiLSTM(nn.Module):
|
||||
# for entity encoding or the title encoding
|
||||
def __init__(self, args, enc_type='title'):
|
||||
super(BiLSTM, self).__init__()
|
||||
self.enc_type = enc_type
|
||||
self.drop = nn.Dropout(args.emb_drop)
|
||||
self.bilstm = nn.LSTM(args.nhid, args.nhid//2, bidirectional=True, \
|
||||
num_layers=args.enc_lstm_layers, batch_first=True)
|
||||
|
||||
def forward(self, inp, mask, ent_len=None):
|
||||
inp = self.drop(inp)
|
||||
lens = (mask==0).sum(-1).long().tolist()
|
||||
pad_seq = pack_padded_sequence(inp, lens, batch_first=True, enforce_sorted=False)
|
||||
y, (_h, _c) = self.bilstm(pad_seq)
|
||||
if self.enc_type=='title':
|
||||
y = pad_packed_sequence(y, batch_first=True)[0]
|
||||
return y
|
||||
if self.enc_type=='entity':
|
||||
_h = _h.transpose(0,1).contiguous()
|
||||
_h = _h[:,-2:].view(_h.size(0), -1) # two directions of the top-layer
|
||||
ret = pad(_h.split(ent_len), out_type='tensor')
|
||||
return ret
|
||||
|
||||
|
||||
class GAT(nn.Module):
|
||||
# a graph attention network with dot-product attention
|
||||
def __init__(self,
|
||||
in_feats,
|
||||
out_feats,
|
||||
num_heads,
|
||||
ffn_drop=0.,
|
||||
attn_drop=0.,
|
||||
trans=True):
|
||||
super(GAT, self).__init__()
|
||||
self._num_heads = num_heads
|
||||
self._in_feats = in_feats
|
||||
self._out_feats = out_feats
|
||||
self.q_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False)
|
||||
self.k_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False)
|
||||
self.v_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False)
|
||||
self.attn_drop = nn.Dropout(0.1)
|
||||
self.ln1 = nn.LayerNorm(in_feats)
|
||||
self.ln2 = nn.LayerNorm(in_feats)
|
||||
if trans:
|
||||
self.FFN = nn.Sequential(
|
||||
nn.Linear(in_feats, 4*in_feats),
|
||||
nn.PReLU(4*in_feats),
|
||||
nn.Linear(4*in_feats, in_feats),
|
||||
nn.Dropout(0.1),
|
||||
)
|
||||
# a strange FFN, see the author's code
|
||||
self._trans = trans
|
||||
|
||||
def forward(self, graph, feat):
|
||||
graph = graph.local_var()
|
||||
feat_c = feat.clone().detach().requires_grad_(False)
|
||||
q, k, v = self.q_proj(feat), self.k_proj(feat_c), self.v_proj(feat_c)
|
||||
q = q.view(-1, self._num_heads, self._out_feats)
|
||||
k = k.view(-1, self._num_heads, self._out_feats)
|
||||
v = v.view(-1, self._num_heads, self._out_feats)
|
||||
graph.ndata.update({'ft': v, 'el': k, 'er': q}) # k,q instead of q,k, the edge_softmax is applied on incoming edges
|
||||
# compute edge attention
|
||||
graph.apply_edges(fn.u_dot_v('el', 'er', 'e'))
|
||||
e = graph.edata.pop('e') / math.sqrt(self._out_feats * self._num_heads)
|
||||
graph.edata['a'] = edge_softmax(graph, e).unsqueeze(-1)
|
||||
# message passing
|
||||
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
|
||||
fn.sum('m', 'ft2'))
|
||||
rst = graph.ndata['ft2']
|
||||
# residual
|
||||
rst = rst.view(feat.shape) + feat
|
||||
if self._trans:
|
||||
rst = self.ln1(rst)
|
||||
rst = self.ln1(rst+self.FFN(rst))
|
||||
# use the same layer norm, see the author's code
|
||||
return rst
|
||||
|
||||
|
||||
class GraphTrans(nn.Module):
|
||||
def __init__(self,args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
if args.graph_enc == "gat":
|
||||
# we only support gtrans, don't use this one
|
||||
self.gat = nn.ModuleList([GAT(args.nhid, args.nhid//4, 4, attn_drop=args.attn_drop, trans=False) for _ in range(args.prop)]) #untested
|
||||
else:
|
||||
self.gat = nn.ModuleList([GAT(args.nhid, args.nhid//4, 4, attn_drop=args.attn_drop, ffn_drop=args.drop, trans=True) for _ in range(args.prop)])
|
||||
self.prop = args.prop
|
||||
|
||||
def forward(self, ent, ent_mask, ent_len, rel, rel_mask, graphs):
|
||||
device = ent.device
|
||||
ent_mask = (ent_mask==0) # reverse mask
|
||||
rel_mask = (rel_mask==0)
|
||||
init_h = []
|
||||
for i in range(graphs.batch_size):
|
||||
init_h.append(ent[i][ent_mask[i]])
|
||||
init_h.append(rel[i][rel_mask[i]])
|
||||
init_h = torch.cat(init_h, 0)
|
||||
feats = init_h
|
||||
for i in range(self.prop):
|
||||
feats = self.gat[i](graphs, feats)
|
||||
g_root = feats.index_select(0, graphs.filter_nodes(lambda x: x.data['type']==NODE_TYPE['root']).to(device))
|
||||
g_ent = pad(feats.index_select(0, graphs.filter_nodes(lambda x: x.data['type']==NODE_TYPE['entity']).to(device)).split(ent_len), out_type='tensor')
|
||||
return g_ent, g_root
|
||||
|
||||
55
examples/pytorch/graphwriter/opts.py
Normal file
55
examples/pytorch/graphwriter/opts.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
|
||||
def fill_config(args):
|
||||
# dirty work
|
||||
args.device = torch.device(args.gpu)
|
||||
args.dec_ninp = args.nhid * 3 if args.title else args.nhid * 2
|
||||
args.fnames = [args.train_file, args.valid_file, args.test_file]
|
||||
return args
|
||||
|
||||
|
||||
def vocab_config(args, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab):
|
||||
# dirty work
|
||||
args.ent_vocab = ent_vocab
|
||||
args.rel_vocab = rel_vocab
|
||||
args.text_vocab = text_vocab
|
||||
args.ent_text_vocab = ent_text_vocab
|
||||
args.title_vocab = title_vocab
|
||||
return args
|
||||
|
||||
|
||||
def get_args():
|
||||
args = argparse.ArgumentParser(description='Graph Writer in DGL')
|
||||
args.add_argument('--nhid', default=500, type=int, help='hidden size')
|
||||
args.add_argument('--nhead', default=4, type=int, help='number of heads')
|
||||
args.add_argument('--head_dim', default=125, type=int, help='head dim')
|
||||
args.add_argument('--weight_decay', default=0.0, type=float, help='weight decay')
|
||||
args.add_argument('--prop', default=6, type=int, help='number of layers of gnn')
|
||||
args.add_argument('--title', action='store_true', help='use title input')
|
||||
args.add_argument('--test', action='store_true', help='inference mode')
|
||||
args.add_argument('--batch_size', default=32, type=int, help='batch_size')
|
||||
args.add_argument('--beam_size', default=4, type=int, help='beam size, 1 for greedy')
|
||||
args.add_argument('--epoch', default=20, type=int, help='training epoch')
|
||||
args.add_argument('--beam_max_len', default=200, type=int, help='max length of the generated text')
|
||||
args.add_argument('--enc_lstm_layers', default=2, type=int, help='number of layers of lstm')
|
||||
args.add_argument('--lr', default=1e-1, type=float, help='learning rate')
|
||||
#args.add_argument('--lr_decay', default=1e-8, type=float, help='')
|
||||
args.add_argument('--clip', default=1, type=float, help='gradient clip')
|
||||
args.add_argument('--emb_drop', default=0.0, type=float, help='embedding dropout')
|
||||
args.add_argument('--attn_drop', default=0.1, type=float, help='attention dropout')
|
||||
args.add_argument('--drop', default=0.1, type=float, help='dropout')
|
||||
args.add_argument('--lp', default=1.0, type=float, help='length penalty')
|
||||
args.add_argument('--graph_enc', default='gtrans', type=str, help='gnn mode, we only support the graph transformer now')
|
||||
args.add_argument('--train_file', default='data/unprocessed.train.json', type=str, help='training file')
|
||||
args.add_argument('--valid_file', default='data/unprocessed.val.json', type=str, help='validation file')
|
||||
args.add_argument('--test_file', default='data/unprocessed.test.json', type=str, help='test file')
|
||||
args.add_argument('--save_dataset', default='data.pickle', type=str, help='save path of dataset')
|
||||
args.add_argument('--save_model', default='saved_model.pt', type=str, help='save path of model')
|
||||
|
||||
args.add_argument('--gpu', default=0, type=int, help='gpu mode')
|
||||
args = args.parse_args()
|
||||
args = fill_config(args)
|
||||
return args
|
||||
|
||||
3
examples/pytorch/graphwriter/prepare_data.sh
Normal file
3
examples/pytorch/graphwriter/prepare_data.sh
Normal file
@@ -0,0 +1,3 @@
|
||||
wget https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/AGENDA.tar.gz
|
||||
mkdir data
|
||||
tar -C data/ -xvzf AGENDA.tar.gz
|
||||
6
examples/pytorch/graphwriter/run.sh
Normal file
6
examples/pytorch/graphwriter/run.sh
Normal file
@@ -0,0 +1,6 @@
|
||||
nohup env CUDA_VISIBLE_DEVICES=0 python -u train.py --prop 6 --save_model tmp_model.pt --title > train_1.log 2>&1 &
|
||||
#nohup env CUDA_VISIBLE_DEVICES=2 python -u train.py --prop 6 --save_model tmp_model1.pt --title > train_2.log 2>&1 &
|
||||
#nohup env CUDA_VISIBLE_DEVICES=3 python -u train.py --prop 6 --save_model tmp_model2.pt --title > train_3.log 2>&1 &
|
||||
#nohup env CUDA_VISIBLE_DEVICES=4 python -u train.py --prop 6 --save_model tmp_model3.pt --title > train_4.log 2>&1 &
|
||||
#nohup env CUDA_VISIBLE_DEVICES=5 python -u train.py --prop 2 --save_model tmp_model4.pt --title > train_5.log 2>&1 &
|
||||
#nohup env CUDA_VISIBLE_DEVICES=6 python -u train.py --prop 2 --save_model tmp_model5.pt --title > train_6.log 2>&1 &
|
||||
11
examples/pytorch/graphwriter/test.sh
Normal file
11
examples/pytorch/graphwriter/test.sh
Normal file
@@ -0,0 +1,11 @@
|
||||
env CUDA_VISIBLE_DEVICES=0 python -u train.py --save_model tmp_model.ptbest --test --title --lp 1.0 --beam_size 1
|
||||
if [ ! detokenizer.perl ]; then
|
||||
wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/8c5eaa1a122236bbf927bde4ec610906fea599e6/scripts/tokenizer/detokenizer.perl
|
||||
fi
|
||||
if [ ! multi-bleu.perl ]; then
|
||||
wget https://raw.githubusercontent.com/moses-smt/mosesdecoder/8c5eaa1a122236bbf927bde4ec610906fea599e6/scripts/generic/multi-bleu.perl
|
||||
fi
|
||||
perl detokenizer.perl -l en < tmp_gold.txt > tmp_gold.txt.a
|
||||
perl detokenizer.perl -l en < tmp_pred.txt > tmp_pred.txt.a
|
||||
perl multi-bleu.perl tmp_gold.txt < tmp_pred.txt
|
||||
perl multi-bleu-detok.perl tmp_gold.txt.a < tmp_pred.txt.a
|
||||
125
examples/pytorch/graphwriter/train.py
Normal file
125
examples/pytorch/graphwriter/train.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
from graphwriter import *
|
||||
from utlis import *
|
||||
from opts import *
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append('./pycocoevalcap')
|
||||
from pycocoevalcap.bleu.bleu import Bleu
|
||||
from pycocoevalcap.rouge.rouge import Rouge
|
||||
from pycocoevalcap.meteor.meteor import Meteor
|
||||
|
||||
|
||||
def train_one_epoch(model, dataloader, optimizer, args, epoch):
|
||||
model.train()
|
||||
tloss = 0.
|
||||
tcnt = 0.
|
||||
st_time = time.time()
|
||||
with tqdm(dataloader, desc='Train Ep '+str(epoch), mininterval=60) as tq:
|
||||
for batch in tq:
|
||||
pred = model(batch)
|
||||
nll_loss = F.nll_loss(pred.view(-1, pred.shape[-1]), batch['tgt_text'].view(-1), ignore_index=0)
|
||||
loss = nll_loss
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(model.parameters(), args.clip)
|
||||
optimizer.step()
|
||||
loss = loss.item()
|
||||
if loss!=loss:
|
||||
raise ValueError('NaN appear')
|
||||
tloss += loss * len(batch['tgt_text'])
|
||||
tcnt += len(batch['tgt_text'])
|
||||
tq.set_postfix({'loss': tloss/tcnt}, refresh=False)
|
||||
print('Train Ep ', str(epoch), 'AVG Loss ', tloss/tcnt, 'Steps ', tcnt, 'Time ', time.time()-st_time, 'GPU', torch.cuda.max_memory_cached()/1024.0/1024.0/1024.0)
|
||||
torch.save(model, args.save_model+str(epoch%100))
|
||||
|
||||
|
||||
val_loss = 2**31
|
||||
def eval_it(model, dataloader, args, epoch):
|
||||
global val_loss
|
||||
model.eval()
|
||||
tloss = 0.
|
||||
tcnt = 0.
|
||||
st_time = time.time()
|
||||
with tqdm(dataloader, desc='Eval Ep '+str(epoch), mininterval=60) as tq:
|
||||
for batch in tq:
|
||||
with torch.no_grad():
|
||||
pred = model(batch)
|
||||
nll_loss = F.nll_loss(pred.view(-1, pred.shape[-1]), batch['tgt_text'].view(-1), ignore_index=0)
|
||||
loss = nll_loss
|
||||
loss = loss.item()
|
||||
tloss += loss * len(batch['tgt_text'])
|
||||
tcnt += len(batch['tgt_text'])
|
||||
tq.set_postfix({'loss': tloss/tcnt}, refresh=False)
|
||||
print('Eval Ep ', str(epoch), 'AVG Loss ', tloss/tcnt, 'Steps ', tcnt, 'Time ', time.time()-st_time)
|
||||
if tloss/tcnt < val_loss:
|
||||
print('Saving best model ', 'Ep ', epoch, ' loss ', tloss/tcnt)
|
||||
torch.save(model, args.save_model+'best')
|
||||
val_loss = tloss/tcnt
|
||||
|
||||
|
||||
def test(model, dataloader, args):
|
||||
scorer = Bleu(4)
|
||||
m_scorer = Meteor()
|
||||
r_scorer = Rouge()
|
||||
hyp = []
|
||||
ref = []
|
||||
model.eval()
|
||||
gold_file = open('tmp_gold.txt', 'w')
|
||||
pred_file = open('tmp_pred.txt', 'w')
|
||||
with tqdm(dataloader, desc='Test ', mininterval=1) as tq:
|
||||
for batch in tq:
|
||||
with torch.no_grad():
|
||||
seq = model(batch, beam_size=args.beam_size)
|
||||
r = write_txt(batch, batch['tgt_text'], gold_file, args)
|
||||
h = write_txt(batch, seq, pred_file, args)
|
||||
hyp.extend(h)
|
||||
ref.extend(r)
|
||||
hyp = dict(zip(range(len(hyp)), hyp))
|
||||
ref = dict(zip(range(len(ref)), ref))
|
||||
print(hyp[0], ref[0])
|
||||
print('BLEU INP', len(hyp), len(ref))
|
||||
print('BLEU', scorer.compute_score(ref, hyp)[0])
|
||||
print('METEOR', m_scorer.compute_score(ref, hyp)[0])
|
||||
print('ROUGE_L', r_scorer.compute_score(ref, hyp)[0])
|
||||
gold_file.close()
|
||||
pred_file.close()
|
||||
|
||||
|
||||
def main(args):
|
||||
if os.path.exists(args.save_dataset):
|
||||
train_dataset, valid_dataset, test_dataset = pickle.load(open(args.save_dataset, 'rb'))
|
||||
else:
|
||||
train_dataset, valid_dataset, test_dataset = get_datasets(args.fnames, device=args.device, save=args.save_dataset)
|
||||
args = vocab_config(args, train_dataset.ent_vocab, train_dataset.rel_vocab, train_dataset.text_vocab, train_dataset.ent_text_vocab, train_dataset.title_vocab)
|
||||
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler = BucketSampler(train_dataset, batch_size=args.batch_size), \
|
||||
collate_fn=train_dataset.batch_fn)
|
||||
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, \
|
||||
shuffle=False, collate_fn=train_dataset.batch_fn)
|
||||
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, \
|
||||
shuffle=False, collate_fn=train_dataset.batch_fn)
|
||||
|
||||
model = GraphWriter(args)
|
||||
model.to(args.device)
|
||||
if args.test:
|
||||
model = torch.load(args.save_model)
|
||||
model.args = args
|
||||
print(model)
|
||||
test(model, test_dataloader, args)
|
||||
else:
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)
|
||||
print(model)
|
||||
for epoch in range(args.epoch):
|
||||
train_one_epoch(model, train_dataloader, optimizer, args, epoch)
|
||||
eval_it(model, valid_dataloader, args, epoch)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
main(args)
|
||||
321
examples/pytorch/graphwriter/utlis.py
Normal file
321
examples/pytorch/graphwriter/utlis.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import torch
|
||||
import dgl
|
||||
import numpy as np
|
||||
import json
|
||||
import pickle
|
||||
import random
|
||||
|
||||
|
||||
NODE_TYPE = {'entity': 0, 'root': 1, 'relation':2}
|
||||
|
||||
|
||||
def write_txt(batch, seqs, w_file, args):
|
||||
# converting the prediction to real text.
|
||||
ret = []
|
||||
for b, seq in enumerate(seqs):
|
||||
txt = []
|
||||
for token in seq:
|
||||
# copy the entity
|
||||
if token>=len(args.text_vocab):
|
||||
ent_text = batch['raw_ent_text'][b][token-len(args.text_vocab)]
|
||||
ent_text = filter(lambda x:x!='<PAD>', ent_text)
|
||||
txt.extend(ent_text)
|
||||
else:
|
||||
if int(token) not in [args.text_vocab(x) for x in ['<PAD>', '<BOS>', '<EOS>']]:
|
||||
txt.append(args.text_vocab(int(token)))
|
||||
if int(token) == args.text_vocab('<EOS>'):
|
||||
break
|
||||
w_file.write(' '.join([str(x) for x in txt])+'\n')
|
||||
ret.append([' '.join([str(x) for x in txt])])
|
||||
return ret
|
||||
|
||||
|
||||
def replace_ent(x, ent, V):
|
||||
# replace the entity
|
||||
mask = x>=V
|
||||
if mask.sum()==0:
|
||||
return x
|
||||
nz = mask.nonzero()
|
||||
fill_ent = ent[nz, x[mask]-V]
|
||||
x = x.masked_scatter(mask, fill_ent)
|
||||
return x
|
||||
|
||||
|
||||
def len2mask(lens, device):
|
||||
max_len = max(lens)
|
||||
mask = torch.arange(max_len, device=device).unsqueeze(0).expand(len(lens), max_len)
|
||||
mask = mask >= torch.LongTensor(lens).to(mask).unsqueeze(1)
|
||||
return mask
|
||||
|
||||
|
||||
def pad(var_len_list, out_type='list', flatten=False):
|
||||
if flatten:
|
||||
lens = [len(x) for x in var_len_list]
|
||||
var_len_list = sum(var_len_list, [])
|
||||
max_len = max([len(x) for x in var_len_list])
|
||||
if out_type=='list':
|
||||
if flatten:
|
||||
return [x+['<PAD>']*(max_len-len(x)) for x in var_len_list], lens
|
||||
else:
|
||||
return [x+['<PAD>']*(max_len-len(x)) for x in var_len_list]
|
||||
if out_type=='tensor':
|
||||
if flatten:
|
||||
return torch.stack([torch.cat([x, \
|
||||
torch.zeros([max_len-len(x)]+list(x.shape[1:])).type_as(x)], 0) for x in var_len_list], 0), lens
|
||||
else:
|
||||
return torch.stack([torch.cat([x, \
|
||||
torch.zeros([max_len-len(x)]+list(x.shape[1:])).type_as(x)], 0) for x in var_len_list], 0)
|
||||
|
||||
|
||||
class Vocab(object):
|
||||
def __init__(self, max_vocab=2**31, min_freq=-1, sp=['<PAD>', '<BOS>', '<EOS>', '<UNK>']):
|
||||
self.i2s = []
|
||||
self.s2i = {}
|
||||
self.wf = {}
|
||||
self.max_vocab, self.min_freq, self.sp = max_vocab, min_freq, sp
|
||||
|
||||
def __len__(self):
|
||||
return len(self.i2s)
|
||||
|
||||
def __str__(self):
|
||||
return 'Total ' + str(len(self.i2s)) + str(self.i2s[:10])
|
||||
|
||||
def update(self, token):
|
||||
if isinstance(token, list):
|
||||
for t in token:
|
||||
self.update(t)
|
||||
else:
|
||||
self.wf[token] = self.wf.get(token, 0) + 1
|
||||
|
||||
def build(self):
|
||||
self.i2s.extend(self.sp)
|
||||
sort_kv = sorted(self.wf.items(), key=lambda x:x[1], reverse=True)
|
||||
for k,v in sort_kv:
|
||||
if len(self.i2s)<self.max_vocab and v>=self.min_freq and k not in self.sp:
|
||||
self.i2s.append(k)
|
||||
self.s2i.update(list(zip(self.i2s, range(len(self.i2s)))))
|
||||
|
||||
def __call__(self, x):
|
||||
if isinstance(x, int):
|
||||
return self.i2s[x]
|
||||
else:
|
||||
return self.s2i.get(x, self.s2i['<UNK>'])
|
||||
|
||||
def save(self, fname):
|
||||
pass
|
||||
|
||||
def load(self, fname):
|
||||
pass
|
||||
|
||||
def at_least(x):
|
||||
# handling the illegal data
|
||||
if len(x) == 0:
|
||||
return ['<UNK>']
|
||||
else:
|
||||
return x
|
||||
|
||||
class Example(object):
|
||||
def __init__(self, title, ent_text, ent_type, rel, text):
|
||||
# one object corresponds to a data sample
|
||||
self.raw_title = title.split()
|
||||
self.raw_ent_text = [at_least(x.split()) for x in ent_text]
|
||||
assert min([len(x) for x in self.raw_ent_text])>0, str(self.raw_ent_text)
|
||||
self.raw_ent_type = ent_type.split() # <method> .. <>
|
||||
self.raw_rel = []
|
||||
for r in rel:
|
||||
rel_list = r.split()
|
||||
for i in range(len(rel_list)):
|
||||
if i>0 and i<len(rel_list)-1 and rel_list[i-1]=='--' and rel_list[i]!=rel_list[i].lower() and rel_list[i+1]=='--':
|
||||
self.raw_rel.append([rel_list[:i-1], rel_list[i-1]+rel_list[i]+rel_list[i+1], rel_list[i+2:]])
|
||||
break
|
||||
self.raw_text = text.split()
|
||||
self.graph = self.build_graph()
|
||||
|
||||
def __str__(self):
|
||||
return '\n'.join([str(k)+':\t'+str(v) for k, v in self.__dict__.items()])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.raw_text)
|
||||
|
||||
@staticmethod
|
||||
def from_json(json_data):
|
||||
return Example(json_data['title'], json_data['entities'], json_data['types'], json_data['relations'], json_data['abstract'])
|
||||
|
||||
def build_graph(self):
|
||||
graph = dgl.DGLGraph()
|
||||
ent_len = len(self.raw_ent_text)
|
||||
rel_len = len(self.raw_rel) # treat the repeated relation as different nodes, refer to the author's code
|
||||
|
||||
graph.add_nodes(ent_len, {'type': torch.ones(ent_len) * NODE_TYPE['entity']})
|
||||
graph.add_nodes(1, {'type': torch.ones(1) * NODE_TYPE['root']})
|
||||
graph.add_nodes(rel_len*2, {'type': torch.ones(rel_len*2) * NODE_TYPE['relation']})
|
||||
graph.add_edges(ent_len, torch.arange(ent_len))
|
||||
graph.add_edges(torch.arange(ent_len), ent_len)
|
||||
graph.add_edges(torch.arange(ent_len+1+rel_len*2), torch.arange(ent_len+1+rel_len*2))
|
||||
adj_edges = []
|
||||
for i, r in enumerate(self.raw_rel):
|
||||
assert len(r)==3, str(r)
|
||||
st, rt, ed = r
|
||||
st_ent, ed_ent = self.raw_ent_text.index(st), self.raw_ent_text.index(ed)
|
||||
# according to the edge_softmax operator, we need to reverse the graph
|
||||
adj_edges.append([ent_len+1+2*i, st_ent])
|
||||
adj_edges.append([ed_ent, ent_len+1+2*i])
|
||||
adj_edges.append([ent_len+1+2*i+1, ed_ent])
|
||||
adj_edges.append([st_ent, ent_len+1+2*i+1])
|
||||
|
||||
if len(adj_edges)>0:
|
||||
graph.add_edges(*list(map(list, zip(*adj_edges))))
|
||||
return graph
|
||||
|
||||
def get_tensor(self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab):
|
||||
if hasattr(self, '_cached_tensor'):
|
||||
return self._cached_tensor
|
||||
else:
|
||||
title_data = ['<BOS>'] + self.raw_title + ['<EOS>']
|
||||
title = [title_vocab(x) for x in title_data]
|
||||
ent_text = [[ent_text_vocab(y) for y in x] for x in self.raw_ent_text]
|
||||
ent_type = [text_vocab(x) for x in self.raw_ent_type] # for inference
|
||||
rel_data = ['--root--'] + sum([[x[1],x[1]+'_INV'] for x in self.raw_rel], [])
|
||||
rel = [rel_vocab(x) for x in rel_data]
|
||||
|
||||
text_data = ['<BOS>'] + self.raw_text + ['<EOS>']
|
||||
text = [text_vocab(x) for x in text_data]
|
||||
tgt_text = []
|
||||
# the input text and decoding target are different since the consideration of the copy mechanism.
|
||||
for i, str1 in enumerate(text_data):
|
||||
if str1[0]=='<' and str1[-1]=='>' and '_' in str1:
|
||||
a, b = str1[1:-1].split('_')
|
||||
text[i] = text_vocab('<'+a+'>')
|
||||
tgt_text.append(len(text_vocab)+int(b))
|
||||
else:
|
||||
tgt_text.append(text[i])
|
||||
self._cached_tensor = {'title': torch.LongTensor(title), 'ent_text': [torch.LongTensor(x) for x in ent_text], \
|
||||
'ent_type': torch.LongTensor(ent_type), 'rel': torch.LongTensor(rel), \
|
||||
'text': torch.LongTensor(text[:-1]), 'tgt_text': torch.LongTensor(tgt_text[1:]), 'graph': self.graph, 'raw_ent_text': self.raw_ent_text}
|
||||
return self._cached_tensor
|
||||
|
||||
def update_vocab(self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab):
|
||||
ent_vocab.update(self.raw_ent_type)
|
||||
ent_text_vocab.update(self.raw_ent_text)
|
||||
title_vocab.update(self.raw_title)
|
||||
rel_vocab.update(['--root--']+[x[1] for x in self.raw_rel]+[x[1]+'_INV' for x in self.raw_rel])
|
||||
text_vocab.update(self.raw_ent_type)
|
||||
text_vocab.update(self.raw_text)
|
||||
|
||||
class BucketSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, data_source, batch_size=32, bucket=3):
|
||||
self.data_source = data_source
|
||||
self.bucket = bucket
|
||||
self.batch_size = batch_size
|
||||
|
||||
def __iter__(self):
|
||||
# the magic number comes from the author's code
|
||||
perm = torch.randperm(len(self.data_source))
|
||||
lens = torch.Tensor([len(x) for x in self.data_source])
|
||||
lens = lens[perm]
|
||||
t1 = []
|
||||
t2 = []
|
||||
t3 = []
|
||||
for i, l in enumerate(lens):
|
||||
if (l<100):
|
||||
t1.append(perm[i])
|
||||
elif (l>100 and l<220):
|
||||
t2.append(perm[i])
|
||||
else:
|
||||
t3.append(perm[i])
|
||||
datas = [t1,t2,t3]
|
||||
random.shuffle(datas)
|
||||
idxs = sum(datas, [])
|
||||
batch = []
|
||||
for idx in idxs:
|
||||
batch.append(idx)
|
||||
mlen = max([0]+[lens[x] for x in batch])
|
||||
if (mlen<100 and len(batch) == 32) or (mlen>100 and mlen<220 and len(batch) >= 24) or (mlen>220 and len(batch)>=8) or len(batch)==32:
|
||||
yield batch
|
||||
batch = []
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
|
||||
def __len__(self):
|
||||
return (len(self.data_source)+self.batch_size-1)//self.batch_size
|
||||
|
||||
|
||||
class GWdataset(torch.utils.data.Dataset):
|
||||
def __init__(self, exs, ent_vocab=None, rel_vocab=None, text_vocab=None, ent_text_vocab=None, title_vocab=None, device=None):
|
||||
super(GWdataset, self).__init__()
|
||||
self.exs = exs
|
||||
self.ent_vocab, self.rel_vocab, self.text_vocab, self.ent_text_vocab, self.title_vocab, self.device = \
|
||||
ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab, device
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.exs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.exs[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.exs)
|
||||
|
||||
def batch_fn(self, batch_ex):
|
||||
batch_title, batch_ent_text, batch_ent_type, batch_rel, batch_text, batch_tgt_text, batch_graph = \
|
||||
[], [], [], [], [], [], []
|
||||
batch_raw_ent_text = []
|
||||
for ex in batch_ex:
|
||||
ex_data = ex.get_tensor(self.ent_vocab, self.rel_vocab, self.text_vocab, self.ent_text_vocab, self.title_vocab)
|
||||
batch_title.append(ex_data['title'])
|
||||
batch_ent_text.append(ex_data['ent_text'])
|
||||
batch_ent_type.append(ex_data['ent_type'])
|
||||
batch_rel.append(ex_data['rel'])
|
||||
batch_text.append(ex_data['text'])
|
||||
batch_tgt_text.append(ex_data['tgt_text'])
|
||||
batch_graph.append(ex_data['graph'])
|
||||
batch_raw_ent_text.append(ex_data['raw_ent_text'])
|
||||
batch_title = pad(batch_title, out_type='tensor')
|
||||
batch_ent_text, ent_len = pad(batch_ent_text, out_type='tensor', flatten=True)
|
||||
batch_ent_type = pad(batch_ent_type, out_type='tensor')
|
||||
batch_rel = pad(batch_rel, out_type='tensor')
|
||||
batch_text = pad(batch_text, out_type='tensor')
|
||||
batch_tgt_text = pad(batch_tgt_text, out_type='tensor')
|
||||
batch_graph = dgl.batch(batch_graph)
|
||||
batch_graph.to(self.device)
|
||||
return {'title': batch_title.to(self.device), 'ent_text': batch_ent_text.to(self.device), 'ent_len': ent_len, \
|
||||
'ent_type': batch_ent_type.to(self.device), 'rel': batch_rel.to(self.device), 'text': batch_text.to(self.device), \
|
||||
'tgt_text': batch_tgt_text.to(self.device), 'graph': batch_graph, 'raw_ent_text': batch_raw_ent_text}
|
||||
|
||||
|
||||
def get_datasets(fnames, min_freq=-1, sep=';', joint_vocab=True, device=None, save='tmp.pickle'):
|
||||
# min_freq : not support now since it's very sensitive to the final results, but you can set it via passing min_freq to the Vocab class.
|
||||
# sep : not support now
|
||||
# joint_vocab : not support now
|
||||
ent_vocab = Vocab(sp=['<PAD>', '<UNK>'])
|
||||
title_vocab = Vocab(min_freq=5)
|
||||
rel_vocab = Vocab(sp=['<PAD>', '<UNK>'])
|
||||
text_vocab = Vocab(min_freq=5)
|
||||
ent_text_vocab = Vocab(sp=['<PAD>', '<UNK>'])
|
||||
datasets = []
|
||||
for fname in fnames:
|
||||
exs = []
|
||||
json_datas = json.loads(open(fname).read())
|
||||
for json_data in json_datas:
|
||||
# construct one data example
|
||||
ex = Example.from_json(json_data)
|
||||
if fname == fnames[0]: # only training set
|
||||
ex.update_vocab(ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab)
|
||||
exs.append(ex)
|
||||
datasets.append(exs)
|
||||
ent_vocab.build()
|
||||
rel_vocab.build()
|
||||
text_vocab.build()
|
||||
ent_text_vocab.build()
|
||||
title_vocab.build()
|
||||
datasets = [GWdataset(exs, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab, device) for exs in datasets]
|
||||
with open(save, 'wb') as f:
|
||||
pickle.dump(datasets, f)
|
||||
return datasets
|
||||
|
||||
|
||||
if __name__ == '__main__' :
|
||||
ds = get_datasets(['data/unprocessed.val.json', 'data/unprocessed.val.json', 'data/unprocessed.test.json'])
|
||||
print(ds[0].exs[0])
|
||||
print(ds[0].exs[0].get_tensor(ds[0].ent_vocab, ds[0].rel_vocab, ds[0].text_vocab, ds[0].ent_text_vocab, ds[0].title_vocab))
|
||||
|
||||
@@ -63,16 +63,16 @@ as front end and Set2Set for output prediction.
|
||||
### Example Usage of Pre-trained Models
|
||||
|
||||
```python
|
||||
from dgl.data.chem import Tox21
|
||||
from dgl.data.chem import Tox21, smiles_to_bigraph, CanonicalAtomFeaturizer
|
||||
from dgl import model_zoo
|
||||
|
||||
dataset = Tox21()
|
||||
dataset = Tox21(smiles_to_bigraph, CanonicalAtomFeaturizer())
|
||||
model = model_zoo.chem.load_pretrained('GCN_Tox21') # Pretrained model loaded
|
||||
model.eval()
|
||||
|
||||
smiles, g, label, mask = dataset[0]
|
||||
feats = g.ndata.pop('h')
|
||||
label_pred = model(feats, g)
|
||||
label_pred = model(g, feats)
|
||||
print(smiles) # CCOc1ccc2nc(S(N)(=O)=O)sc2c1
|
||||
print(label_pred[:, mask != 0]) # Mask non-existing labels
|
||||
# tensor([[-0.7956, 0.4054, 0.4288, -0.5565, -0.0911,
|
||||
|
||||
@@ -21,7 +21,7 @@ __all__ = [
|
||||
'to_networkx',
|
||||
]
|
||||
|
||||
def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
|
||||
def graph(data, ntype='_N', etype='_E', card=None, validate=False, **kwargs):
|
||||
"""Create a graph with one type of nodes and edges.
|
||||
|
||||
In the sparse matrix perspective, :func:`dgl.graph` creates a graph
|
||||
@@ -45,6 +45,10 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
|
||||
card : int, optional
|
||||
Cardinality (number of nodes in the graph). If None, infer from input data, i.e.
|
||||
the largest node ID plus 1. (Default: None)
|
||||
validate : bool, optional
|
||||
If True, check if node ids are within cardinality, the check process may take
|
||||
some time.
|
||||
If False and card is not None, user would receive a warning. (Default: False)
|
||||
kwargs : key-word arguments, optional
|
||||
Other key word arguments. Only comes into effect when we are using a NetworkX
|
||||
graph. It can consist of:
|
||||
@@ -101,6 +105,16 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
|
||||
['follows']
|
||||
>>> g.canonical_etypes
|
||||
[('user', 'follows', 'user')]
|
||||
|
||||
Check if node ids are within cardinality
|
||||
|
||||
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=2, validate=True)
|
||||
...
|
||||
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
|
||||
>>> g = dgl.graph(([0, 1, 2], [1, 2, 0]), card=3, validate=True)
|
||||
Graph(num_nodes=3, num_edges=3,
|
||||
ndata_schemes={}
|
||||
edata_schemes={})
|
||||
"""
|
||||
if card is not None:
|
||||
urange, vrange = card, card
|
||||
@@ -108,9 +122,9 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
|
||||
urange, vrange = None, None
|
||||
if isinstance(data, tuple):
|
||||
u, v = data
|
||||
return create_from_edges(u, v, ntype, etype, ntype, urange, vrange)
|
||||
return create_from_edges(u, v, ntype, etype, ntype, urange, vrange, validate)
|
||||
elif isinstance(data, list):
|
||||
return create_from_edge_list(data, ntype, etype, ntype, urange, vrange)
|
||||
return create_from_edge_list(data, ntype, etype, ntype, urange, vrange, validate)
|
||||
elif isinstance(data, sp.sparse.spmatrix):
|
||||
return create_from_scipy(data, ntype, etype, ntype)
|
||||
elif isinstance(data, nx.Graph):
|
||||
@@ -118,7 +132,7 @@ def graph(data, ntype='_N', etype='_E', card=None, **kwargs):
|
||||
else:
|
||||
raise DGLError('Unsupported graph data type:', type(data))
|
||||
|
||||
def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
|
||||
def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, validate=False, **kwargs):
|
||||
"""Create a bipartite graph.
|
||||
|
||||
The result graph is directed and edges must be from ``utype`` nodes
|
||||
@@ -147,6 +161,10 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
|
||||
card : pair of int, optional
|
||||
Cardinality (number of nodes in the source and destination group). If None,
|
||||
infer from input data, i.e. the largest node ID plus 1 for each type. (Default: None)
|
||||
validate : bool, optional
|
||||
If True, check if node ids are within cardinality, the check process may take
|
||||
some time.
|
||||
If False and card is not None, user would receive a warning. (Default: False)
|
||||
kwargs : key-word arguments, optional
|
||||
Other key word arguments. Only comes into effect when we are using a NetworkX
|
||||
graph. It can consist of:
|
||||
@@ -215,6 +233,16 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
|
||||
4
|
||||
>>> g.edges()
|
||||
(tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]), tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]))
|
||||
|
||||
Check if node ids are within cardinality
|
||||
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(2, 4), validate=True)
|
||||
...
|
||||
dgl._ffi.base.DGLError: Invalid node id 2 (should be less than cardinality 2).
|
||||
>>> g = dgl.bipartite(([0, 1, 2], [1, 2, 3]), card=(3, 4), validate=True)
|
||||
>>> g
|
||||
Graph(num_nodes={'_U': 3, '_V': 4},
|
||||
num_edges={('_U', '_E', '_V'): 3},
|
||||
metagraph=[('_U', '_V')])
|
||||
"""
|
||||
if utype == vtype:
|
||||
raise DGLError('utype should not be equal to vtype. Use ``dgl.graph`` instead.')
|
||||
@@ -224,9 +252,9 @@ def bipartite(data, utype='_U', etype='_E', vtype='_V', card=None, **kwargs):
|
||||
urange, vrange = None, None
|
||||
if isinstance(data, tuple):
|
||||
u, v = data
|
||||
return create_from_edges(u, v, utype, etype, vtype, urange, vrange)
|
||||
return create_from_edges(u, v, utype, etype, vtype, urange, vrange, validate)
|
||||
elif isinstance(data, list):
|
||||
return create_from_edge_list(data, utype, etype, vtype, urange, vrange)
|
||||
return create_from_edge_list(data, utype, etype, vtype, urange, vrange, validate)
|
||||
elif isinstance(data, sp.sparse.spmatrix):
|
||||
return create_from_scipy(data, utype, etype, vtype)
|
||||
elif isinstance(data, nx.Graph):
|
||||
@@ -667,7 +695,7 @@ def to_homo(G):
|
||||
# Internal APIs
|
||||
############################################################
|
||||
|
||||
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
|
||||
def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None, validate=False):
|
||||
"""Internal function to create a graph from incident nodes with types.
|
||||
|
||||
utype could be equal to vtype
|
||||
@@ -690,6 +718,8 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
|
||||
vrange : int, optional
|
||||
The destination node ID range. If None, the value is the
|
||||
maximum of the destination node IDs in the edge list plus 1. (Default: None)
|
||||
validate : bool, optional
|
||||
If True, checks if node IDs are within range.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -697,6 +727,13 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
|
||||
"""
|
||||
u = utils.toindex(u)
|
||||
v = utils.toindex(v)
|
||||
if validate:
|
||||
if urange is not None and urange <= int(F.asnumpy(F.max(u.tousertensor(), dim=0))):
|
||||
raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format(
|
||||
urange, int(F.asnumpy(F.max(u.tousertensor(), dim=0)))))
|
||||
if vrange is not None and vrange <= int(F.asnumpy(F.max(v.tousertensor(), dim=0))):
|
||||
raise DGLError('Invalid node id {} (should be less than cardinality {}).'.format(
|
||||
vrange, int(F.asnumpy(F.max(v.tousertensor(), dim=0)))))
|
||||
urange = urange or (int(F.asnumpy(F.max(u.tousertensor(), dim=0))) + 1)
|
||||
vrange = vrange or (int(F.asnumpy(F.max(v.tousertensor(), dim=0))) + 1)
|
||||
if utype == vtype:
|
||||
@@ -710,7 +747,7 @@ def create_from_edges(u, v, utype, etype, vtype, urange=None, vrange=None):
|
||||
else:
|
||||
return DGLHeteroGraph(hgidx, [utype, vtype], [etype])
|
||||
|
||||
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
|
||||
def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None, validate=False):
|
||||
"""Internal function to create a heterograph from a list of edge tuples with types.
|
||||
|
||||
utype could be equal to vtype
|
||||
@@ -731,6 +768,9 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
|
||||
vrange : int, optional
|
||||
The destination node ID range. If None, the value is the
|
||||
maximum of the destination node IDs in the edge list plus 1. (Default: None)
|
||||
validate : bool, optional
|
||||
If True, checks if node IDs are within range.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -742,7 +782,7 @@ def create_from_edge_list(elist, utype, etype, vtype, urange=None, vrange=None):
|
||||
u, v = zip(*elist)
|
||||
u = list(u)
|
||||
v = list(v)
|
||||
return create_from_edges(u, v, utype, etype, vtype, urange, vrange)
|
||||
return create_from_edges(u, v, utype, etype, vtype, urange, vrange, validate)
|
||||
|
||||
def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False):
|
||||
"""Internal function to create a heterograph from a scipy sparse matrix with types.
|
||||
@@ -762,6 +802,9 @@ def create_from_scipy(spmat, utype, etype, vtype, with_edge_id=False):
|
||||
If True, the entries in the sparse matrix are treated as edge IDs.
|
||||
Otherwise, the entries are ignored and edges will be added in
|
||||
(source, destination) order.
|
||||
validate : bool, optional
|
||||
If True, checks if node IDs are within range.
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
@@ -7,6 +7,7 @@ import itertools
|
||||
import backend as F
|
||||
import networkx as nx
|
||||
import unittest
|
||||
from dgl import DGLError
|
||||
|
||||
def create_test_heterograph():
|
||||
# test heterograph from the docstring, plus a user -- wishes -- game relation
|
||||
@@ -94,6 +95,36 @@ def test_create():
|
||||
assert g.number_of_nodes('l1') == 3
|
||||
assert g.number_of_nodes('l2') == 4
|
||||
|
||||
# test if validate flag works
|
||||
# homo graph
|
||||
fail = False
|
||||
try:
|
||||
g = dgl.graph(
|
||||
([0, 0, 0, 1, 1, 2], [0, 1, 2, 0, 1, 2]),
|
||||
card=2,
|
||||
validate=True
|
||||
)
|
||||
except DGLError:
|
||||
fail = True
|
||||
finally:
|
||||
assert fail, "should catch a DGLError because node ID is out of bound."
|
||||
# bipartite graph
|
||||
def _test_validate_bipartite(card):
|
||||
fail = False
|
||||
try:
|
||||
g = dgl.bipartite(
|
||||
([0, 0, 1, 1, 2], [1, 1, 2, 2, 3]),
|
||||
card=card,
|
||||
validate=True
|
||||
)
|
||||
except DGLError:
|
||||
fail = True
|
||||
finally:
|
||||
assert fail, "should catch a DGLError because node ID is out of bound."
|
||||
|
||||
_test_validate_bipartite((3, 3))
|
||||
_test_validate_bipartite((2, 4))
|
||||
|
||||
def test_query():
|
||||
g = create_test_heterograph()
|
||||
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
.. _tutorials2-index:
|
||||
|
||||
Dealing with many small graphs
|
||||
Batching many small graphs
|
||||
==============================
|
||||
|
||||
* **Tree-LSTM** `[paper] <https://arxiv.org/abs/1503.00075>`__ `[tutorial]
|
||||
<2_small_graph/3_tree-lstm.html>`__ `[code]
|
||||
<2_small_graph/3_tree-lstm.html>`__ `[PyTorch code]
|
||||
<https://github.com/dmlc/dgl/blob/master/examples/pytorch/tree_lstm>`__:
|
||||
sentences of natural languages have inherent structures, which are thrown
|
||||
Sentences have inherent structures that are thrown
|
||||
away by treating them simply as sequences. Tree-LSTM is a powerful model
|
||||
that learns the representation by leveraging prior syntactic structures
|
||||
(e.g. parse-tree). The challenge to train it well is that simply by padding
|
||||
a sentence to the maximum length no longer works, since trees of different
|
||||
that learns the representation by using prior syntactic structures such as a parse-tree.
|
||||
The challenge in training is that simply by padding
|
||||
a sentence to the maximum length no longer works. Trees of different
|
||||
sentences have different sizes and topologies. DGL solves this problem by
|
||||
throwing the trees into a bigger "container" graph, and use message-passing
|
||||
to explore maximum parallelism. The key API we use is batching.
|
||||
adding the trees to a bigger container graph, and then using message-passing
|
||||
to explore maximum parallelism. Batching is a key API for this.
|
||||
|
||||
Reference in New Issue
Block a user