mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
Bug fixes for multimer inference and monomer training
This commit is contained in:
@@ -155,14 +155,17 @@ def model_config(
|
||||
c.loss.tm.weight = 0.1
|
||||
elif "multimer" in name:
|
||||
c.globals.is_multimer = True
|
||||
c.globals.bfloat16 = True
|
||||
c.globals.bfloat16 = False
|
||||
c.globals.bfloat16_output = False
|
||||
c.loss.masked_msa.num_classes = 22
|
||||
c.data.common.max_recycling_iters = 20
|
||||
|
||||
for k,v in multimer_model_config_update.items():
|
||||
for k,v in multimer_model_config_update['model'].items():
|
||||
c.model[k] = v
|
||||
|
||||
for k, v in multimer_model_config_update['loss'].items():
|
||||
c.loss[k] = v
|
||||
|
||||
# TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model
|
||||
if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name):
|
||||
#c.model.input_embedder.num_msa = 252
|
||||
@@ -590,6 +593,12 @@ config = mlc.ConfigDict(
|
||||
"c_out": 37,
|
||||
},
|
||||
},
|
||||
# A negative value indicates that no early stopping will occur, i.e.
|
||||
# the model will always run `max_recycling_iters` number of recycling
|
||||
# iterations. A positive value will enable early stopping if the
|
||||
# difference in pairwise distances is less than the tolerance between
|
||||
# recycling steps.
|
||||
"recycle_early_stop_tolerance": -1.
|
||||
},
|
||||
"relax": {
|
||||
"max_iterations": 0, # no max
|
||||
@@ -670,157 +679,154 @@ config = mlc.ConfigDict(
|
||||
"eps": eps,
|
||||
},
|
||||
"ema": {"decay": 0.999},
|
||||
# A negative value indicates that no early stopping will occur, i.e.
|
||||
# the model will always run `max_recycling_iters` number of recycling
|
||||
# iterations. A positive value will enable early stopping if the
|
||||
# difference in pairwise distances is less than the tolerance between
|
||||
# recycling steps.
|
||||
"recycle_early_stop_tolerance": -1
|
||||
}
|
||||
)
|
||||
|
||||
multimer_model_config_update = {
|
||||
"input_embedder": {
|
||||
"tf_dim": 21,
|
||||
"msa_dim": 49,
|
||||
#"num_msa": 508,
|
||||
"c_z": c_z,
|
||||
"c_m": c_m,
|
||||
"relpos_k": 32,
|
||||
"max_relative_chain": 2,
|
||||
"max_relative_idx": 32,
|
||||
"use_chain_relative": True,
|
||||
},
|
||||
"template": {
|
||||
"distogram": {
|
||||
"min_bin": 3.25,
|
||||
"max_bin": 50.75,
|
||||
"no_bins": 39,
|
||||
},
|
||||
"template_pair_embedder": {
|
||||
"model": {
|
||||
"input_embedder": {
|
||||
"tf_dim": 21,
|
||||
"msa_dim": 49,
|
||||
#"num_msa": 508,
|
||||
"c_z": c_z,
|
||||
"c_out": 64,
|
||||
"c_dgram": 39,
|
||||
"c_aatype": 22,
|
||||
},
|
||||
"template_single_embedder": {
|
||||
"c_in": 34,
|
||||
"c_m": c_m,
|
||||
"relpos_k": 32,
|
||||
"max_relative_chain": 2,
|
||||
"max_relative_idx": 32,
|
||||
"use_chain_relative": True,
|
||||
},
|
||||
"template_pair_stack": {
|
||||
"template": {
|
||||
"distogram": {
|
||||
"min_bin": 3.25,
|
||||
"max_bin": 50.75,
|
||||
"no_bins": 39,
|
||||
},
|
||||
"template_pair_embedder": {
|
||||
"c_z": c_z,
|
||||
"c_out": 64,
|
||||
"c_dgram": 39,
|
||||
"c_aatype": 22,
|
||||
},
|
||||
"template_single_embedder": {
|
||||
"c_in": 34,
|
||||
"c_m": c_m,
|
||||
},
|
||||
"template_pair_stack": {
|
||||
"c_t": c_t,
|
||||
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
|
||||
# as 64. In the code, it's 16.
|
||||
"c_hidden_tri_att": 16,
|
||||
"c_hidden_tri_mul": 64,
|
||||
"no_blocks": 2,
|
||||
"no_heads": 4,
|
||||
"pair_transition_n": 2,
|
||||
"dropout_rate": 0.25,
|
||||
"tri_mul_first": True,
|
||||
"fuse_projection_weights": True,
|
||||
"blocks_per_ckpt": blocks_per_ckpt,
|
||||
"inf": 1e9,
|
||||
},
|
||||
"c_t": c_t,
|
||||
# DISCREPANCY: c_hidden_tri_att here is given in the supplement
|
||||
# as 64. In the code, it's 16.
|
||||
"c_hidden_tri_att": 16,
|
||||
"c_hidden_tri_mul": 64,
|
||||
"no_blocks": 2,
|
||||
"no_heads": 4,
|
||||
"pair_transition_n": 2,
|
||||
"dropout_rate": 0.25,
|
||||
"tri_mul_first": True,
|
||||
"fuse_projection_weights": True,
|
||||
"blocks_per_ckpt": blocks_per_ckpt,
|
||||
"inf": 1e9,
|
||||
},
|
||||
"c_t": c_t,
|
||||
"c_z": c_z,
|
||||
"inf": 1e5, # 1e9,
|
||||
"eps": eps, # 1e-6,
|
||||
"enabled": templates_enabled,
|
||||
"embed_angles": embed_template_torsion_angles,
|
||||
"use_unit_vector": True
|
||||
},
|
||||
"extra_msa": {
|
||||
"extra_msa_embedder": {
|
||||
"c_in": 25,
|
||||
"c_out": c_e,
|
||||
#"num_extra_msa": 2048
|
||||
},
|
||||
"extra_msa_stack": {
|
||||
"c_m": c_e,
|
||||
"c_z": c_z,
|
||||
"c_hidden_msa_att": 8,
|
||||
"inf": 1e5, # 1e9,
|
||||
"eps": eps, # 1e-6,
|
||||
"enabled": templates_enabled,
|
||||
"embed_angles": embed_template_torsion_angles,
|
||||
"use_unit_vector": True
|
||||
},
|
||||
"extra_msa": {
|
||||
"extra_msa_embedder": {
|
||||
"c_in": 25,
|
||||
"c_out": c_e,
|
||||
#"num_extra_msa": 2048
|
||||
},
|
||||
"extra_msa_stack": {
|
||||
"c_m": c_e,
|
||||
"c_z": c_z,
|
||||
"c_hidden_msa_att": 8,
|
||||
"c_hidden_opm": 32,
|
||||
"c_hidden_mul": 128,
|
||||
"c_hidden_pair_att": 32,
|
||||
"no_heads_msa": 8,
|
||||
"no_heads_pair": 4,
|
||||
"no_blocks": 4,
|
||||
"transition_n": 4,
|
||||
"msa_dropout": 0.15,
|
||||
"pair_dropout": 0.25,
|
||||
"opm_first": True,
|
||||
"fuse_projection_weights": True,
|
||||
"clear_cache_between_blocks": True,
|
||||
"inf": 1e9,
|
||||
"eps": eps, # 1e-10,
|
||||
"ckpt": blocks_per_ckpt is not None,
|
||||
},
|
||||
"enabled": True,
|
||||
},
|
||||
"evoformer_stack": {
|
||||
"c_m": c_m,
|
||||
"c_z": c_z,
|
||||
"c_hidden_msa_att": 32,
|
||||
"c_hidden_opm": 32,
|
||||
"c_hidden_mul": 128,
|
||||
"c_hidden_pair_att": 32,
|
||||
"c_s": c_s,
|
||||
"no_heads_msa": 8,
|
||||
"no_heads_pair": 4,
|
||||
"no_blocks": 4,
|
||||
"no_blocks": 48,
|
||||
"transition_n": 4,
|
||||
"msa_dropout": 0.15,
|
||||
"pair_dropout": 0.25,
|
||||
"opm_first": True,
|
||||
"fuse_projection_weights": True,
|
||||
"clear_cache_between_blocks": True,
|
||||
"blocks_per_ckpt": blocks_per_ckpt,
|
||||
"clear_cache_between_blocks": False,
|
||||
"inf": 1e9,
|
||||
"eps": eps, # 1e-10,
|
||||
"ckpt": blocks_per_ckpt is not None,
|
||||
},
|
||||
"enabled": True,
|
||||
},
|
||||
"evoformer_stack": {
|
||||
"c_m": c_m,
|
||||
"c_z": c_z,
|
||||
"c_hidden_msa_att": 32,
|
||||
"c_hidden_opm": 32,
|
||||
"c_hidden_mul": 128,
|
||||
"c_hidden_pair_att": 32,
|
||||
"c_s": c_s,
|
||||
"no_heads_msa": 8,
|
||||
"no_heads_pair": 4,
|
||||
"no_blocks": 48,
|
||||
"transition_n": 4,
|
||||
"msa_dropout": 0.15,
|
||||
"pair_dropout": 0.25,
|
||||
"opm_first": True,
|
||||
"fuse_projection_weights": True,
|
||||
"blocks_per_ckpt": blocks_per_ckpt,
|
||||
"clear_cache_between_blocks": False,
|
||||
"inf": 1e9,
|
||||
"eps": eps, # 1e-10,
|
||||
},
|
||||
"structure_module": {
|
||||
"c_s": c_s,
|
||||
"c_z": c_z,
|
||||
"c_ipa": 16,
|
||||
"c_resnet": 128,
|
||||
"no_heads_ipa": 12,
|
||||
"no_qk_points": 4,
|
||||
"no_v_points": 8,
|
||||
"dropout_rate": 0.1,
|
||||
"no_blocks": 8,
|
||||
"no_transition_layers": 1,
|
||||
"no_resnet_blocks": 2,
|
||||
"no_angles": 7,
|
||||
"trans_scale_factor": 20,
|
||||
"epsilon": eps, # 1e-12,
|
||||
"inf": 1e5,
|
||||
},
|
||||
"heads": {
|
||||
"lddt": {
|
||||
"no_bins": 50,
|
||||
"c_in": c_s,
|
||||
"c_hidden": 128,
|
||||
},
|
||||
"distogram": {
|
||||
"c_z": c_z,
|
||||
"no_bins": aux_distogram_bins,
|
||||
},
|
||||
"tm": {
|
||||
"c_z": c_z,
|
||||
"no_bins": aux_distogram_bins,
|
||||
"ptm_weight": 0.2,
|
||||
"iptm_weight": 0.8,
|
||||
"enabled": True,
|
||||
},
|
||||
"masked_msa": {
|
||||
"c_m": c_m,
|
||||
"c_out": 22,
|
||||
},
|
||||
"experimentally_resolved": {
|
||||
"structure_module": {
|
||||
"c_s": c_s,
|
||||
"c_out": 37,
|
||||
"c_z": c_z,
|
||||
"c_ipa": 16,
|
||||
"c_resnet": 128,
|
||||
"no_heads_ipa": 12,
|
||||
"no_qk_points": 4,
|
||||
"no_v_points": 8,
|
||||
"dropout_rate": 0.1,
|
||||
"no_blocks": 8,
|
||||
"no_transition_layers": 1,
|
||||
"no_resnet_blocks": 2,
|
||||
"no_angles": 7,
|
||||
"trans_scale_factor": 20,
|
||||
"epsilon": eps, # 1e-12,
|
||||
"inf": 1e5,
|
||||
},
|
||||
"heads": {
|
||||
"lddt": {
|
||||
"no_bins": 50,
|
||||
"c_in": c_s,
|
||||
"c_hidden": 128,
|
||||
},
|
||||
"distogram": {
|
||||
"c_z": c_z,
|
||||
"no_bins": aux_distogram_bins,
|
||||
},
|
||||
"tm": {
|
||||
"c_z": c_z,
|
||||
"no_bins": aux_distogram_bins,
|
||||
"ptm_weight": 0.2,
|
||||
"iptm_weight": 0.8,
|
||||
"enabled": True,
|
||||
},
|
||||
"masked_msa": {
|
||||
"c_m": c_m,
|
||||
"c_out": 22,
|
||||
},
|
||||
"experimentally_resolved": {
|
||||
"c_s": c_s,
|
||||
"c_out": 37,
|
||||
},
|
||||
},
|
||||
"recycle_early_stop_tolerance": 0.5
|
||||
},
|
||||
"loss": {
|
||||
"distogram": {
|
||||
@@ -897,6 +903,5 @@ multimer_model_config_update = {
|
||||
"enabled": True,
|
||||
},
|
||||
"eps": eps,
|
||||
},
|
||||
"recycle_early_stop_tolerance": 0.5
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,7 +151,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
|
||||
chain: i for i, chain in enumerate(self._chain_ids)
|
||||
}
|
||||
|
||||
template_featurizer = templates.TemplateHitFeaturizer(
|
||||
template_featurizer = templates.HhsearchHitFeaturizer(
|
||||
mmcif_dir=template_mmcif_dir,
|
||||
max_template_date=max_template_date,
|
||||
max_hits=max_template_hits,
|
||||
|
||||
@@ -24,7 +24,7 @@ from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union
|
||||
|
||||
import numpy as np
|
||||
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
|
||||
from openfold.data.templates import get_custom_template_features
|
||||
from openfold.data.templates import get_custom_template_features, empty_template_feats
|
||||
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
|
||||
from openfold.data.tools.utils import to_date
|
||||
from openfold.np import residue_constants, protein
|
||||
@@ -34,22 +34,10 @@ FeatureDict = MutableMapping[str, np.ndarray]
|
||||
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
|
||||
|
||||
|
||||
def empty_template_feats(n_res) -> FeatureDict:
|
||||
return {
|
||||
"template_aatype": np.zeros((0, n_res)).astype(np.int64),
|
||||
"template_all_atom_positions":
|
||||
np.zeros((0, n_res, 37, 3)).astype(np.float32),
|
||||
"template_sum_probs": np.zeros((0, 1)).astype(np.float32),
|
||||
"template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32),
|
||||
}
|
||||
|
||||
|
||||
def make_template_features(
|
||||
input_sequence: str,
|
||||
hits: Sequence[Any],
|
||||
template_featurizer: Any,
|
||||
query_pdb_code: Optional[str] = None,
|
||||
query_release_date: Optional[str] = None,
|
||||
) -> FeatureDict:
|
||||
hits_cat = sum(hits.values(), [])
|
||||
if(len(hits_cat) == 0 or template_featurizer is None):
|
||||
@@ -61,11 +49,6 @@ def make_template_features(
|
||||
)
|
||||
template_features = templates_result.features
|
||||
|
||||
# The template featurizer doesn't format empty template features
|
||||
# properly. This is a quick fix.
|
||||
if(template_features["template_aatype"].shape[0] == 0):
|
||||
template_features = empty_template_feats(len(input_sequence))
|
||||
|
||||
return template_features
|
||||
|
||||
|
||||
@@ -453,7 +436,8 @@ class AlignmentRunner:
|
||||
if(uniprot_database_path is not None):
|
||||
self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer(
|
||||
binary_path=jackhmmer_binary_path,
|
||||
database_path=uniprot_database_path
|
||||
database_path=uniprot_database_path,
|
||||
n_cpu=no_cpus
|
||||
)
|
||||
|
||||
if(template_searcher is not None and
|
||||
@@ -800,37 +784,6 @@ class DataPipeline:
|
||||
|
||||
return all_hits
|
||||
|
||||
def _parse_template_hits(
|
||||
self,
|
||||
alignment_dir: str,
|
||||
alignment_index: Optional[Any] = None
|
||||
) -> Mapping[str, Any]:
|
||||
all_hits = {}
|
||||
if (alignment_index is not None):
|
||||
fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
|
||||
|
||||
def read_template(start, size):
|
||||
fp.seek(start)
|
||||
return fp.read(size).decode("utf-8")
|
||||
|
||||
for (name, start, size) in alignment_index["files"]:
|
||||
ext = os.path.splitext(name)[-1]
|
||||
|
||||
if (ext == ".hhr"):
|
||||
hits = parsers.parse_hhr(read_template(start, size))
|
||||
all_hits[name] = hits
|
||||
|
||||
fp.close()
|
||||
else:
|
||||
for f in os.listdir(alignment_dir):
|
||||
path = os.path.join(alignment_dir, f)
|
||||
ext = os.path.splitext(f)[-1]
|
||||
|
||||
if (ext == ".hhr"):
|
||||
with open(path, "r") as fp:
|
||||
hits = parsers.parse_hhr(fp.read())
|
||||
all_hits[f] = hits
|
||||
|
||||
def _get_msas(self,
|
||||
alignment_dir: str,
|
||||
input_sequence: Optional[str] = None,
|
||||
@@ -935,15 +888,15 @@ class DataPipeline:
|
||||
mmcif_feats = make_mmcif_features(mmcif, chain_id)
|
||||
|
||||
input_sequence = mmcif.chain_to_seqres[chain_id]
|
||||
hits = self._parse_template_hits(
|
||||
hits = self._parse_template_hit_files(
|
||||
alignment_dir,
|
||||
input_sequence,
|
||||
alignment_index)
|
||||
|
||||
template_features = make_template_features(
|
||||
input_sequence,
|
||||
hits,
|
||||
self.template_featurizer,
|
||||
query_release_date=to_date(mmcif.header["release_date"])
|
||||
self.template_featurizer
|
||||
)
|
||||
|
||||
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
|
||||
@@ -984,8 +937,9 @@ class DataPipeline:
|
||||
is_distillation=is_distillation
|
||||
)
|
||||
|
||||
hits = self._parse_template_hits(
|
||||
hits = self._parse_template_hit_files(
|
||||
alignment_dir,
|
||||
input_sequence,
|
||||
alignment_index
|
||||
)
|
||||
|
||||
@@ -1016,8 +970,9 @@ class DataPipeline:
|
||||
description = os.path.splitext(os.path.basename(core_path))[0].upper()
|
||||
core_feats = make_protein_features(protein_object, description)
|
||||
|
||||
hits = self._parse_template_hits(
|
||||
hits = self._parse_template_hit_files(
|
||||
alignment_dir,
|
||||
input_sequence,
|
||||
alignment_index
|
||||
)
|
||||
|
||||
@@ -1107,7 +1062,7 @@ class DataPipeline:
|
||||
alignment_dir = os.path.join(
|
||||
super_alignment_dir, desc
|
||||
)
|
||||
hits = self._parse_template_hits(alignment_dir, alignment_index=None)
|
||||
hits = self._parse_template_hit_files(alignment_dir, seq, alignment_index=None)
|
||||
template_features = make_template_features(
|
||||
seq,
|
||||
hits,
|
||||
|
||||
@@ -89,18 +89,17 @@ def make_all_atom_aatype(protein):
|
||||
def fix_templates_aatype(protein):
|
||||
# Map one-hot to indices
|
||||
num_templates = protein["template_aatype"].shape[0]
|
||||
if(num_templates > 0):
|
||||
protein["template_aatype"] = torch.argmax(
|
||||
protein["template_aatype"], dim=-1
|
||||
)
|
||||
# Map hhsearch-aatype to our aatype.
|
||||
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
|
||||
new_order = torch.tensor(
|
||||
new_order_list, dtype=torch.int64, device=protein["template_aatype"].device,
|
||||
).expand(num_templates, -1)
|
||||
protein["template_aatype"] = torch.gather(
|
||||
new_order, 1, index=protein["template_aatype"]
|
||||
)
|
||||
protein["template_aatype"] = torch.argmax(
|
||||
protein["template_aatype"], dim=-1
|
||||
)
|
||||
# Map hhsearch-aatype to our aatype.
|
||||
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
|
||||
new_order = torch.tensor(
|
||||
new_order_list, dtype=torch.int64, device=protein["template_aatype"].device,
|
||||
).expand(num_templates, -1)
|
||||
protein["template_aatype"] = torch.gather(
|
||||
new_order, 1, index=protein["template_aatype"]
|
||||
)
|
||||
|
||||
return protein
|
||||
|
||||
|
||||
@@ -2,7 +2,9 @@ from typing import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from openfold.config import NUM_RES
|
||||
from openfold.data.data_transforms import curry1
|
||||
from openfold.np import residue_constants as rc
|
||||
from openfold.utils.tensor_utils import masked_mean
|
||||
|
||||
|
||||
@@ -301,3 +303,177 @@ def make_msa_profile(batch):
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def get_interface_residues(positions, atom_mask, asym_id, interface_threshold):
|
||||
coord_diff = positions[..., None, :, :] - positions[..., None, :, :, :]
|
||||
pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
|
||||
|
||||
diff_chain_mask = (asym_id[..., None, :] != asym_id[..., :, None]).float()
|
||||
pair_mask = atom_mask[..., None, :] * atom_mask[..., None, :, :]
|
||||
mask = diff_chain_mask[..., None] * pair_mask
|
||||
|
||||
min_dist_per_res = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1)
|
||||
|
||||
valid_interfaces = torch.sum((min_dist_per_res < interface_threshold).float(), dim=-1)
|
||||
interface_residues_idxs = torch.nonzero(valid_interfaces, as_tuple=True)[0]
|
||||
|
||||
return interface_residues_idxs
|
||||
|
||||
|
||||
def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator):
|
||||
positions = protein["all_atom_positions"]
|
||||
atom_mask = protein["all_atom_mask"]
|
||||
asym_id = protein["asym_id"]
|
||||
|
||||
interface_residues = get_interface_residues(positions=positions,
|
||||
atom_mask=atom_mask,
|
||||
asym_id=asym_id,
|
||||
interface_threshold=interface_threshold)
|
||||
|
||||
if not torch.any(interface_residues):
|
||||
return get_contiguous_crop_idx(protein, crop_size, generator)
|
||||
|
||||
target_res = interface_residues[int(torch.randint(0, interface_residues.shape[-1], (1,),
|
||||
device=positions.device, generator=generator)[0])]
|
||||
|
||||
ca_idx = rc.atom_order["CA"]
|
||||
ca_positions = positions[..., ca_idx, :]
|
||||
ca_mask = atom_mask[..., ca_idx].bool()
|
||||
|
||||
coord_diff = ca_positions[..., None, :] - ca_positions[..., None, :, :]
|
||||
ca_pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1))
|
||||
|
||||
to_target_distances = ca_pairwise_dists[target_res]
|
||||
break_tie = (
|
||||
torch.arange(
|
||||
0, to_target_distances.shape[-1], device=positions.device
|
||||
).float()
|
||||
* 1e-3
|
||||
)
|
||||
to_target_distances = torch.where(ca_mask[..., None], to_target_distances, torch.inf) + break_tie
|
||||
|
||||
ret = torch.argsort(to_target_distances)[:crop_size]
|
||||
return ret.sort().values
|
||||
|
||||
|
||||
def randint(lower, upper, generator, device):
|
||||
return int(torch.randint(
|
||||
lower,
|
||||
upper + 1,
|
||||
(1,),
|
||||
device=device,
|
||||
generator=generator,
|
||||
)[0])
|
||||
|
||||
|
||||
def get_contiguous_crop_idx(protein, crop_size, generator):
|
||||
num_res = protein["aatype"].shape[0]
|
||||
if num_res <= crop_size:
|
||||
return torch.arange(num_res)
|
||||
|
||||
_, chain_lens = protein["asym_id"].unique(return_counts=True)
|
||||
shuffle_idx = torch.randperm(chain_lens.shape[-1], device=chain_lens.device, generator=generator)
|
||||
num_remaining = int(chain_lens.sum())
|
||||
num_budget = crop_size
|
||||
crop_idxs = []
|
||||
asym_offset = torch.tensor(0, dtype=torch.int64)
|
||||
for j, idx in enumerate(shuffle_idx):
|
||||
this_len = int(chain_lens[idx])
|
||||
num_remaining -= this_len
|
||||
# num res at most we can keep in this ent
|
||||
crop_size_max = min(num_budget, this_len)
|
||||
# num res at least we shall keep in this ent
|
||||
crop_size_min = min(this_len, max(0, num_budget - num_remaining))
|
||||
chain_crop_size = randint(lower=crop_size_min,
|
||||
upper=crop_size_max + 1,
|
||||
generator=generator,
|
||||
device=chain_lens.device)
|
||||
|
||||
chain_start = randint(lower=0,
|
||||
upper=this_len - chain_crop_size + 1,
|
||||
generator=generator,
|
||||
device=chain_lens.device)
|
||||
crop_idxs.append(
|
||||
torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size)
|
||||
)
|
||||
asym_offset += this_len
|
||||
|
||||
num_budget -= chain_crop_size
|
||||
|
||||
return torch.concat(crop_idxs)
|
||||
|
||||
|
||||
@curry1
|
||||
def random_crop_to_size(
|
||||
protein,
|
||||
crop_size,
|
||||
max_templates,
|
||||
shape_schema,
|
||||
spatial_crop_prob,
|
||||
interface_threshold,
|
||||
subsample_templates=False,
|
||||
seed=None,
|
||||
):
|
||||
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
|
||||
# We want each ensemble to be cropped the same way
|
||||
g = torch.Generator(device=protein["seq_length"].device)
|
||||
if seed is not None:
|
||||
g.manual_seed(seed)
|
||||
|
||||
use_spatial_crop = torch.rand((1,),
|
||||
device=protein["seq_length"].device,
|
||||
generator=g) < spatial_crop_prob
|
||||
if use_spatial_crop:
|
||||
crop_idxs = get_spatial_crop_idx(protein, crop_size, interface_threshold, g)
|
||||
else:
|
||||
crop_idxs = get_contiguous_crop_idx(protein, crop_size, g)
|
||||
|
||||
if "template_mask" in protein:
|
||||
num_templates = protein["template_mask"].shape[-1]
|
||||
else:
|
||||
num_templates = 0
|
||||
|
||||
# No need to subsample templates if there aren't any
|
||||
subsample_templates = subsample_templates and num_templates
|
||||
|
||||
if subsample_templates:
|
||||
templates_crop_start = randint(lower=0,
|
||||
upper=num_templates + 1,
|
||||
generator=g,
|
||||
device=protein["seq_length"].device)
|
||||
templates_select_indices = torch.randperm(
|
||||
num_templates, device=protein["seq_length"].device, generator=g
|
||||
)
|
||||
else:
|
||||
templates_crop_start = 0
|
||||
|
||||
num_res_crop_size = min(int(protein["seq_length"]), crop_size)
|
||||
num_templates_crop_size = min(
|
||||
num_templates - templates_crop_start, max_templates
|
||||
)
|
||||
|
||||
for k, v in protein.items():
|
||||
if k not in shape_schema or (
|
||||
"template" not in k and NUM_RES not in shape_schema[k]
|
||||
):
|
||||
continue
|
||||
|
||||
# randomly permute the templates before cropping them.
|
||||
if k.startswith("template") and subsample_templates:
|
||||
v = v[templates_select_indices]
|
||||
|
||||
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
|
||||
is_num_res = dim_size == NUM_RES
|
||||
if i == 0 and k.startswith("template"):
|
||||
crop_size = num_templates_crop_size
|
||||
crop_start = templates_crop_start
|
||||
v = v[slice(crop_start, crop_start + crop_size)]
|
||||
elif is_num_res:
|
||||
v = torch.index_select(v, i, crop_idxs)
|
||||
|
||||
protein[k] = v
|
||||
|
||||
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
|
||||
|
||||
return protein
|
||||
|
||||
@@ -104,7 +104,8 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed):
|
||||
# the masked locations and secret corrupted locations.
|
||||
transforms.append(
|
||||
data_transforms.make_masked_msa(
|
||||
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction
|
||||
common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction,
|
||||
seed=(msa_seed + 1) if msa_seed else None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -89,6 +89,24 @@ TEMPLATE_FEATURES = {
|
||||
}
|
||||
|
||||
|
||||
def empty_template_feats(n_res):
|
||||
return {
|
||||
"template_aatype": np.zeros(
|
||||
(0, n_res, len(residue_constants.restypes_with_x_and_gap)),
|
||||
np.float32
|
||||
),
|
||||
"template_all_atom_mask": np.zeros(
|
||||
(0, n_res, residue_constants.atom_type_num), np.float32
|
||||
),
|
||||
"template_all_atom_positions": np.zeros(
|
||||
(0, n_res, residue_constants.atom_type_num, 3), np.float32
|
||||
),
|
||||
"template_domain_names": np.array([''.encode()], dtype=np.object),
|
||||
"template_sequence": np.array([''.encode()], dtype=np.object),
|
||||
"template_sum_probs": np.zeros((0, 1), dtype=np.float32),
|
||||
}
|
||||
|
||||
|
||||
def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]:
|
||||
"""Returns PDB id and chain id for an HHSearch Hit."""
|
||||
# PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown.
|
||||
@@ -1163,21 +1181,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer):
|
||||
else:
|
||||
num_res = len(query_sequence)
|
||||
# Construct a default template with all zeros.
|
||||
template_features = {
|
||||
"template_aatype": np.zeros(
|
||||
(1, num_res, len(residue_constants.restypes_with_x_and_gap)),
|
||||
np.float32
|
||||
),
|
||||
"template_all_atom_masks": np.zeros(
|
||||
(1, num_res, residue_constants.atom_type_num), np.float32
|
||||
),
|
||||
"template_all_atom_positions": np.zeros(
|
||||
(1, num_res, residue_constants.atom_type_num, 3), np.float32
|
||||
),
|
||||
"template_domain_names": np.array([''.encode()], dtype=np.object),
|
||||
"template_sequence": np.array([''.encode()], dtype=np.object),
|
||||
"template_sum_probs": np.array([0], dtype=np.float32),
|
||||
}
|
||||
template_features = empty_template_feats(num_res)
|
||||
|
||||
return TemplateSearchResult(
|
||||
features=template_features, errors=errors, warnings=warnings
|
||||
@@ -1276,21 +1280,7 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer):
|
||||
else:
|
||||
num_res = len(query_sequence)
|
||||
# Construct a default template with all zeros.
|
||||
template_features = {
|
||||
"template_aatype": np.zeros(
|
||||
(1, num_res, len(residue_constants.restypes_with_x_and_gap)),
|
||||
np.float32
|
||||
),
|
||||
"template_all_atom_masks": np.zeros(
|
||||
(1, num_res, residue_constants.atom_type_num), np.float32
|
||||
),
|
||||
"template_all_atom_positions": np.zeros(
|
||||
(1, num_res, residue_constants.atom_type_num, 3), np.float32
|
||||
),
|
||||
"template_domain_names": np.array([''.encode()], dtype=np.object),
|
||||
"template_sequence": np.array([''.encode()], dtype=np.object),
|
||||
"template_sum_probs": np.array([0], dtype=np.float32),
|
||||
}
|
||||
template_features = empty_template_feats(num_res)
|
||||
|
||||
return TemplateSearchResult(
|
||||
features=template_features,
|
||||
|
||||
@@ -242,7 +242,7 @@ class InputEmbedderMultimer(nn.Module):
|
||||
|
||||
entity_id = batch["entity_id"]
|
||||
entity_id_same = (entity_id[..., None] == entity_id[..., None, :])
|
||||
rel_feats.append(entity_id_same[..., None])
|
||||
rel_feats.append(entity_id_same[..., None].to(dtype=rel_pos.dtype))
|
||||
|
||||
sym_id = batch["sym_id"]
|
||||
rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
|
||||
@@ -577,7 +577,7 @@ class TemplateEmbedder(nn.Module):
|
||||
# a second copy during the stack later on
|
||||
t_pair = z.new_zeros(
|
||||
z.shape[:-3] +
|
||||
(n_templ, n, n, self.config.template_pair_embedder.c_t)
|
||||
(n_templ, n, n, self.config.template_pair_embedder.c_out)
|
||||
)
|
||||
|
||||
for i in range(n_templ):
|
||||
@@ -667,17 +667,17 @@ class TemplatePairEmbedderMultimer(nn.Module):
|
||||
):
|
||||
super(TemplatePairEmbedderMultimer, self).__init__()
|
||||
|
||||
self.dgram_linear = Linear(c_dgram, c_out)
|
||||
self.aatype_linear_1 = Linear(c_aatype, c_out)
|
||||
self.aatype_linear_2 = Linear(c_aatype, c_out)
|
||||
self.dgram_linear = Linear(c_dgram, c_out, init='relu')
|
||||
self.aatype_linear_1 = Linear(c_aatype, c_out, init='relu')
|
||||
self.aatype_linear_2 = Linear(c_aatype, c_out, init='relu')
|
||||
self.query_embedding_layer_norm = LayerNorm(c_z)
|
||||
self.query_embedding_linear = Linear(c_z, c_out)
|
||||
self.query_embedding_linear = Linear(c_z, c_out, init='relu')
|
||||
|
||||
self.pseudo_beta_mask_linear = Linear(1, c_out)
|
||||
self.x_linear = Linear(1, c_out)
|
||||
self.y_linear = Linear(1, c_out)
|
||||
self.z_linear = Linear(1, c_out)
|
||||
self.backbone_mask_linear = Linear(1, c_out)
|
||||
self.pseudo_beta_mask_linear = Linear(1, c_out, init='relu')
|
||||
self.x_linear = Linear(1, c_out, init='relu')
|
||||
self.y_linear = Linear(1, c_out, init='relu')
|
||||
self.z_linear = Linear(1, c_out, init='relu')
|
||||
self.backbone_mask_linear = Linear(1, c_out, init='relu')
|
||||
|
||||
def forward(self,
|
||||
template_dgram: torch.Tensor,
|
||||
@@ -812,10 +812,10 @@ class TemplateEmbedderMultimer(nn.Module):
|
||||
single_template_embeds = {}
|
||||
act = 0.
|
||||
|
||||
template_positions, pseudo_beta_mask = (
|
||||
single_template_feats["template_pseudo_beta"],
|
||||
single_template_feats["template_pseudo_beta_mask"],
|
||||
)
|
||||
template_positions, pseudo_beta_mask = pseudo_beta_fn(
|
||||
single_template_feats["template_aatype"],
|
||||
single_template_feats["template_all_atom_positions"],
|
||||
single_template_feats["template_all_atom_mask"])
|
||||
|
||||
template_dgram = dgram_from_positions(
|
||||
template_positions,
|
||||
|
||||
@@ -186,11 +186,6 @@ class AlphaFold(nn.Module):
|
||||
if self.config.recycle_early_stop_tolerance < 0:
|
||||
return False
|
||||
|
||||
if no_batch_dims == 0:
|
||||
prev_pos = prev_pos.unsqueeze(dim=0)
|
||||
next_pos = next_pos.unsqueeze(dim=0)
|
||||
mask = mask.unsqueeze(dim=0)
|
||||
|
||||
ca_idx = residue_constants.atom_order['CA']
|
||||
sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2
|
||||
mask = mask[..., None] * mask[..., None, :]
|
||||
@@ -265,7 +260,7 @@ class AlphaFold(nn.Module):
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
x_prev = pseudo_beta_fn(
|
||||
pseudo_beta_x_prev = pseudo_beta_fn(
|
||||
feats["aatype"], x_prev, None
|
||||
).to(dtype=z.dtype)
|
||||
|
||||
@@ -279,10 +274,12 @@ class AlphaFold(nn.Module):
|
||||
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
|
||||
m_1_prev,
|
||||
z_prev,
|
||||
x_prev,
|
||||
pseudo_beta_x_prev,
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
|
||||
del pseudo_beta_x_prev
|
||||
|
||||
if(self.globals.offload_inference and inplace_safe):
|
||||
m = m.to(m_1_prev_emb.device)
|
||||
z = z.to(z_prev.device)
|
||||
|
||||
@@ -166,12 +166,14 @@ class PointProjection(nn.Module):
|
||||
c_hidden: int,
|
||||
num_points: int,
|
||||
no_heads: int,
|
||||
is_multimer: bool,
|
||||
return_local_points: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.return_local_points = return_local_points
|
||||
self.no_heads = no_heads
|
||||
self.num_points = num_points
|
||||
self.is_multimer = is_multimer
|
||||
|
||||
self.linear = Linear(c_hidden, no_heads * 3 * num_points)
|
||||
|
||||
@@ -181,24 +183,19 @@ class PointProjection(nn.Module):
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
# TODO: Needs to run in high precision during training
|
||||
points_local = self.linear(activations)
|
||||
out_shape = points_local.shape[:-1] + (self.no_heads, self.num_points, 3)
|
||||
|
||||
if isinstance(rigids, Rigid3Array):
|
||||
points_local = points_local.reshape(
|
||||
*points_local.shape[:-1],
|
||||
self.no_heads,
|
||||
-1,
|
||||
if self.is_multimer:
|
||||
points_local = points_local.view(
|
||||
points_local.shape[:-1] + (self.no_heads, -1)
|
||||
)
|
||||
|
||||
points_local = torch.split(
|
||||
points_local, points_local.shape[-1] // 3, dim=-1
|
||||
)
|
||||
|
||||
points_local = torch.stack(points_local, dim=-1)
|
||||
points_local = torch.stack(points_local, dim=-1).view(out_shape)
|
||||
|
||||
if not isinstance(rigids, Rigid3Array):
|
||||
points_local = points_local.reshape(
|
||||
*points_local.shape[:-2], self.no_heads, -1, 3
|
||||
)
|
||||
points_global = rigids[..., None, None].apply(points_local)
|
||||
|
||||
if(self.return_local_points):
|
||||
@@ -260,7 +257,8 @@ class InvariantPointAttention(nn.Module):
|
||||
self.linear_q_points = PointProjection(
|
||||
self.c_s,
|
||||
self.no_qk_points,
|
||||
self.no_heads
|
||||
self.no_heads,
|
||||
self.is_multimer
|
||||
)
|
||||
|
||||
if(is_multimer):
|
||||
@@ -270,12 +268,14 @@ class InvariantPointAttention(nn.Module):
|
||||
self.c_s,
|
||||
self.no_qk_points,
|
||||
self.no_heads,
|
||||
self.is_multimer
|
||||
)
|
||||
|
||||
self.linear_v_points = PointProjection(
|
||||
self.c_s,
|
||||
self.no_v_points,
|
||||
self.no_heads,
|
||||
self.is_multimer
|
||||
)
|
||||
else:
|
||||
self.linear_kv = Linear(self.c_s, 2 * hc)
|
||||
@@ -283,6 +283,7 @@ class InvariantPointAttention(nn.Module):
|
||||
self.c_s,
|
||||
self.no_qk_points + self.no_v_points,
|
||||
self.no_heads,
|
||||
self.is_multimer
|
||||
)
|
||||
|
||||
self.linear_b = Linear(self.c_z, self.no_heads)
|
||||
@@ -504,6 +505,230 @@ class InvariantPointAttention(nn.Module):
|
||||
return s
|
||||
|
||||
|
||||
#TODO: This module follows the refactoring done in IPA for multimer. Running the regular IPA above
|
||||
# in multimer mode should be equivalent, but tests do not pass unless using this version. Determine
|
||||
# whether or not the increase in test error matters in practice.
|
||||
class InvariantPointAttentionMultimer(nn.Module):
|
||||
"""
|
||||
Implements Algorithm 22.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
c_s: int,
|
||||
c_z: int,
|
||||
c_hidden: int,
|
||||
no_heads: int,
|
||||
no_qk_points: int,
|
||||
no_v_points: int,
|
||||
inf: float = 1e5,
|
||||
eps: float = 1e-8,
|
||||
is_multimer: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
c_s:
|
||||
Single representation channel dimension
|
||||
c_z:
|
||||
Pair representation channel dimension
|
||||
c_hidden:
|
||||
Hidden channel dimension
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
no_qk_points:
|
||||
Number of query/key points to generate
|
||||
no_v_points:
|
||||
Number of value points to generate
|
||||
"""
|
||||
super(InvariantPointAttentionMultimer, self).__init__()
|
||||
|
||||
self.c_s = c_s
|
||||
self.c_z = c_z
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.no_qk_points = no_qk_points
|
||||
self.no_v_points = no_v_points
|
||||
self.inf = inf
|
||||
self.eps = eps
|
||||
|
||||
# These linear layers differ from their specifications in the
|
||||
# supplement. There, they lack bias and use Glorot initialization.
|
||||
# Here as in the official source, they have bias and use the default
|
||||
# Lecun initialization.
|
||||
hc = self.c_hidden * self.no_heads
|
||||
self.linear_q = Linear(self.c_s, hc, bias=False)
|
||||
|
||||
self.linear_q_points = PointProjection(
|
||||
self.c_s,
|
||||
self.no_qk_points,
|
||||
self.no_heads,
|
||||
is_multimer=True
|
||||
)
|
||||
|
||||
self.linear_k = Linear(self.c_s, hc, bias=False)
|
||||
self.linear_v = Linear(self.c_s, hc, bias=False)
|
||||
self.linear_k_points = PointProjection(
|
||||
self.c_s,
|
||||
self.no_qk_points,
|
||||
self.no_heads,
|
||||
is_multimer=True
|
||||
)
|
||||
|
||||
self.linear_v_points = PointProjection(
|
||||
self.c_s,
|
||||
self.no_v_points,
|
||||
self.no_heads,
|
||||
is_multimer=True
|
||||
)
|
||||
|
||||
self.linear_b = Linear(self.c_z, self.no_heads)
|
||||
|
||||
self.head_weights = nn.Parameter(torch.zeros((no_heads)))
|
||||
ipa_point_weights_init_(self.head_weights)
|
||||
|
||||
concat_out_dim = self.no_heads * (
|
||||
self.c_z + self.c_hidden + self.no_v_points * 4
|
||||
)
|
||||
self.linear_out = Linear(concat_out_dim, self.c_s, init="final")
|
||||
|
||||
self.softmax = nn.Softmax(dim=-2)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
s: torch.Tensor,
|
||||
z: Optional[torch.Tensor],
|
||||
r: Union[Rigid, Rigid3Array],
|
||||
mask: torch.Tensor,
|
||||
inplace_safe: bool = False,
|
||||
_offload_inference: bool = False,
|
||||
_z_reference_list: Optional[Sequence[torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
s:
|
||||
[*, N_res, C_s] single representation
|
||||
z:
|
||||
[*, N_res, N_res, C_z] pair representation
|
||||
r:
|
||||
[*, N_res] transformation object
|
||||
mask:
|
||||
[*, N_res] mask
|
||||
Returns:
|
||||
[*, N_res, C_s] single representation update
|
||||
"""
|
||||
if(_offload_inference and inplace_safe):
|
||||
z = _z_reference_list
|
||||
else:
|
||||
z = [z]
|
||||
|
||||
a = 0.
|
||||
|
||||
point_variance = (max(self.no_qk_points, 1) * 9.0 / 2)
|
||||
point_weights = math.sqrt(1.0 / point_variance)
|
||||
|
||||
softplus = lambda x: torch.logaddexp(x, torch.zeros_like(x))
|
||||
|
||||
head_weights = softplus(self.head_weights)
|
||||
point_weights = point_weights * head_weights
|
||||
|
||||
#######################################
|
||||
# Generate scalar and point activations
|
||||
#######################################
|
||||
|
||||
# [*, N_res, H, P_qk]
|
||||
q_pts = Vec3Array.from_array(self.linear_q_points(s, r))
|
||||
|
||||
# [*, N_res, H, P_qk, 3]
|
||||
k_pts = Vec3Array.from_array(self.linear_k_points(s, r))
|
||||
|
||||
pt_att = square_euclidean_distance(q_pts.unsqueeze(-3), k_pts.unsqueeze(-4), epsilon=0.)
|
||||
pt_att = torch.sum(pt_att * point_weights[..., None], dim=-1) * (-0.5)
|
||||
a = a + pt_att
|
||||
|
||||
scalar_variance = max(self.c_hidden, 1) * 1.
|
||||
scalar_weights = math.sqrt(1.0 / scalar_variance)
|
||||
|
||||
# [*, N_res, H * C_hidden]
|
||||
q = self.linear_q(s)
|
||||
k = self.linear_k(s)
|
||||
|
||||
# [*, N_res, H, C_hidden]
|
||||
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
||||
k = k.view(k.shape[:-1] + (self.no_heads, -1))
|
||||
|
||||
q = q * scalar_weights
|
||||
a = a + torch.einsum('...qhc,...khc->...qkh', q, k)
|
||||
|
||||
##########################
|
||||
# Compute attention scores
|
||||
##########################
|
||||
# [*, N_res, N_res, H]
|
||||
b = self.linear_b(z[0])
|
||||
|
||||
if (_offload_inference):
|
||||
assert (sys.getrefcount(z[0]) == 2)
|
||||
z[0] = z[0].cpu()
|
||||
|
||||
a = a + b
|
||||
|
||||
# [*, N_res, N_res]
|
||||
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
|
||||
square_mask = self.inf * (square_mask - 1)
|
||||
|
||||
a = a + square_mask.unsqueeze(-1)
|
||||
a = a * math.sqrt(1. / 3) # Normalize by number of logit terms (3)
|
||||
a = self.softmax(a)
|
||||
|
||||
# [*, N_res, H * C_hidden]
|
||||
v = self.linear_v(s)
|
||||
|
||||
# [*, N_res, H, C_hidden]
|
||||
v = v.view(v.shape[:-1] + (self.no_heads, -1))
|
||||
|
||||
o = torch.einsum('...qkh, ...khc->...qhc', a, v)
|
||||
|
||||
# [*, N_res, H * C_hidden]
|
||||
o = flatten_final_dims(o, 2)
|
||||
|
||||
# [*, N_res, H, P_v, 3]
|
||||
v_pts = Vec3Array.from_array(self.linear_v_points(s, r))
|
||||
|
||||
# [*, N_res, H, P_v]
|
||||
o_pt = v_pts[..., None, :, :, :] * a.unsqueeze(-1)
|
||||
o_pt = o_pt.sum(dim=-3)
|
||||
# o_pt = Vec3Array(
|
||||
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].x, dim=-3),
|
||||
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].y, dim=-3),
|
||||
# torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].z, dim=-3),
|
||||
# )
|
||||
|
||||
# [*, N_res, H * P_v, 3]
|
||||
o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,))
|
||||
|
||||
# [*, N_res, H, P_v]
|
||||
o_pt = r[..., None].apply_inverse_to_point(o_pt)
|
||||
o_pt_flat = [o_pt.x, o_pt.y, o_pt.z]
|
||||
|
||||
# [*, N_res, H * P_v]
|
||||
o_pt_norm = o_pt.norm(epsilon=1e-8)
|
||||
|
||||
if (_offload_inference):
|
||||
z[0] = z[0].to(o_pt.device)
|
||||
|
||||
o_pair = torch.einsum('...ijh, ...ijc->...ihc', a, z[0].to(dtype=a.dtype))
|
||||
|
||||
# [*, N_res, H * C_z]
|
||||
o_pair = flatten_final_dims(o_pair, 2)
|
||||
|
||||
# [*, N_res, C_s]
|
||||
s = self.linear_out(
|
||||
torch.cat(
|
||||
(o, *o_pt_flat, o_pt_norm, o_pair), dim=-1
|
||||
).to(dtype=z[0].dtype)
|
||||
)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class BackboneUpdate(nn.Module):
|
||||
"""
|
||||
Implements part of Algorithm 23.
|
||||
@@ -670,7 +895,8 @@ class StructureModule(nn.Module):
|
||||
|
||||
self.linear_in = Linear(self.c_s, self.c_s)
|
||||
|
||||
self.ipa = InvariantPointAttention(
|
||||
ipa = InvariantPointAttention if not self.is_multimer else InvariantPointAttentionMultimer
|
||||
self.ipa = ipa(
|
||||
self.c_s,
|
||||
self.c_z,
|
||||
self.c_ipa,
|
||||
|
||||
@@ -521,12 +521,8 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
||||
|
||||
def compute_projection(pair, mask):
|
||||
p = compute_projection_helper(pair, mask)
|
||||
if self._outgoing:
|
||||
left = p[..., :self.c_hidden]
|
||||
right = p[..., self.c_hidden:]
|
||||
else:
|
||||
left = p[..., self.c_hidden:]
|
||||
right = p[..., :self.c_hidden]
|
||||
left = p[..., :self.c_hidden]
|
||||
right = p[..., self.c_hidden:]
|
||||
|
||||
return left, right
|
||||
|
||||
@@ -580,12 +576,8 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
||||
ab = ab * self.sigmoid(self.linear_ab_g(z))
|
||||
ab = ab * self.linear_ab_p(z)
|
||||
|
||||
if self._outgoing:
|
||||
a = ab[..., :self.c_hidden]
|
||||
b = ab[..., self.c_hidden:]
|
||||
else:
|
||||
b = ab[..., :self.c_hidden]
|
||||
a = ab[..., self.c_hidden:]
|
||||
a = ab[..., :self.c_hidden]
|
||||
b = ab[..., self.c_hidden:]
|
||||
|
||||
# Prevents overflow of torch.matmul in combine projections in
|
||||
# reduced-precision modes
|
||||
|
||||
@@ -36,9 +36,9 @@ def get_rc_tensor(rc_np, aatype):
|
||||
def atom14_to_atom37(
|
||||
atom14_data: torch.Tensor, # (*, N, 14, ...)
|
||||
aatype: torch.Tensor # (*, N)
|
||||
) -> torch.Tensor: # (*, N, 37, ...)
|
||||
) -> Tuple: # (*, N, 37, ...)
|
||||
"""Convert atom14 to atom37 representation."""
|
||||
idx_atom37_to_atom14 = get_rc_tensor(rc.RESTYPE_ATOM37_TO_ATOM14, aatype)
|
||||
idx_atom37_to_atom14 = get_rc_tensor(rc.RESTYPE_ATOM37_TO_ATOM14, aatype).long()
|
||||
no_batch_dims = len(aatype.shape) - 1
|
||||
atom37_data = tensor_utils.batched_gather(
|
||||
atom14_data,
|
||||
@@ -50,10 +50,10 @@ def atom14_to_atom37(
|
||||
if len(atom14_data.shape) == no_batch_dims + 2:
|
||||
atom37_data *= atom37_mask
|
||||
elif len(atom14_data.shape) == no_batch_dims + 3:
|
||||
atom37_data *= atom37_mask[..., None].astype(atom37_data.dtype)
|
||||
atom37_data *= atom37_mask[..., None].to(dtype=atom37_data.dtype)
|
||||
else:
|
||||
raise ValueError("Incorrectly shaped data")
|
||||
return atom37_data
|
||||
return atom37_data, atom37_mask
|
||||
|
||||
|
||||
def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask):
|
||||
@@ -230,13 +230,13 @@ def torsion_angles_to_frames(
|
||||
num_residues = aatype.shape[-1]
|
||||
sin_angles = torch.cat(
|
||||
[
|
||||
torch.zeros_like(aatype).unsqueeze(),
|
||||
torch.zeros_like(aatype).unsqueeze(dim=-1),
|
||||
sin_angles,
|
||||
],
|
||||
dim=-1)
|
||||
cos_angles = torch.cat(
|
||||
[
|
||||
torch.ones_like(aatype).unsqueeze(),
|
||||
torch.ones_like(aatype).unsqueeze(dim=-1),
|
||||
cos_angles
|
||||
],
|
||||
dim=-1
|
||||
|
||||
@@ -20,7 +20,7 @@ class QuatRigid(nn.Module):
|
||||
|
||||
def forward(self, activations: torch.Tensor) -> Rigid3Array:
|
||||
# NOTE: During training, this needs to be run in higher precision
|
||||
rigid_flat = self.linear(activations.to(torch.float32))
|
||||
rigid_flat = self.linear(activations)
|
||||
|
||||
rigid_flat = torch.unbind(rigid_flat, dim=-1)
|
||||
if(self.full_quat):
|
||||
|
||||
@@ -172,20 +172,20 @@ class Rot3Array:
|
||||
) -> Rot3Array:
|
||||
"""Construct Rot3Array from components of quaternion."""
|
||||
if normalize:
|
||||
inv_norm = torch.rsqrt(eps + w**2 + x**2 + y**2 + z**2)
|
||||
inv_norm = torch.rsqrt(torch.clamp(w**2 + x**2 + y**2 + z**2, min=eps))
|
||||
w = w * inv_norm
|
||||
x = x * inv_norm
|
||||
y = y * inv_norm
|
||||
z = z * inv_norm
|
||||
xx = 1 - 2 * (y ** 2 + z ** 2)
|
||||
xy = 2 * (x * y - w * z)
|
||||
xz = 2 * (x * z + w * y)
|
||||
yx = 2 * (x * y + w * z)
|
||||
yy = 1 - 2 * (x ** 2 + z ** 2)
|
||||
yz = 2 * (y * z - w * x)
|
||||
zx = 2 * (x * z - w * y)
|
||||
zy = 2 * (y * z + w * x)
|
||||
zz = 1 - 2 * (x ** 2 + y ** 2)
|
||||
xx = 1.0 - 2.0 * (y ** 2 + z ** 2)
|
||||
xy = 2.0 * (x * y - w * z)
|
||||
xz = 2.0 * (x * z + w * y)
|
||||
yx = 2.0 * (x * y + w * z)
|
||||
yy = 1.0 - 2.0 * (x ** 2 + z ** 2)
|
||||
yz = 2.0 * (y * z - w * x)
|
||||
zx = 2.0 * (x * z - w * y)
|
||||
zy = 2.0 * (y * z + w * x)
|
||||
zz = 1.0 - 2.0 * (x ** 2 + y ** 2)
|
||||
return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz)
|
||||
|
||||
def reshape(self, new_shape):
|
||||
|
||||
@@ -28,7 +28,7 @@ _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/"
|
||||
# With Param, a poor man's enum with attributes (Rust-style)
|
||||
class ParamType(Enum):
|
||||
LinearWeight = partial( # hack: partial prevents fns from becoming methods
|
||||
lambda w: w.transpose(-1, -2)
|
||||
lambda w: w.unsqueeze(-1) if len(w.shape) == 1 else w.transpose(-1, -2)
|
||||
)
|
||||
LinearWeightMHA = partial(
|
||||
lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2)
|
||||
@@ -58,6 +58,7 @@ class Param:
|
||||
param: Union[torch.Tensor, List[torch.Tensor]]
|
||||
param_type: ParamType = ParamType.Other
|
||||
stacked: bool = False
|
||||
swap: bool = False
|
||||
|
||||
|
||||
def process_translation_dict(d, top_layer=True):
|
||||
@@ -101,6 +102,7 @@ def stacked(param_dict_list, out=None):
|
||||
param=[param.param for param in v],
|
||||
param_type=v[0].param_type,
|
||||
stacked=True,
|
||||
swap=v[0].swap
|
||||
)
|
||||
|
||||
out[k] = stacked_param
|
||||
@@ -122,7 +124,12 @@ def assign(translation_dict, orig_weights):
|
||||
try:
|
||||
weights = list(map(param_type.transformation, weights))
|
||||
for p, w in zip(ref, weights):
|
||||
p.copy_(w)
|
||||
if param.swap:
|
||||
index = p.shape[0] // 2
|
||||
p[:index].copy_(w[index:])
|
||||
p[index:].copy_(w[:index])
|
||||
else:
|
||||
p.copy_(w)
|
||||
except:
|
||||
print(k)
|
||||
print(ref[0].shape)
|
||||
@@ -145,12 +152,24 @@ def generate_translation_dict(model, version, is_multimer=False):
|
||||
LinearBiasMultimer = lambda l: (
|
||||
Param(l, param_type=ParamType.LinearBiasMultimer)
|
||||
)
|
||||
LinearWeightSwap = lambda l: (Param(l, param_type=ParamType.LinearWeight, swap=True))
|
||||
LinearBiasSwap = lambda l: (Param(l, swap=True))
|
||||
|
||||
LinearParams = lambda l: {
|
||||
"weights": LinearWeight(l.weight),
|
||||
"bias": LinearBias(l.bias),
|
||||
}
|
||||
|
||||
LinearParamsMHA = lambda l: {
|
||||
"weights": LinearWeightMHA(l.weight),
|
||||
"bias": LinearBiasMHA(l.bias),
|
||||
}
|
||||
|
||||
LinearParamsSwap = lambda l: {
|
||||
"weights": LinearWeightSwap(l.weight),
|
||||
"bias": LinearBiasSwap(l.bias),
|
||||
}
|
||||
|
||||
LinearParamsMultimer = lambda l: {
|
||||
"weights": LinearWeightMultimer(l.weight),
|
||||
"bias": LinearBiasMultimer(l.bias),
|
||||
@@ -194,10 +213,11 @@ def generate_translation_dict(model, version, is_multimer=False):
|
||||
|
||||
def TriMulOutParams(tri_mul, outgoing=True):
|
||||
if re.fullmatch("^model_[1-5]_multimer_v3$", version):
|
||||
lin_param_type = LinearParams if outgoing else LinearParamsSwap
|
||||
d = {
|
||||
"left_norm_input": LayerNormParams(tri_mul.layer_norm_in),
|
||||
"projection": LinearParams(tri_mul.linear_ab_p),
|
||||
"gate": LinearParams(tri_mul.linear_ab_g),
|
||||
"projection": lin_param_type(tri_mul.linear_ab_p),
|
||||
"gate": lin_param_type(tri_mul.linear_ab_g),
|
||||
"center_norm": LayerNormParams(tri_mul.layer_norm_out),
|
||||
}
|
||||
else:
|
||||
@@ -276,24 +296,24 @@ def generate_translation_dict(model, version, is_multimer=False):
|
||||
}
|
||||
|
||||
PointProjectionParams = lambda pp: {
|
||||
"point_projection": LinearParamsMultimer(
|
||||
"point_projection": LinearParamsMHA(
|
||||
pp.linear,
|
||||
),
|
||||
}
|
||||
|
||||
IPAParamsMultimer = lambda ipa: {
|
||||
"q_scalar_projection": {
|
||||
"weights": LinearWeightMultimer(
|
||||
"weights": LinearWeightMHA(
|
||||
ipa.linear_q.weight,
|
||||
),
|
||||
},
|
||||
"k_scalar_projection": {
|
||||
"weights": LinearWeightMultimer(
|
||||
"weights": LinearWeightMHA(
|
||||
ipa.linear_k.weight,
|
||||
),
|
||||
},
|
||||
"v_scalar_projection": {
|
||||
"weights": LinearWeightMultimer(
|
||||
"weights": LinearWeightMHA(
|
||||
ipa.linear_v.weight,
|
||||
),
|
||||
},
|
||||
@@ -574,7 +594,7 @@ def generate_translation_dict(model, version, is_multimer=False):
|
||||
"template_pair_embedding_0": LinearParams(
|
||||
temp_embedder.template_pair_embedder.dgram_linear
|
||||
),
|
||||
"template_pair_embedding_1": LinearParamsMultimer(
|
||||
"template_pair_embedding_1": LinearParams(
|
||||
temp_embedder.template_pair_embedder.pseudo_beta_mask_linear
|
||||
),
|
||||
"template_pair_embedding_2": LinearParams(
|
||||
@@ -583,16 +603,16 @@ def generate_translation_dict(model, version, is_multimer=False):
|
||||
"template_pair_embedding_3": LinearParams(
|
||||
temp_embedder.template_pair_embedder.aatype_linear_2
|
||||
),
|
||||
"template_pair_embedding_4": LinearParamsMultimer(
|
||||
"template_pair_embedding_4": LinearParams(
|
||||
temp_embedder.template_pair_embedder.x_linear
|
||||
),
|
||||
"template_pair_embedding_5": LinearParamsMultimer(
|
||||
"template_pair_embedding_5": LinearParams(
|
||||
temp_embedder.template_pair_embedder.y_linear
|
||||
),
|
||||
"template_pair_embedding_6": LinearParamsMultimer(
|
||||
"template_pair_embedding_6": LinearParams(
|
||||
temp_embedder.template_pair_embedder.z_linear
|
||||
),
|
||||
"template_pair_embedding_7": LinearParamsMultimer(
|
||||
"template_pair_embedding_7": LinearParams(
|
||||
temp_embedder.template_pair_embedder.backbone_mask_linear
|
||||
),
|
||||
"template_pair_embedding_8": LinearParams(
|
||||
@@ -600,7 +620,7 @@ def generate_translation_dict(model, version, is_multimer=False):
|
||||
),
|
||||
"template_embedding_iteration": tps_blocks_params,
|
||||
"output_layer_norm": LayerNormParams(
|
||||
model.template_embedder.template_pair_stack.layer_norm
|
||||
temp_embedder.template_pair_stack.layer_norm
|
||||
),
|
||||
},
|
||||
"output_linear": LinearParams(
|
||||
|
||||
@@ -1643,7 +1643,7 @@ def chain_center_of_mass_loss(
|
||||
all_atom_positions = all_atom_positions[..., ca_pos, :]
|
||||
all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim
|
||||
|
||||
chains, _ = asym_id.unique(return_counts=True)
|
||||
chains = asym_id.unique()
|
||||
one_hot = torch.nn.functional.one_hot(asym_id, num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype)
|
||||
one_hot = one_hot * all_atom_mask
|
||||
chain_pos_mask = one_hot.transpose(-2, -1)
|
||||
|
||||
@@ -431,7 +431,7 @@ if __name__ == "__main__":
|
||||
help="""Postfix for output prediction filenames"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data_random_seed", type=str, default=None
|
||||
"--data_random_seed", type=int, default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_relaxation", action="store_true", default=False,
|
||||
|
||||
@@ -45,7 +45,7 @@ def main(args):
|
||||
uniref90_database_path=args.uniref90_database_path,
|
||||
mgnify_database_path=args.mgnify_database_path,
|
||||
bfd_database_path=args.bfd_database_path,
|
||||
uniclust30_database_path=args.uniclust30_database_path,
|
||||
uniref30_database_path=args.uniref30_database_path,
|
||||
small_bfd_database_path=None,
|
||||
template_featurizer=template_featurizer,
|
||||
template_searcher=template_searcher,
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from tests.config import consts
|
||||
from openfold.config import model_config
|
||||
from openfold.model.model import AlphaFold
|
||||
from openfold.utils.import_weights import import_jax_weights_
|
||||
@@ -23,15 +25,17 @@ from openfold.utils.import_weights import import_jax_weights_
|
||||
|
||||
class TestImportWeights(unittest.TestCase):
|
||||
def test_import_jax_weights_(self):
|
||||
npz_path = "openfold/resources/params/params_model_1_ptm.npz"
|
||||
npz_path = Path(__file__).parent.resolve() / f"../openfold/resources/params/params_{consts.model}.npz"
|
||||
|
||||
c = model_config("model_1_ptm")
|
||||
c = model_config(consts.model)
|
||||
c.globals.blocks_per_ckpt = None
|
||||
model = AlphaFold(c)
|
||||
model.eval()
|
||||
|
||||
import_jax_weights_(
|
||||
model,
|
||||
npz_path,
|
||||
version=consts.model
|
||||
)
|
||||
|
||||
data = np.load(npz_path)
|
||||
|
||||
@@ -22,7 +22,7 @@ from openfold.data.data_modules import (
|
||||
from openfold.model.model import AlphaFold
|
||||
from openfold.model.torchscript import script_preset_
|
||||
from openfold.np import residue_constants
|
||||
from openfold.utils.argparse import remove_arguments
|
||||
from openfold.utils.argparse_utils import remove_arguments
|
||||
from openfold.utils.callbacks import (
|
||||
EarlyStoppingVerbose,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user