mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
figuring out symmetric configurations
This commit is contained in:
@@ -7,7 +7,7 @@ from data_loader import get_train_valid_set, Dataset, DatasetNAComplex, DatasetR
|
||||
from kinematics import xyz_to_c6d, xyz_to_t2d
|
||||
from chemical import num2aa, aa2elt, aa2num
|
||||
from loss import compute_general_FAPE, resolve_equiv_natives, calc_str_loss
|
||||
from util import get_frames, frame_indices, is_atom, xyz_to_frame_xyz, xyz_t_to_frame_xyz
|
||||
from util import get_frames, frame_indices, is_atom, xyz_to_frame_xyz, xyz_t_to_frame_xyz, long2alt
|
||||
|
||||
class LossTestCase(unittest.TestCase):
|
||||
|
||||
@@ -244,6 +244,7 @@ class LossTestCase(unittest.TestCase):
|
||||
break
|
||||
|
||||
def test_res_mask(self):
|
||||
"""updated res_mask to not mask "atom" nodes that only have one backbone atom filled in """
|
||||
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in self.valid_sm_compl_loader:
|
||||
true_crds, atom_mask = resolve_equiv_natives(true_crds[0, 0].unsqueeze(0), true_crds, atom_mask)
|
||||
res_mask = ~((atom_mask[:,:,:3].sum(dim=-1) < 3.0) * ~(is_atom(msa[:,0,0])))
|
||||
@@ -252,6 +253,12 @@ class LossTestCase(unittest.TestCase):
|
||||
self.assertEqual(res_mask.shape[1], L)
|
||||
break
|
||||
|
||||
def test_resolve_equiv_natives(self):
|
||||
""" test that resolve_equiv_natives works"""
|
||||
for seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, atom_frames, bond_feats in self.valid_sm_compl_loader:
|
||||
true_crds, atom_mask = resolve_equiv_natives(torch.randn(true_crds[0, 0].unsqueeze(0).shape), true_crds, atom_mask)
|
||||
print(true_crds)
|
||||
break
|
||||
|
||||
class DataLoaderTestCase(unittest.TestCase):
|
||||
|
||||
@@ -283,6 +290,13 @@ class DataLoaderTestCase(unittest.TestCase):
|
||||
i = torch.randint(2, L-3,(1,))
|
||||
residues_atomize = msa[:, 0, 0, i-2:i+2]
|
||||
residues_atomize = [aa2elt[num][:14] for num in residues_atomize[0]]
|
||||
|
||||
true_alt = torch.zeros_like(true_crds)
|
||||
true_alt.scatter_(2, long2alt[msa[:, 0, 0],:,None].repeat(1,1,1,3), true_crds)
|
||||
print(true_crds.shape)
|
||||
print((true_crds[:, i-2:i+2] == true_alt[:, i-2:i+2]).all(dim=2).all(dim=2).squeeze())
|
||||
print(torch.nonzero(~(true_crds[:, i-2:i+2] == true_alt[:, i-2:i+2]).all(dim=2).all(dim=2).squeeze()).squeeze())
|
||||
|
||||
lig_seq = []
|
||||
ra = []
|
||||
for idx in range(len(residues_atomize)):
|
||||
@@ -310,21 +324,16 @@ class DataLoaderTestCase(unittest.TestCase):
|
||||
atom_mask = torch.cat((atom_mask[ :, :i-2], atom_mask[ :, i+2:]), dim=1)
|
||||
|
||||
idx_pdb = torch.cat((idx_pdb[ :, :i-2], idx_pdb[ :, i+2:]), dim=1)
|
||||
print(idx_pdb.shape)
|
||||
xyz_t = torch.cat((xyz_t[ :, :, :i-2], xyz_t[:, :, i+2:]), dim=2)
|
||||
print(xyz_t.shape)
|
||||
t1d = torch.cat((t1d[ :, :, :i-2], t1d[:, :, i+2:]), dim=2)
|
||||
print(t1d.shape)
|
||||
xyz_prev = torch.cat((xyz_prev[ :, :i-2], xyz_prev[:, i+2:]), dim=1)
|
||||
print(xyz_prev.shape)
|
||||
same_chain = torch.cat((same_chain[ :, :i-2], same_chain[:, i+2:]), dim=1)
|
||||
same_chain = torch.cat((same_chain[ :, :, :i-2], same_chain[:, :, i+2:]), dim=2)
|
||||
print(same_chain.shape)
|
||||
|
||||
print(bond_feats.shape)
|
||||
bond_feats = torch.cat((bond_feats[ :, :i-2], bond_feats[:, i+2:]), dim=1)
|
||||
bond_feats = torch.cat((bond_feats[ :, :, :i-2], bond_feats[:, :, i+2:]), dim=2)
|
||||
print(bond_feats.shape)
|
||||
|
||||
|
||||
break
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -22,7 +22,7 @@ from scheduler import get_linear_schedule_with_warmup, get_stepwise_decay_schedu
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
#torch.autograd.set_detect_anomaly(True)
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
torch.manual_seed(5924)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
@@ -175,7 +175,7 @@ class Trainer():
|
||||
|
||||
loss_s = list()
|
||||
tot_loss = 0.0
|
||||
|
||||
|
||||
# c6d loss
|
||||
for i in range(4):
|
||||
loss = self.loss_fn(logit_s[i], label_s[...,i]) # (B, L, L)
|
||||
@@ -243,6 +243,7 @@ class Trainer():
|
||||
# get alternative coordinates for ground-truth
|
||||
true_alt = torch.zeros_like(true)
|
||||
true_alt.scatter_(2, self.l2a[seq,:,None].repeat(1,1,1,3), true)
|
||||
print(true_alt)
|
||||
natRs_all, _n0 = self.compute_allatom_coords(seq, true[...,:3,:], true_tors)
|
||||
natRs_all_alt, _n1 = self.compute_allatom_coords(seq, true_alt[...,:3,:], true_tors_alt)
|
||||
predTs = pred[-1,...]
|
||||
@@ -348,10 +349,10 @@ class Trainer():
|
||||
chain2 = torch.zeros_like(same_chain, dtype=bool)
|
||||
chain2[:,L0:,L0:] = True
|
||||
_, allatom_lddt_c2 = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, chain2, negative=True)
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, chain2, negative=True, bin_scaling=0.5)
|
||||
loss_s.append(allatom_lddt_c2.detach())
|
||||
_, allatom_lddt_inter = calc_allatom_lddt_loss(
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, same_chain, interface=True, bin_scaling=0.5)
|
||||
pred_allatom.detach(), nat_symm, pred_lddt, idx, mask_crds, mask_2d, same_chain, interface=True)
|
||||
loss_s.append(allatom_lddt_inter.detach())
|
||||
# hbond [use all atoms not just those in native]
|
||||
#hb_loss = calc_hb(
|
||||
@@ -733,7 +734,7 @@ class Trainer():
|
||||
|
||||
# load model
|
||||
loaded_epoch, best_valid_loss = self.load_model(ddp_model, optimizer, scheduler, scaler,
|
||||
self.model_name, gpu, resume_train=True)
|
||||
self.model_name, gpu, suffix="best", resume_train=True)
|
||||
|
||||
if (self.eval):
|
||||
# run protein/NA prediction (TEMPLATED)
|
||||
@@ -742,13 +743,13 @@ class Trainer():
|
||||
# rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True)
|
||||
|
||||
# run protein/NA prediction (NON-TEMPLATED)
|
||||
_, _, _ = self.valid_ppi_cycle(
|
||||
ddp_model, valid_na_from_scratch_compl_loader, valid_na_from_scratch_neg_loader,
|
||||
rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True)
|
||||
#_, _, _ = self.valid_ppi_cycle(
|
||||
# ddp_model, valid_na_from_scratch_compl_loader, valid_na_from_scratch_neg_loader,
|
||||
# rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True)
|
||||
|
||||
# run RNA prediction
|
||||
#_,_,_ = self.valid_pdb_cycle(ddp_model, valid_rna_loader, rank, gpu, world_size, 0, verbose=True)
|
||||
|
||||
_, _, _ = self.valid_pdb_cycle(ddp_model, valid_sm_compl_loader, rank, gpu, world_size, 0, verbose=True)
|
||||
dist.destroy_process_group()
|
||||
return
|
||||
|
||||
@@ -816,7 +817,7 @@ class Trainer():
|
||||
'valid_loss': valid_loss,
|
||||
'valid_acc': valid_acc,
|
||||
'best_loss': best_valid_loss},
|
||||
self.checkpoint_fn(self.model_name, 'last'))
|
||||
self.checkpoint_fn(self.model_name, str(epoch)))
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user