From fdcb72e8460117c95459497424608bb76294f71d Mon Sep 17 00:00:00 2001 From: Christina Floristean Date: Wed, 2 Aug 2023 15:15:06 -0400 Subject: [PATCH] Bug fixes for multimer inference and monomer training --- openfold/config.py | 267 +++++++++--------- openfold/data/data_modules.py | 2 +- openfold/data/data_pipeline.py | 67 +---- openfold/data/data_transforms.py | 23 +- openfold/data/data_transforms_multimer.py | 176 ++++++++++++ openfold/data/input_pipeline.py | 3 +- openfold/data/templates.py | 50 ++-- openfold/model/embedders.py | 30 +- openfold/model/model.py | 11 +- openfold/model/structure_module.py | 250 +++++++++++++++- .../model/triangular_multiplicative_update.py | 16 +- openfold/utils/all_atom_multimer.py | 12 +- openfold/utils/geometry/quat_rigid.py | 2 +- openfold/utils/geometry/rotation_matrix.py | 20 +- openfold/utils/import_weights.py | 48 +++- openfold/utils/loss.py | 2 +- run_pretrained_openfold.py | 2 +- scripts/generate_alphafold_feature_dict.py | 2 +- tests/test_import_weights.py | 8 +- train_openfold.py | 2 +- 20 files changed, 679 insertions(+), 314 deletions(-) diff --git a/openfold/config.py b/openfold/config.py index dbf9c83..01bdf39 100644 --- a/openfold/config.py +++ b/openfold/config.py @@ -155,14 +155,17 @@ def model_config( c.loss.tm.weight = 0.1 elif "multimer" in name: c.globals.is_multimer = True - c.globals.bfloat16 = True + c.globals.bfloat16 = False c.globals.bfloat16_output = False c.loss.masked_msa.num_classes = 22 c.data.common.max_recycling_iters = 20 - for k,v in multimer_model_config_update.items(): + for k,v in multimer_model_config_update['model'].items(): c.model[k] = v + for k, v in multimer_model_config_update['loss'].items(): + c.loss[k] = v + # TODO: Change max_msa_clusters and max_extra_msa to multimer feats within model if re.fullmatch("^model_[1-5]_multimer(_v2)?$", name): #c.model.input_embedder.num_msa = 252 @@ -590,6 +593,12 @@ config = mlc.ConfigDict( "c_out": 37, }, }, + # A negative value indicates that no early stopping will occur, i.e. + # the model will always run `max_recycling_iters` number of recycling + # iterations. A positive value will enable early stopping if the + # difference in pairwise distances is less than the tolerance between + # recycling steps. + "recycle_early_stop_tolerance": -1. }, "relax": { "max_iterations": 0, # no max @@ -670,157 +679,154 @@ config = mlc.ConfigDict( "eps": eps, }, "ema": {"decay": 0.999}, - # A negative value indicates that no early stopping will occur, i.e. - # the model will always run `max_recycling_iters` number of recycling - # iterations. A positive value will enable early stopping if the - # difference in pairwise distances is less than the tolerance between - # recycling steps. - "recycle_early_stop_tolerance": -1 } ) multimer_model_config_update = { - "input_embedder": { - "tf_dim": 21, - "msa_dim": 49, - #"num_msa": 508, - "c_z": c_z, - "c_m": c_m, - "relpos_k": 32, - "max_relative_chain": 2, - "max_relative_idx": 32, - "use_chain_relative": True, - }, - "template": { - "distogram": { - "min_bin": 3.25, - "max_bin": 50.75, - "no_bins": 39, - }, - "template_pair_embedder": { + "model": { + "input_embedder": { + "tf_dim": 21, + "msa_dim": 49, + #"num_msa": 508, "c_z": c_z, - "c_out": 64, - "c_dgram": 39, - "c_aatype": 22, - }, - "template_single_embedder": { - "c_in": 34, "c_m": c_m, + "relpos_k": 32, + "max_relative_chain": 2, + "max_relative_idx": 32, + "use_chain_relative": True, }, - "template_pair_stack": { + "template": { + "distogram": { + "min_bin": 3.25, + "max_bin": 50.75, + "no_bins": 39, + }, + "template_pair_embedder": { + "c_z": c_z, + "c_out": 64, + "c_dgram": 39, + "c_aatype": 22, + }, + "template_single_embedder": { + "c_in": 34, + "c_m": c_m, + }, + "template_pair_stack": { + "c_t": c_t, + # DISCREPANCY: c_hidden_tri_att here is given in the supplement + # as 64. In the code, it's 16. + "c_hidden_tri_att": 16, + "c_hidden_tri_mul": 64, + "no_blocks": 2, + "no_heads": 4, + "pair_transition_n": 2, + "dropout_rate": 0.25, + "tri_mul_first": True, + "fuse_projection_weights": True, + "blocks_per_ckpt": blocks_per_ckpt, + "inf": 1e9, + }, "c_t": c_t, - # DISCREPANCY: c_hidden_tri_att here is given in the supplement - # as 64. In the code, it's 16. - "c_hidden_tri_att": 16, - "c_hidden_tri_mul": 64, - "no_blocks": 2, - "no_heads": 4, - "pair_transition_n": 2, - "dropout_rate": 0.25, - "tri_mul_first": True, - "fuse_projection_weights": True, - "blocks_per_ckpt": blocks_per_ckpt, - "inf": 1e9, - }, - "c_t": c_t, - "c_z": c_z, - "inf": 1e5, # 1e9, - "eps": eps, # 1e-6, - "enabled": templates_enabled, - "embed_angles": embed_template_torsion_angles, - "use_unit_vector": True - }, - "extra_msa": { - "extra_msa_embedder": { - "c_in": 25, - "c_out": c_e, - #"num_extra_msa": 2048 - }, - "extra_msa_stack": { - "c_m": c_e, "c_z": c_z, - "c_hidden_msa_att": 8, + "inf": 1e5, # 1e9, + "eps": eps, # 1e-6, + "enabled": templates_enabled, + "embed_angles": embed_template_torsion_angles, + "use_unit_vector": True + }, + "extra_msa": { + "extra_msa_embedder": { + "c_in": 25, + "c_out": c_e, + #"num_extra_msa": 2048 + }, + "extra_msa_stack": { + "c_m": c_e, + "c_z": c_z, + "c_hidden_msa_att": 8, + "c_hidden_opm": 32, + "c_hidden_mul": 128, + "c_hidden_pair_att": 32, + "no_heads_msa": 8, + "no_heads_pair": 4, + "no_blocks": 4, + "transition_n": 4, + "msa_dropout": 0.15, + "pair_dropout": 0.25, + "opm_first": True, + "fuse_projection_weights": True, + "clear_cache_between_blocks": True, + "inf": 1e9, + "eps": eps, # 1e-10, + "ckpt": blocks_per_ckpt is not None, + }, + "enabled": True, + }, + "evoformer_stack": { + "c_m": c_m, + "c_z": c_z, + "c_hidden_msa_att": 32, "c_hidden_opm": 32, "c_hidden_mul": 128, "c_hidden_pair_att": 32, + "c_s": c_s, "no_heads_msa": 8, "no_heads_pair": 4, - "no_blocks": 4, + "no_blocks": 48, "transition_n": 4, "msa_dropout": 0.15, "pair_dropout": 0.25, "opm_first": True, "fuse_projection_weights": True, - "clear_cache_between_blocks": True, + "blocks_per_ckpt": blocks_per_ckpt, + "clear_cache_between_blocks": False, "inf": 1e9, "eps": eps, # 1e-10, - "ckpt": blocks_per_ckpt is not None, }, - "enabled": True, - }, - "evoformer_stack": { - "c_m": c_m, - "c_z": c_z, - "c_hidden_msa_att": 32, - "c_hidden_opm": 32, - "c_hidden_mul": 128, - "c_hidden_pair_att": 32, - "c_s": c_s, - "no_heads_msa": 8, - "no_heads_pair": 4, - "no_blocks": 48, - "transition_n": 4, - "msa_dropout": 0.15, - "pair_dropout": 0.25, - "opm_first": True, - "fuse_projection_weights": True, - "blocks_per_ckpt": blocks_per_ckpt, - "clear_cache_between_blocks": False, - "inf": 1e9, - "eps": eps, # 1e-10, - }, - "structure_module": { - "c_s": c_s, - "c_z": c_z, - "c_ipa": 16, - "c_resnet": 128, - "no_heads_ipa": 12, - "no_qk_points": 4, - "no_v_points": 8, - "dropout_rate": 0.1, - "no_blocks": 8, - "no_transition_layers": 1, - "no_resnet_blocks": 2, - "no_angles": 7, - "trans_scale_factor": 20, - "epsilon": eps, # 1e-12, - "inf": 1e5, - }, - "heads": { - "lddt": { - "no_bins": 50, - "c_in": c_s, - "c_hidden": 128, - }, - "distogram": { - "c_z": c_z, - "no_bins": aux_distogram_bins, - }, - "tm": { - "c_z": c_z, - "no_bins": aux_distogram_bins, - "ptm_weight": 0.2, - "iptm_weight": 0.8, - "enabled": True, - }, - "masked_msa": { - "c_m": c_m, - "c_out": 22, - }, - "experimentally_resolved": { + "structure_module": { "c_s": c_s, - "c_out": 37, + "c_z": c_z, + "c_ipa": 16, + "c_resnet": 128, + "no_heads_ipa": 12, + "no_qk_points": 4, + "no_v_points": 8, + "dropout_rate": 0.1, + "no_blocks": 8, + "no_transition_layers": 1, + "no_resnet_blocks": 2, + "no_angles": 7, + "trans_scale_factor": 20, + "epsilon": eps, # 1e-12, + "inf": 1e5, }, + "heads": { + "lddt": { + "no_bins": 50, + "c_in": c_s, + "c_hidden": 128, + }, + "distogram": { + "c_z": c_z, + "no_bins": aux_distogram_bins, + }, + "tm": { + "c_z": c_z, + "no_bins": aux_distogram_bins, + "ptm_weight": 0.2, + "iptm_weight": 0.8, + "enabled": True, + }, + "masked_msa": { + "c_m": c_m, + "c_out": 22, + }, + "experimentally_resolved": { + "c_s": c_s, + "c_out": 37, + }, + }, + "recycle_early_stop_tolerance": 0.5 }, "loss": { "distogram": { @@ -897,6 +903,5 @@ multimer_model_config_update = { "enabled": True, }, "eps": eps, - }, - "recycle_early_stop_tolerance": 0.5 + } } diff --git a/openfold/data/data_modules.py b/openfold/data/data_modules.py index 1bd7949..ab1d92a 100644 --- a/openfold/data/data_modules.py +++ b/openfold/data/data_modules.py @@ -151,7 +151,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): chain: i for i, chain in enumerate(self._chain_ids) } - template_featurizer = templates.TemplateHitFeaturizer( + template_featurizer = templates.HhsearchHitFeaturizer( mmcif_dir=template_mmcif_dir, max_template_date=max_template_date, max_hits=max_template_hits, diff --git a/openfold/data/data_pipeline.py b/openfold/data/data_pipeline.py index 9c02327..c436a55 100644 --- a/openfold/data/data_pipeline.py +++ b/openfold/data/data_pipeline.py @@ -24,7 +24,7 @@ from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union import numpy as np from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer -from openfold.data.templates import get_custom_template_features +from openfold.data.templates import get_custom_template_features, empty_template_feats from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch from openfold.data.tools.utils import to_date from openfold.np import residue_constants, protein @@ -34,22 +34,10 @@ FeatureDict = MutableMapping[str, np.ndarray] TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch] -def empty_template_feats(n_res) -> FeatureDict: - return { - "template_aatype": np.zeros((0, n_res)).astype(np.int64), - "template_all_atom_positions": - np.zeros((0, n_res, 37, 3)).astype(np.float32), - "template_sum_probs": np.zeros((0, 1)).astype(np.float32), - "template_all_atom_mask": np.zeros((0, n_res, 37)).astype(np.float32), - } - - def make_template_features( input_sequence: str, hits: Sequence[Any], template_featurizer: Any, - query_pdb_code: Optional[str] = None, - query_release_date: Optional[str] = None, ) -> FeatureDict: hits_cat = sum(hits.values(), []) if(len(hits_cat) == 0 or template_featurizer is None): @@ -61,11 +49,6 @@ def make_template_features( ) template_features = templates_result.features - # The template featurizer doesn't format empty template features - # properly. This is a quick fix. - if(template_features["template_aatype"].shape[0] == 0): - template_features = empty_template_feats(len(input_sequence)) - return template_features @@ -453,7 +436,8 @@ class AlignmentRunner: if(uniprot_database_path is not None): self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer( binary_path=jackhmmer_binary_path, - database_path=uniprot_database_path + database_path=uniprot_database_path, + n_cpu=no_cpus ) if(template_searcher is not None and @@ -800,37 +784,6 @@ class DataPipeline: return all_hits - def _parse_template_hits( - self, - alignment_dir: str, - alignment_index: Optional[Any] = None - ) -> Mapping[str, Any]: - all_hits = {} - if (alignment_index is not None): - fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb') - - def read_template(start, size): - fp.seek(start) - return fp.read(size).decode("utf-8") - - for (name, start, size) in alignment_index["files"]: - ext = os.path.splitext(name)[-1] - - if (ext == ".hhr"): - hits = parsers.parse_hhr(read_template(start, size)) - all_hits[name] = hits - - fp.close() - else: - for f in os.listdir(alignment_dir): - path = os.path.join(alignment_dir, f) - ext = os.path.splitext(f)[-1] - - if (ext == ".hhr"): - with open(path, "r") as fp: - hits = parsers.parse_hhr(fp.read()) - all_hits[f] = hits - def _get_msas(self, alignment_dir: str, input_sequence: Optional[str] = None, @@ -935,15 +888,15 @@ class DataPipeline: mmcif_feats = make_mmcif_features(mmcif, chain_id) input_sequence = mmcif.chain_to_seqres[chain_id] - hits = self._parse_template_hits( + hits = self._parse_template_hit_files( alignment_dir, + input_sequence, alignment_index) template_features = make_template_features( input_sequence, hits, - self.template_featurizer, - query_release_date=to_date(mmcif.header["release_date"]) + self.template_featurizer ) msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index) @@ -984,8 +937,9 @@ class DataPipeline: is_distillation=is_distillation ) - hits = self._parse_template_hits( + hits = self._parse_template_hit_files( alignment_dir, + input_sequence, alignment_index ) @@ -1016,8 +970,9 @@ class DataPipeline: description = os.path.splitext(os.path.basename(core_path))[0].upper() core_feats = make_protein_features(protein_object, description) - hits = self._parse_template_hits( + hits = self._parse_template_hit_files( alignment_dir, + input_sequence, alignment_index ) @@ -1107,7 +1062,7 @@ class DataPipeline: alignment_dir = os.path.join( super_alignment_dir, desc ) - hits = self._parse_template_hits(alignment_dir, alignment_index=None) + hits = self._parse_template_hit_files(alignment_dir, seq, alignment_index=None) template_features = make_template_features( seq, hits, diff --git a/openfold/data/data_transforms.py b/openfold/data/data_transforms.py index e172633..2ec193a 100755 --- a/openfold/data/data_transforms.py +++ b/openfold/data/data_transforms.py @@ -89,18 +89,17 @@ def make_all_atom_aatype(protein): def fix_templates_aatype(protein): # Map one-hot to indices num_templates = protein["template_aatype"].shape[0] - if(num_templates > 0): - protein["template_aatype"] = torch.argmax( - protein["template_aatype"], dim=-1 - ) - # Map hhsearch-aatype to our aatype. - new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE - new_order = torch.tensor( - new_order_list, dtype=torch.int64, device=protein["template_aatype"].device, - ).expand(num_templates, -1) - protein["template_aatype"] = torch.gather( - new_order, 1, index=protein["template_aatype"] - ) + protein["template_aatype"] = torch.argmax( + protein["template_aatype"], dim=-1 + ) + # Map hhsearch-aatype to our aatype. + new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE + new_order = torch.tensor( + new_order_list, dtype=torch.int64, device=protein["template_aatype"].device, + ).expand(num_templates, -1) + protein["template_aatype"] = torch.gather( + new_order, 1, index=protein["template_aatype"] + ) return protein diff --git a/openfold/data/data_transforms_multimer.py b/openfold/data/data_transforms_multimer.py index b495766..40434c2 100644 --- a/openfold/data/data_transforms_multimer.py +++ b/openfold/data/data_transforms_multimer.py @@ -2,7 +2,9 @@ from typing import Sequence import torch +from openfold.config import NUM_RES from openfold.data.data_transforms import curry1 +from openfold.np import residue_constants as rc from openfold.utils.tensor_utils import masked_mean @@ -301,3 +303,177 @@ def make_msa_profile(batch): ) return batch + + +def get_interface_residues(positions, atom_mask, asym_id, interface_threshold): + coord_diff = positions[..., None, :, :] - positions[..., None, :, :, :] + pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1)) + + diff_chain_mask = (asym_id[..., None, :] != asym_id[..., :, None]).float() + pair_mask = atom_mask[..., None, :] * atom_mask[..., None, :, :] + mask = diff_chain_mask[..., None] * pair_mask + + min_dist_per_res = torch.where(mask, pairwise_dists, torch.inf).min(dim=-1) + + valid_interfaces = torch.sum((min_dist_per_res < interface_threshold).float(), dim=-1) + interface_residues_idxs = torch.nonzero(valid_interfaces, as_tuple=True)[0] + + return interface_residues_idxs + + +def get_spatial_crop_idx(protein, crop_size, interface_threshold, generator): + positions = protein["all_atom_positions"] + atom_mask = protein["all_atom_mask"] + asym_id = protein["asym_id"] + + interface_residues = get_interface_residues(positions=positions, + atom_mask=atom_mask, + asym_id=asym_id, + interface_threshold=interface_threshold) + + if not torch.any(interface_residues): + return get_contiguous_crop_idx(protein, crop_size, generator) + + target_res = interface_residues[int(torch.randint(0, interface_residues.shape[-1], (1,), + device=positions.device, generator=generator)[0])] + + ca_idx = rc.atom_order["CA"] + ca_positions = positions[..., ca_idx, :] + ca_mask = atom_mask[..., ca_idx].bool() + + coord_diff = ca_positions[..., None, :] - ca_positions[..., None, :, :] + ca_pairwise_dists = torch.sqrt(torch.sum(coord_diff ** 2, dim=-1)) + + to_target_distances = ca_pairwise_dists[target_res] + break_tie = ( + torch.arange( + 0, to_target_distances.shape[-1], device=positions.device + ).float() + * 1e-3 + ) + to_target_distances = torch.where(ca_mask[..., None], to_target_distances, torch.inf) + break_tie + + ret = torch.argsort(to_target_distances)[:crop_size] + return ret.sort().values + + +def randint(lower, upper, generator, device): + return int(torch.randint( + lower, + upper + 1, + (1,), + device=device, + generator=generator, + )[0]) + + +def get_contiguous_crop_idx(protein, crop_size, generator): + num_res = protein["aatype"].shape[0] + if num_res <= crop_size: + return torch.arange(num_res) + + _, chain_lens = protein["asym_id"].unique(return_counts=True) + shuffle_idx = torch.randperm(chain_lens.shape[-1], device=chain_lens.device, generator=generator) + num_remaining = int(chain_lens.sum()) + num_budget = crop_size + crop_idxs = [] + asym_offset = torch.tensor(0, dtype=torch.int64) + for j, idx in enumerate(shuffle_idx): + this_len = int(chain_lens[idx]) + num_remaining -= this_len + # num res at most we can keep in this ent + crop_size_max = min(num_budget, this_len) + # num res at least we shall keep in this ent + crop_size_min = min(this_len, max(0, num_budget - num_remaining)) + chain_crop_size = randint(lower=crop_size_min, + upper=crop_size_max + 1, + generator=generator, + device=chain_lens.device) + + chain_start = randint(lower=0, + upper=this_len - chain_crop_size + 1, + generator=generator, + device=chain_lens.device) + crop_idxs.append( + torch.arange(asym_offset + chain_start, asym_offset + chain_start + chain_crop_size) + ) + asym_offset += this_len + + num_budget -= chain_crop_size + + return torch.concat(crop_idxs) + + +@curry1 +def random_crop_to_size( + protein, + crop_size, + max_templates, + shape_schema, + spatial_crop_prob, + interface_threshold, + subsample_templates=False, + seed=None, +): + """Crop randomly to `crop_size`, or keep as is if shorter than that.""" + # We want each ensemble to be cropped the same way + g = torch.Generator(device=protein["seq_length"].device) + if seed is not None: + g.manual_seed(seed) + + use_spatial_crop = torch.rand((1,), + device=protein["seq_length"].device, + generator=g) < spatial_crop_prob + if use_spatial_crop: + crop_idxs = get_spatial_crop_idx(protein, crop_size, interface_threshold, g) + else: + crop_idxs = get_contiguous_crop_idx(protein, crop_size, g) + + if "template_mask" in protein: + num_templates = protein["template_mask"].shape[-1] + else: + num_templates = 0 + + # No need to subsample templates if there aren't any + subsample_templates = subsample_templates and num_templates + + if subsample_templates: + templates_crop_start = randint(lower=0, + upper=num_templates + 1, + generator=g, + device=protein["seq_length"].device) + templates_select_indices = torch.randperm( + num_templates, device=protein["seq_length"].device, generator=g + ) + else: + templates_crop_start = 0 + + num_res_crop_size = min(int(protein["seq_length"]), crop_size) + num_templates_crop_size = min( + num_templates - templates_crop_start, max_templates + ) + + for k, v in protein.items(): + if k not in shape_schema or ( + "template" not in k and NUM_RES not in shape_schema[k] + ): + continue + + # randomly permute the templates before cropping them. + if k.startswith("template") and subsample_templates: + v = v[templates_select_indices] + + for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)): + is_num_res = dim_size == NUM_RES + if i == 0 and k.startswith("template"): + crop_size = num_templates_crop_size + crop_start = templates_crop_start + v = v[slice(crop_start, crop_start + crop_size)] + elif is_num_res: + v = torch.index_select(v, i, crop_idxs) + + protein[k] = v + + protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size) + + return protein diff --git a/openfold/data/input_pipeline.py b/openfold/data/input_pipeline.py index 779fed7..35192ca 100644 --- a/openfold/data/input_pipeline.py +++ b/openfold/data/input_pipeline.py @@ -104,7 +104,8 @@ def ensembled_transform_fns(common_cfg, mode_cfg, ensemble_seed): # the masked locations and secret corrupted locations. transforms.append( data_transforms.make_masked_msa( - common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction + common_cfg.masked_msa, mode_cfg.masked_msa_replace_fraction, + seed=(msa_seed + 1) if msa_seed else None, ) ) diff --git a/openfold/data/templates.py b/openfold/data/templates.py index b92f14c..af6d37a 100644 --- a/openfold/data/templates.py +++ b/openfold/data/templates.py @@ -89,6 +89,24 @@ TEMPLATE_FEATURES = { } +def empty_template_feats(n_res): + return { + "template_aatype": np.zeros( + (0, n_res, len(residue_constants.restypes_with_x_and_gap)), + np.float32 + ), + "template_all_atom_mask": np.zeros( + (0, n_res, residue_constants.atom_type_num), np.float32 + ), + "template_all_atom_positions": np.zeros( + (0, n_res, residue_constants.atom_type_num, 3), np.float32 + ), + "template_domain_names": np.array([''.encode()], dtype=np.object), + "template_sequence": np.array([''.encode()], dtype=np.object), + "template_sum_probs": np.zeros((0, 1), dtype=np.float32), + } + + def _get_pdb_id_and_chain(hit: parsers.TemplateHit) -> Tuple[str, str]: """Returns PDB id and chain id for an HHSearch Hit.""" # PDB ID: 4 letters. Chain ID: 1+ alphanumeric letters or "." if unknown. @@ -1163,21 +1181,7 @@ class HhsearchHitFeaturizer(TemplateHitFeaturizer): else: num_res = len(query_sequence) # Construct a default template with all zeros. - template_features = { - "template_aatype": np.zeros( - (1, num_res, len(residue_constants.restypes_with_x_and_gap)), - np.float32 - ), - "template_all_atom_masks": np.zeros( - (1, num_res, residue_constants.atom_type_num), np.float32 - ), - "template_all_atom_positions": np.zeros( - (1, num_res, residue_constants.atom_type_num, 3), np.float32 - ), - "template_domain_names": np.array([''.encode()], dtype=np.object), - "template_sequence": np.array([''.encode()], dtype=np.object), - "template_sum_probs": np.array([0], dtype=np.float32), - } + template_features = empty_template_feats(num_res) return TemplateSearchResult( features=template_features, errors=errors, warnings=warnings @@ -1276,21 +1280,7 @@ class HmmsearchHitFeaturizer(TemplateHitFeaturizer): else: num_res = len(query_sequence) # Construct a default template with all zeros. - template_features = { - "template_aatype": np.zeros( - (1, num_res, len(residue_constants.restypes_with_x_and_gap)), - np.float32 - ), - "template_all_atom_masks": np.zeros( - (1, num_res, residue_constants.atom_type_num), np.float32 - ), - "template_all_atom_positions": np.zeros( - (1, num_res, residue_constants.atom_type_num, 3), np.float32 - ), - "template_domain_names": np.array([''.encode()], dtype=np.object), - "template_sequence": np.array([''.encode()], dtype=np.object), - "template_sum_probs": np.array([0], dtype=np.float32), - } + template_features = empty_template_feats(num_res) return TemplateSearchResult( features=template_features, diff --git a/openfold/model/embedders.py b/openfold/model/embedders.py index 5124716..c0c2e42 100644 --- a/openfold/model/embedders.py +++ b/openfold/model/embedders.py @@ -242,7 +242,7 @@ class InputEmbedderMultimer(nn.Module): entity_id = batch["entity_id"] entity_id_same = (entity_id[..., None] == entity_id[..., None, :]) - rel_feats.append(entity_id_same[..., None]) + rel_feats.append(entity_id_same[..., None].to(dtype=rel_pos.dtype)) sym_id = batch["sym_id"] rel_sym_id = sym_id[..., None] - sym_id[..., None, :] @@ -577,7 +577,7 @@ class TemplateEmbedder(nn.Module): # a second copy during the stack later on t_pair = z.new_zeros( z.shape[:-3] + - (n_templ, n, n, self.config.template_pair_embedder.c_t) + (n_templ, n, n, self.config.template_pair_embedder.c_out) ) for i in range(n_templ): @@ -667,17 +667,17 @@ class TemplatePairEmbedderMultimer(nn.Module): ): super(TemplatePairEmbedderMultimer, self).__init__() - self.dgram_linear = Linear(c_dgram, c_out) - self.aatype_linear_1 = Linear(c_aatype, c_out) - self.aatype_linear_2 = Linear(c_aatype, c_out) + self.dgram_linear = Linear(c_dgram, c_out, init='relu') + self.aatype_linear_1 = Linear(c_aatype, c_out, init='relu') + self.aatype_linear_2 = Linear(c_aatype, c_out, init='relu') self.query_embedding_layer_norm = LayerNorm(c_z) - self.query_embedding_linear = Linear(c_z, c_out) + self.query_embedding_linear = Linear(c_z, c_out, init='relu') - self.pseudo_beta_mask_linear = Linear(1, c_out) - self.x_linear = Linear(1, c_out) - self.y_linear = Linear(1, c_out) - self.z_linear = Linear(1, c_out) - self.backbone_mask_linear = Linear(1, c_out) + self.pseudo_beta_mask_linear = Linear(1, c_out, init='relu') + self.x_linear = Linear(1, c_out, init='relu') + self.y_linear = Linear(1, c_out, init='relu') + self.z_linear = Linear(1, c_out, init='relu') + self.backbone_mask_linear = Linear(1, c_out, init='relu') def forward(self, template_dgram: torch.Tensor, @@ -812,10 +812,10 @@ class TemplateEmbedderMultimer(nn.Module): single_template_embeds = {} act = 0. - template_positions, pseudo_beta_mask = ( - single_template_feats["template_pseudo_beta"], - single_template_feats["template_pseudo_beta_mask"], - ) + template_positions, pseudo_beta_mask = pseudo_beta_fn( + single_template_feats["template_aatype"], + single_template_feats["template_all_atom_positions"], + single_template_feats["template_all_atom_mask"]) template_dgram = dgram_from_positions( template_positions, diff --git a/openfold/model/model.py b/openfold/model/model.py index 5bedf4d..d31303b 100644 --- a/openfold/model/model.py +++ b/openfold/model/model.py @@ -186,11 +186,6 @@ class AlphaFold(nn.Module): if self.config.recycle_early_stop_tolerance < 0: return False - if no_batch_dims == 0: - prev_pos = prev_pos.unsqueeze(dim=0) - next_pos = next_pos.unsqueeze(dim=0) - mask = mask.unsqueeze(dim=0) - ca_idx = residue_constants.atom_order['CA'] sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2 mask = mask[..., None] * mask[..., None, :] @@ -265,7 +260,7 @@ class AlphaFold(nn.Module): requires_grad=False, ) - x_prev = pseudo_beta_fn( + pseudo_beta_x_prev = pseudo_beta_fn( feats["aatype"], x_prev, None ).to(dtype=z.dtype) @@ -279,10 +274,12 @@ class AlphaFold(nn.Module): m_1_prev_emb, z_prev_emb = self.recycling_embedder( m_1_prev, z_prev, - x_prev, + pseudo_beta_x_prev, inplace_safe=inplace_safe, ) + del pseudo_beta_x_prev + if(self.globals.offload_inference and inplace_safe): m = m.to(m_1_prev_emb.device) z = z.to(z_prev.device) diff --git a/openfold/model/structure_module.py b/openfold/model/structure_module.py index 1b611ea..989f51e 100644 --- a/openfold/model/structure_module.py +++ b/openfold/model/structure_module.py @@ -166,12 +166,14 @@ class PointProjection(nn.Module): c_hidden: int, num_points: int, no_heads: int, + is_multimer: bool, return_local_points: bool = False, ): super().__init__() self.return_local_points = return_local_points self.no_heads = no_heads self.num_points = num_points + self.is_multimer = is_multimer self.linear = Linear(c_hidden, no_heads * 3 * num_points) @@ -181,24 +183,19 @@ class PointProjection(nn.Module): ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: # TODO: Needs to run in high precision during training points_local = self.linear(activations) + out_shape = points_local.shape[:-1] + (self.no_heads, self.num_points, 3) - if isinstance(rigids, Rigid3Array): - points_local = points_local.reshape( - *points_local.shape[:-1], - self.no_heads, - -1, + if self.is_multimer: + points_local = points_local.view( + points_local.shape[:-1] + (self.no_heads, -1) ) points_local = torch.split( points_local, points_local.shape[-1] // 3, dim=-1 ) - points_local = torch.stack(points_local, dim=-1) + points_local = torch.stack(points_local, dim=-1).view(out_shape) - if not isinstance(rigids, Rigid3Array): - points_local = points_local.reshape( - *points_local.shape[:-2], self.no_heads, -1, 3 - ) points_global = rigids[..., None, None].apply(points_local) if(self.return_local_points): @@ -260,7 +257,8 @@ class InvariantPointAttention(nn.Module): self.linear_q_points = PointProjection( self.c_s, self.no_qk_points, - self.no_heads + self.no_heads, + self.is_multimer ) if(is_multimer): @@ -270,12 +268,14 @@ class InvariantPointAttention(nn.Module): self.c_s, self.no_qk_points, self.no_heads, + self.is_multimer ) self.linear_v_points = PointProjection( self.c_s, self.no_v_points, self.no_heads, + self.is_multimer ) else: self.linear_kv = Linear(self.c_s, 2 * hc) @@ -283,6 +283,7 @@ class InvariantPointAttention(nn.Module): self.c_s, self.no_qk_points + self.no_v_points, self.no_heads, + self.is_multimer ) self.linear_b = Linear(self.c_z, self.no_heads) @@ -504,6 +505,230 @@ class InvariantPointAttention(nn.Module): return s +#TODO: This module follows the refactoring done in IPA for multimer. Running the regular IPA above +# in multimer mode should be equivalent, but tests do not pass unless using this version. Determine +# whether or not the increase in test error matters in practice. +class InvariantPointAttentionMultimer(nn.Module): + """ + Implements Algorithm 22. + """ + def __init__( + self, + c_s: int, + c_z: int, + c_hidden: int, + no_heads: int, + no_qk_points: int, + no_v_points: int, + inf: float = 1e5, + eps: float = 1e-8, + is_multimer: bool = True, + ): + """ + Args: + c_s: + Single representation channel dimension + c_z: + Pair representation channel dimension + c_hidden: + Hidden channel dimension + no_heads: + Number of attention heads + no_qk_points: + Number of query/key points to generate + no_v_points: + Number of value points to generate + """ + super(InvariantPointAttentionMultimer, self).__init__() + + self.c_s = c_s + self.c_z = c_z + self.c_hidden = c_hidden + self.no_heads = no_heads + self.no_qk_points = no_qk_points + self.no_v_points = no_v_points + self.inf = inf + self.eps = eps + + # These linear layers differ from their specifications in the + # supplement. There, they lack bias and use Glorot initialization. + # Here as in the official source, they have bias and use the default + # Lecun initialization. + hc = self.c_hidden * self.no_heads + self.linear_q = Linear(self.c_s, hc, bias=False) + + self.linear_q_points = PointProjection( + self.c_s, + self.no_qk_points, + self.no_heads, + is_multimer=True + ) + + self.linear_k = Linear(self.c_s, hc, bias=False) + self.linear_v = Linear(self.c_s, hc, bias=False) + self.linear_k_points = PointProjection( + self.c_s, + self.no_qk_points, + self.no_heads, + is_multimer=True + ) + + self.linear_v_points = PointProjection( + self.c_s, + self.no_v_points, + self.no_heads, + is_multimer=True + ) + + self.linear_b = Linear(self.c_z, self.no_heads) + + self.head_weights = nn.Parameter(torch.zeros((no_heads))) + ipa_point_weights_init_(self.head_weights) + + concat_out_dim = self.no_heads * ( + self.c_z + self.c_hidden + self.no_v_points * 4 + ) + self.linear_out = Linear(concat_out_dim, self.c_s, init="final") + + self.softmax = nn.Softmax(dim=-2) + + def forward( + self, + s: torch.Tensor, + z: Optional[torch.Tensor], + r: Union[Rigid, Rigid3Array], + mask: torch.Tensor, + inplace_safe: bool = False, + _offload_inference: bool = False, + _z_reference_list: Optional[Sequence[torch.Tensor]] = None, + ) -> torch.Tensor: + """ + Args: + s: + [*, N_res, C_s] single representation + z: + [*, N_res, N_res, C_z] pair representation + r: + [*, N_res] transformation object + mask: + [*, N_res] mask + Returns: + [*, N_res, C_s] single representation update + """ + if(_offload_inference and inplace_safe): + z = _z_reference_list + else: + z = [z] + + a = 0. + + point_variance = (max(self.no_qk_points, 1) * 9.0 / 2) + point_weights = math.sqrt(1.0 / point_variance) + + softplus = lambda x: torch.logaddexp(x, torch.zeros_like(x)) + + head_weights = softplus(self.head_weights) + point_weights = point_weights * head_weights + + ####################################### + # Generate scalar and point activations + ####################################### + + # [*, N_res, H, P_qk] + q_pts = Vec3Array.from_array(self.linear_q_points(s, r)) + + # [*, N_res, H, P_qk, 3] + k_pts = Vec3Array.from_array(self.linear_k_points(s, r)) + + pt_att = square_euclidean_distance(q_pts.unsqueeze(-3), k_pts.unsqueeze(-4), epsilon=0.) + pt_att = torch.sum(pt_att * point_weights[..., None], dim=-1) * (-0.5) + a = a + pt_att + + scalar_variance = max(self.c_hidden, 1) * 1. + scalar_weights = math.sqrt(1.0 / scalar_variance) + + # [*, N_res, H * C_hidden] + q = self.linear_q(s) + k = self.linear_k(s) + + # [*, N_res, H, C_hidden] + q = q.view(q.shape[:-1] + (self.no_heads, -1)) + k = k.view(k.shape[:-1] + (self.no_heads, -1)) + + q = q * scalar_weights + a = a + torch.einsum('...qhc,...khc->...qkh', q, k) + + ########################## + # Compute attention scores + ########################## + # [*, N_res, N_res, H] + b = self.linear_b(z[0]) + + if (_offload_inference): + assert (sys.getrefcount(z[0]) == 2) + z[0] = z[0].cpu() + + a = a + b + + # [*, N_res, N_res] + square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) + square_mask = self.inf * (square_mask - 1) + + a = a + square_mask.unsqueeze(-1) + a = a * math.sqrt(1. / 3) # Normalize by number of logit terms (3) + a = self.softmax(a) + + # [*, N_res, H * C_hidden] + v = self.linear_v(s) + + # [*, N_res, H, C_hidden] + v = v.view(v.shape[:-1] + (self.no_heads, -1)) + + o = torch.einsum('...qkh, ...khc->...qhc', a, v) + + # [*, N_res, H * C_hidden] + o = flatten_final_dims(o, 2) + + # [*, N_res, H, P_v, 3] + v_pts = Vec3Array.from_array(self.linear_v_points(s, r)) + + # [*, N_res, H, P_v] + o_pt = v_pts[..., None, :, :, :] * a.unsqueeze(-1) + o_pt = o_pt.sum(dim=-3) + # o_pt = Vec3Array( + # torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].x, dim=-3), + # torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].y, dim=-3), + # torch.sum(a.unsqueeze(-1) * v_pts[..., None, :, :, :].z, dim=-3), + # ) + + # [*, N_res, H * P_v, 3] + o_pt = o_pt.reshape(o_pt.shape[:-2] + (-1,)) + + # [*, N_res, H, P_v] + o_pt = r[..., None].apply_inverse_to_point(o_pt) + o_pt_flat = [o_pt.x, o_pt.y, o_pt.z] + + # [*, N_res, H * P_v] + o_pt_norm = o_pt.norm(epsilon=1e-8) + + if (_offload_inference): + z[0] = z[0].to(o_pt.device) + + o_pair = torch.einsum('...ijh, ...ijc->...ihc', a, z[0].to(dtype=a.dtype)) + + # [*, N_res, H * C_z] + o_pair = flatten_final_dims(o_pair, 2) + + # [*, N_res, C_s] + s = self.linear_out( + torch.cat( + (o, *o_pt_flat, o_pt_norm, o_pair), dim=-1 + ).to(dtype=z[0].dtype) + ) + + return s + + class BackboneUpdate(nn.Module): """ Implements part of Algorithm 23. @@ -670,7 +895,8 @@ class StructureModule(nn.Module): self.linear_in = Linear(self.c_s, self.c_s) - self.ipa = InvariantPointAttention( + ipa = InvariantPointAttention if not self.is_multimer else InvariantPointAttentionMultimer + self.ipa = ipa( self.c_s, self.c_z, self.c_ipa, diff --git a/openfold/model/triangular_multiplicative_update.py b/openfold/model/triangular_multiplicative_update.py index 984eb69..4365d5a 100644 --- a/openfold/model/triangular_multiplicative_update.py +++ b/openfold/model/triangular_multiplicative_update.py @@ -521,12 +521,8 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate): def compute_projection(pair, mask): p = compute_projection_helper(pair, mask) - if self._outgoing: - left = p[..., :self.c_hidden] - right = p[..., self.c_hidden:] - else: - left = p[..., self.c_hidden:] - right = p[..., :self.c_hidden] + left = p[..., :self.c_hidden] + right = p[..., self.c_hidden:] return left, right @@ -580,12 +576,8 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate): ab = ab * self.sigmoid(self.linear_ab_g(z)) ab = ab * self.linear_ab_p(z) - if self._outgoing: - a = ab[..., :self.c_hidden] - b = ab[..., self.c_hidden:] - else: - b = ab[..., :self.c_hidden] - a = ab[..., self.c_hidden:] + a = ab[..., :self.c_hidden] + b = ab[..., self.c_hidden:] # Prevents overflow of torch.matmul in combine projections in # reduced-precision modes diff --git a/openfold/utils/all_atom_multimer.py b/openfold/utils/all_atom_multimer.py index 3e15aeb..18a2779 100644 --- a/openfold/utils/all_atom_multimer.py +++ b/openfold/utils/all_atom_multimer.py @@ -36,9 +36,9 @@ def get_rc_tensor(rc_np, aatype): def atom14_to_atom37( atom14_data: torch.Tensor, # (*, N, 14, ...) aatype: torch.Tensor # (*, N) -) -> torch.Tensor: # (*, N, 37, ...) +) -> Tuple: # (*, N, 37, ...) """Convert atom14 to atom37 representation.""" - idx_atom37_to_atom14 = get_rc_tensor(rc.RESTYPE_ATOM37_TO_ATOM14, aatype) + idx_atom37_to_atom14 = get_rc_tensor(rc.RESTYPE_ATOM37_TO_ATOM14, aatype).long() no_batch_dims = len(aatype.shape) - 1 atom37_data = tensor_utils.batched_gather( atom14_data, @@ -50,10 +50,10 @@ def atom14_to_atom37( if len(atom14_data.shape) == no_batch_dims + 2: atom37_data *= atom37_mask elif len(atom14_data.shape) == no_batch_dims + 3: - atom37_data *= atom37_mask[..., None].astype(atom37_data.dtype) + atom37_data *= atom37_mask[..., None].to(dtype=atom37_data.dtype) else: raise ValueError("Incorrectly shaped data") - return atom37_data + return atom37_data, atom37_mask def atom37_to_atom14(aatype, all_atom_pos, all_atom_mask): @@ -230,13 +230,13 @@ def torsion_angles_to_frames( num_residues = aatype.shape[-1] sin_angles = torch.cat( [ - torch.zeros_like(aatype).unsqueeze(), + torch.zeros_like(aatype).unsqueeze(dim=-1), sin_angles, ], dim=-1) cos_angles = torch.cat( [ - torch.ones_like(aatype).unsqueeze(), + torch.ones_like(aatype).unsqueeze(dim=-1), cos_angles ], dim=-1 diff --git a/openfold/utils/geometry/quat_rigid.py b/openfold/utils/geometry/quat_rigid.py index 2e7b107..88c0253 100644 --- a/openfold/utils/geometry/quat_rigid.py +++ b/openfold/utils/geometry/quat_rigid.py @@ -20,7 +20,7 @@ class QuatRigid(nn.Module): def forward(self, activations: torch.Tensor) -> Rigid3Array: # NOTE: During training, this needs to be run in higher precision - rigid_flat = self.linear(activations.to(torch.float32)) + rigid_flat = self.linear(activations) rigid_flat = torch.unbind(rigid_flat, dim=-1) if(self.full_quat): diff --git a/openfold/utils/geometry/rotation_matrix.py b/openfold/utils/geometry/rotation_matrix.py index d992292..8835d08 100644 --- a/openfold/utils/geometry/rotation_matrix.py +++ b/openfold/utils/geometry/rotation_matrix.py @@ -172,20 +172,20 @@ class Rot3Array: ) -> Rot3Array: """Construct Rot3Array from components of quaternion.""" if normalize: - inv_norm = torch.rsqrt(eps + w**2 + x**2 + y**2 + z**2) + inv_norm = torch.rsqrt(torch.clamp(w**2 + x**2 + y**2 + z**2, min=eps)) w = w * inv_norm x = x * inv_norm y = y * inv_norm z = z * inv_norm - xx = 1 - 2 * (y ** 2 + z ** 2) - xy = 2 * (x * y - w * z) - xz = 2 * (x * z + w * y) - yx = 2 * (x * y + w * z) - yy = 1 - 2 * (x ** 2 + z ** 2) - yz = 2 * (y * z - w * x) - zx = 2 * (x * z - w * y) - zy = 2 * (y * z + w * x) - zz = 1 - 2 * (x ** 2 + y ** 2) + xx = 1.0 - 2.0 * (y ** 2 + z ** 2) + xy = 2.0 * (x * y - w * z) + xz = 2.0 * (x * z + w * y) + yx = 2.0 * (x * y + w * z) + yy = 1.0 - 2.0 * (x ** 2 + z ** 2) + yz = 2.0 * (y * z - w * x) + zx = 2.0 * (x * z - w * y) + zy = 2.0 * (y * z + w * x) + zz = 1.0 - 2.0 * (x ** 2 + y ** 2) return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) def reshape(self, new_shape): diff --git a/openfold/utils/import_weights.py b/openfold/utils/import_weights.py index 43f12a4..c69913c 100644 --- a/openfold/utils/import_weights.py +++ b/openfold/utils/import_weights.py @@ -28,7 +28,7 @@ _NPZ_KEY_PREFIX = "alphafold/alphafold_iteration/" # With Param, a poor man's enum with attributes (Rust-style) class ParamType(Enum): LinearWeight = partial( # hack: partial prevents fns from becoming methods - lambda w: w.transpose(-1, -2) + lambda w: w.unsqueeze(-1) if len(w.shape) == 1 else w.transpose(-1, -2) ) LinearWeightMHA = partial( lambda w: w.reshape(*w.shape[:-2], -1).transpose(-1, -2) @@ -58,6 +58,7 @@ class Param: param: Union[torch.Tensor, List[torch.Tensor]] param_type: ParamType = ParamType.Other stacked: bool = False + swap: bool = False def process_translation_dict(d, top_layer=True): @@ -101,6 +102,7 @@ def stacked(param_dict_list, out=None): param=[param.param for param in v], param_type=v[0].param_type, stacked=True, + swap=v[0].swap ) out[k] = stacked_param @@ -122,7 +124,12 @@ def assign(translation_dict, orig_weights): try: weights = list(map(param_type.transformation, weights)) for p, w in zip(ref, weights): - p.copy_(w) + if param.swap: + index = p.shape[0] // 2 + p[:index].copy_(w[index:]) + p[index:].copy_(w[:index]) + else: + p.copy_(w) except: print(k) print(ref[0].shape) @@ -145,12 +152,24 @@ def generate_translation_dict(model, version, is_multimer=False): LinearBiasMultimer = lambda l: ( Param(l, param_type=ParamType.LinearBiasMultimer) ) + LinearWeightSwap = lambda l: (Param(l, param_type=ParamType.LinearWeight, swap=True)) + LinearBiasSwap = lambda l: (Param(l, swap=True)) LinearParams = lambda l: { "weights": LinearWeight(l.weight), "bias": LinearBias(l.bias), } + LinearParamsMHA = lambda l: { + "weights": LinearWeightMHA(l.weight), + "bias": LinearBiasMHA(l.bias), + } + + LinearParamsSwap = lambda l: { + "weights": LinearWeightSwap(l.weight), + "bias": LinearBiasSwap(l.bias), + } + LinearParamsMultimer = lambda l: { "weights": LinearWeightMultimer(l.weight), "bias": LinearBiasMultimer(l.bias), @@ -194,10 +213,11 @@ def generate_translation_dict(model, version, is_multimer=False): def TriMulOutParams(tri_mul, outgoing=True): if re.fullmatch("^model_[1-5]_multimer_v3$", version): + lin_param_type = LinearParams if outgoing else LinearParamsSwap d = { "left_norm_input": LayerNormParams(tri_mul.layer_norm_in), - "projection": LinearParams(tri_mul.linear_ab_p), - "gate": LinearParams(tri_mul.linear_ab_g), + "projection": lin_param_type(tri_mul.linear_ab_p), + "gate": lin_param_type(tri_mul.linear_ab_g), "center_norm": LayerNormParams(tri_mul.layer_norm_out), } else: @@ -276,24 +296,24 @@ def generate_translation_dict(model, version, is_multimer=False): } PointProjectionParams = lambda pp: { - "point_projection": LinearParamsMultimer( + "point_projection": LinearParamsMHA( pp.linear, ), } IPAParamsMultimer = lambda ipa: { "q_scalar_projection": { - "weights": LinearWeightMultimer( + "weights": LinearWeightMHA( ipa.linear_q.weight, ), }, "k_scalar_projection": { - "weights": LinearWeightMultimer( + "weights": LinearWeightMHA( ipa.linear_k.weight, ), }, "v_scalar_projection": { - "weights": LinearWeightMultimer( + "weights": LinearWeightMHA( ipa.linear_v.weight, ), }, @@ -574,7 +594,7 @@ def generate_translation_dict(model, version, is_multimer=False): "template_pair_embedding_0": LinearParams( temp_embedder.template_pair_embedder.dgram_linear ), - "template_pair_embedding_1": LinearParamsMultimer( + "template_pair_embedding_1": LinearParams( temp_embedder.template_pair_embedder.pseudo_beta_mask_linear ), "template_pair_embedding_2": LinearParams( @@ -583,16 +603,16 @@ def generate_translation_dict(model, version, is_multimer=False): "template_pair_embedding_3": LinearParams( temp_embedder.template_pair_embedder.aatype_linear_2 ), - "template_pair_embedding_4": LinearParamsMultimer( + "template_pair_embedding_4": LinearParams( temp_embedder.template_pair_embedder.x_linear ), - "template_pair_embedding_5": LinearParamsMultimer( + "template_pair_embedding_5": LinearParams( temp_embedder.template_pair_embedder.y_linear ), - "template_pair_embedding_6": LinearParamsMultimer( + "template_pair_embedding_6": LinearParams( temp_embedder.template_pair_embedder.z_linear ), - "template_pair_embedding_7": LinearParamsMultimer( + "template_pair_embedding_7": LinearParams( temp_embedder.template_pair_embedder.backbone_mask_linear ), "template_pair_embedding_8": LinearParams( @@ -600,7 +620,7 @@ def generate_translation_dict(model, version, is_multimer=False): ), "template_embedding_iteration": tps_blocks_params, "output_layer_norm": LayerNormParams( - model.template_embedder.template_pair_stack.layer_norm + temp_embedder.template_pair_stack.layer_norm ), }, "output_linear": LinearParams( diff --git a/openfold/utils/loss.py b/openfold/utils/loss.py index ca7dd72..666efd5 100644 --- a/openfold/utils/loss.py +++ b/openfold/utils/loss.py @@ -1643,7 +1643,7 @@ def chain_center_of_mass_loss( all_atom_positions = all_atom_positions[..., ca_pos, :] all_atom_mask = all_atom_mask[..., ca_pos: (ca_pos + 1)] # keep dim - chains, _ = asym_id.unique(return_counts=True) + chains = asym_id.unique() one_hot = torch.nn.functional.one_hot(asym_id, num_classes=chains.shape[0]).to(dtype=all_atom_mask.dtype) one_hot = one_hot * all_atom_mask chain_pos_mask = one_hot.transpose(-2, -1) diff --git a/run_pretrained_openfold.py b/run_pretrained_openfold.py index 773b493..c6c729c 100644 --- a/run_pretrained_openfold.py +++ b/run_pretrained_openfold.py @@ -431,7 +431,7 @@ if __name__ == "__main__": help="""Postfix for output prediction filenames""" ) parser.add_argument( - "--data_random_seed", type=str, default=None + "--data_random_seed", type=int, default=None ) parser.add_argument( "--skip_relaxation", action="store_true", default=False, diff --git a/scripts/generate_alphafold_feature_dict.py b/scripts/generate_alphafold_feature_dict.py index f80217a..bde5c76 100644 --- a/scripts/generate_alphafold_feature_dict.py +++ b/scripts/generate_alphafold_feature_dict.py @@ -45,7 +45,7 @@ def main(args): uniref90_database_path=args.uniref90_database_path, mgnify_database_path=args.mgnify_database_path, bfd_database_path=args.bfd_database_path, - uniclust30_database_path=args.uniclust30_database_path, + uniref30_database_path=args.uniref30_database_path, small_bfd_database_path=None, template_featurizer=template_featurizer, template_searcher=template_searcher, diff --git a/tests/test_import_weights.py b/tests/test_import_weights.py index dab7547..65bc66f 100644 --- a/tests/test_import_weights.py +++ b/tests/test_import_weights.py @@ -15,7 +15,9 @@ import torch import numpy as np import unittest +from pathlib import Path +from tests.config import consts from openfold.config import model_config from openfold.model.model import AlphaFold from openfold.utils.import_weights import import_jax_weights_ @@ -23,15 +25,17 @@ from openfold.utils.import_weights import import_jax_weights_ class TestImportWeights(unittest.TestCase): def test_import_jax_weights_(self): - npz_path = "openfold/resources/params/params_model_1_ptm.npz" + npz_path = Path(__file__).parent.resolve() / f"../openfold/resources/params/params_{consts.model}.npz" - c = model_config("model_1_ptm") + c = model_config(consts.model) c.globals.blocks_per_ckpt = None model = AlphaFold(c) + model.eval() import_jax_weights_( model, npz_path, + version=consts.model ) data = np.load(npz_path) diff --git a/train_openfold.py b/train_openfold.py index 47a96f7..0194e41 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -22,7 +22,7 @@ from openfold.data.data_modules import ( from openfold.model.model import AlphaFold from openfold.model.torchscript import script_preset_ from openfold.np import residue_constants -from openfold.utils.argparse import remove_arguments +from openfold.utils.argparse_utils import remove_arguments from openfold.utils.callbacks import ( EarlyStoppingVerbose, )