mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
train and inference atom23 with custom datasets and na ss transform
This commit is contained in:
committed by
Raktim Mitra
parent
716c16a7ec
commit
f0ab0fedae
103
models/rfd3/configs/datasets/design_base_rfd3na.yaml
Normal file
103
models/rfd3/configs/datasets/design_base_rfd3na.yaml
Normal file
@@ -0,0 +1,103 @@
|
||||
# base training dataset for training AF3 design models (atom14 variants):
|
||||
# protein subsampling only.
|
||||
|
||||
defaults:
|
||||
# Grab datasets
|
||||
- train/pdb/rfd3_train_interface@train.pdb.sub_datasets.interface
|
||||
- train/pdb/rfd3_train_pn_unit@train.pdb.sub_datasets.pn_unit
|
||||
#- train/rfd3_monomer_distillation@train
|
||||
- train/rna_monomer_distillation@train
|
||||
|
||||
# Customized validation datasets
|
||||
- val/unconditional@val.unconditional
|
||||
- val/unconditional_deep@val.unconditional_deep
|
||||
- val/indexed@val.indexed
|
||||
- val/pseudoknot@val.pseudoknot
|
||||
|
||||
# Customized train masks
|
||||
- conditions/unconditional@global_transform_args.train_conditions.unconditional
|
||||
- conditions/island@global_transform_args.train_conditions.island
|
||||
- conditions/tipatom@global_transform_args.train_conditions.tipatom
|
||||
- conditions/sequence_design@global_transform_args.train_conditions.sequence_design
|
||||
- conditions/ppi@global_transform_args.train_conditions.ppi
|
||||
|
||||
- _self_
|
||||
|
||||
# Create a dictionary used for transform arguments
|
||||
pipeline_target: rfd3.transforms.pipelines.build_atom14_base_pipeline
|
||||
|
||||
# Base config overrides:
|
||||
diffusion_batch_size_train: 32
|
||||
diffusion_batch_size_inference: 8
|
||||
crop_size: 384
|
||||
n_recycles_train: 2
|
||||
n_recycles_validation: 1
|
||||
max_atoms_in_crop: 3840 # ~10x crop size.
|
||||
|
||||
# Global transform arguments are necessary for arguments shared between training and inference
|
||||
global_transform_args:
|
||||
n_atoms_per_token: 14
|
||||
central_atom: CB
|
||||
sigma_perturb: 2.0
|
||||
sigma_perturb_com: 1.0
|
||||
association_scheme: dense
|
||||
center_option: diffuse # options are ["all", "motif", "diffuse"]
|
||||
|
||||
# Reference conformer policy
|
||||
generate_conformers: True
|
||||
generate_conformers_for_non_protein_only: True
|
||||
provide_reference_conformer_when_unmasked: True
|
||||
ground_truth_conformer_policy: IGNORE # Other options: REPLACE, ADD, FALLBACK. See atomworks.enums for details
|
||||
provide_elements_for_unindexed_components: True
|
||||
use_element_for_atom_names_of_atomized_tokens: True # TODO: correct name, implies unindexed do too
|
||||
|
||||
# PPI Cropping
|
||||
keep_full_binder_in_spatial_crop: False
|
||||
max_binder_length: 170
|
||||
|
||||
# PPI Hotspots
|
||||
max_ppi_hotspots_frac_to_provide: 0.2
|
||||
ppi_hotspot_max_distance: 4.5
|
||||
|
||||
# Secondary structure features
|
||||
max_ss_frac_to_provide: 0.4
|
||||
min_ss_island_len: 1
|
||||
max_ss_island_len: 10
|
||||
|
||||
# Nucleic acid features
|
||||
add_na_pair_features: false
|
||||
|
||||
train_conditions:
|
||||
unconditional:
|
||||
frequency: 5.0
|
||||
sequence_design:
|
||||
frequency: 2.0
|
||||
island:
|
||||
frequency: 1.0
|
||||
tipatom:
|
||||
frequency: 0.0
|
||||
ppi:
|
||||
frequency: 0.0
|
||||
|
||||
# Used to create simple boolean flags for downstream conditioning
|
||||
meta_conditioning_probabilities:
|
||||
calculate_NA_SS: 1.0
|
||||
calculate_hbonds: 0.2
|
||||
calculate_rasa: 0.6
|
||||
|
||||
keep_protein_motif_rasa: 0.1 # Small to prevent noisy input to model
|
||||
hbond_subsample: 0.5
|
||||
|
||||
# fully indexed training
|
||||
unindex_leak_global_index: 0.10
|
||||
unindex_insert_random_break: 0.10
|
||||
unindex_remove_random_break: 0.10
|
||||
|
||||
# Probability of adding 1d secondary structure conditioning
|
||||
add_1d_ss_features: 0.1
|
||||
featurize_plddt: 0.9 # Applied for monomer distillation only
|
||||
add_global_is_non_loopy_feature: 0.99
|
||||
|
||||
# PPI
|
||||
add_ppi_hotspots: 0.75
|
||||
full_binder_crop: 0.75
|
||||
@@ -0,0 +1,39 @@
|
||||
defaults:
|
||||
- pdb/base_transform_args@rna_monomer_distillation
|
||||
- _self_
|
||||
|
||||
rna_monomer_distillation:
|
||||
dataset:
|
||||
_target_: atomworks.ml.datasets.StructuralDatasetWrapper
|
||||
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
|
||||
|
||||
# cif parser arguments
|
||||
cif_parser_args:
|
||||
cache_dir: null
|
||||
load_from_cache: False
|
||||
save_to_cache: False
|
||||
|
||||
# metadata parser
|
||||
dataset_parser:
|
||||
_target_: atomworks.ml.datasets.parsers.GenericDFParser
|
||||
pn_unit_iid_colnames: null
|
||||
|
||||
# metadata dataset
|
||||
dataset:
|
||||
_target_: atomworks.ml.datasets.PandasDataset
|
||||
name: rna_monomer_distillation
|
||||
id_column: example_id
|
||||
data: /projects/ml/afavor/rna_distillation/rna_distillation_filtered_df.parquet
|
||||
columns_to_load:
|
||||
- example_id
|
||||
- path
|
||||
- cluster_id
|
||||
- seq_hash
|
||||
- overall_plddt
|
||||
- overall_pde
|
||||
- overall_pae
|
||||
|
||||
transform:
|
||||
crop_contiguous_probability: 0.67
|
||||
crop_spatial_probability: 0.33
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
|
||||
defaults:
|
||||
- unconditional
|
||||
- design_validation_base
|
||||
- _self_
|
||||
|
||||
dataset:
|
||||
name: pseudoknot
|
||||
eval_every_n: 1
|
||||
data: /home/afavor/git/RFD3/modelhub/projects/aa_design/tests/test_data/pseudoknot.json
|
||||
data: ${paths.data.design_benchmark_data_dir}/pseudoknot.json
|
||||
|
||||
@@ -5,6 +5,7 @@ defaults:
|
||||
- /debug/default
|
||||
- override /model: rfd3_base
|
||||
- override /logger: null
|
||||
- override /datasets: design_base_rfd3na
|
||||
- _self_
|
||||
|
||||
name: rfd3na-SScond
|
||||
@@ -54,7 +55,7 @@ datasets:
|
||||
max_atoms_in_crop: 2560 # ~10x crop size.
|
||||
global_transform_args:
|
||||
association_scheme: atom23
|
||||
add_na_pair_features: true
|
||||
#add_na_pair_features: true
|
||||
train_conditions:
|
||||
unconditional:
|
||||
frequency: 2.0
|
||||
@@ -69,19 +70,19 @@ datasets:
|
||||
train:
|
||||
# These are the ratios used in the preprint but we set all pdb sampling by default since not everyone might download the distillation data.
|
||||
pdb:
|
||||
probability: 1.0
|
||||
# rna_monomer_distillation:
|
||||
# probability: 1.0
|
||||
probability: 0.5
|
||||
rna_monomer_distillation:
|
||||
probability: 0.5
|
||||
|
||||
# val:
|
||||
# pseudoknot:
|
||||
# dataset:
|
||||
# # eval_every_n: 10
|
||||
# eval_every_n: 2
|
||||
val:
|
||||
pseudoknot:
|
||||
dataset:
|
||||
# eval_every_n: 10
|
||||
eval_every_n: 1
|
||||
|
||||
trainer:
|
||||
devices_per_node: 1
|
||||
limit_train_batches: 10
|
||||
limit_val_batches: 1
|
||||
validate_every_n_epochs: 5
|
||||
prevalidate: false
|
||||
prevalidate: true
|
||||
|
||||
@@ -9,7 +9,7 @@ from os import PathLike
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from atomworks.constants import STANDARD_AA
|
||||
from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA
|
||||
from atomworks.io.parser import parse_atom_array
|
||||
|
||||
# from atomworks.ml.datasets.datasets import BaseDataset
|
||||
@@ -119,7 +119,9 @@ class DesignInputSpecification(BaseModel):
|
||||
validate_assignment=False,
|
||||
str_strip_whitespace=True,
|
||||
str_min_length=1,
|
||||
extra="forbid",
|
||||
#extra="forbid", ####################################################
|
||||
extra="allow"
|
||||
## for now allowing extra for rfd3na-ss purposes, can decide later ##
|
||||
)
|
||||
# fmt: off
|
||||
# ========================================================================
|
||||
@@ -494,6 +496,12 @@ class DesignInputSpecification(BaseModel):
|
||||
aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like(
|
||||
is_bkbn, False, dtype=int
|
||||
)
|
||||
elif aa.res_name[start] in (STANDARD_DNA + STANDARD_RNA) and self.redesign_motif_sidechains:
|
||||
is_bkbn = np.isin(aa.atom_name[start:end], backbone_atoms_RNA)
|
||||
aa.is_motif_atom_with_fixed_coord[start:end] = is_bkbn.astype(int)
|
||||
aa.is_motif_atom_with_fixed_seq[start:end] = np.full_like(
|
||||
is_bkbn, False, dtype=int
|
||||
)
|
||||
|
||||
# ... Apply selections on top
|
||||
apply_selections(start, end)
|
||||
@@ -509,17 +517,18 @@ class DesignInputSpecification(BaseModel):
|
||||
atom_array_input_annotated = copy.deepcopy(self.atom_array_input)
|
||||
|
||||
########## reorder NA atoms ###########
|
||||
is_dna = np.isin(atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"])
|
||||
is_rna = np.isin(atom_array_input_annotated.res_name, ["A", "C", "G", "U"])
|
||||
dna_array = atom_array_input_annotated[is_dna]
|
||||
rna_array = atom_array_input_annotated[is_rna]
|
||||
if exists(atom_array_input_annotated):
|
||||
is_dna = np.isin(atom_array_input_annotated.res_name, ["DA", "DC", "DG", "DT"])
|
||||
is_rna = np.isin(atom_array_input_annotated.res_name, ["A", "C", "G", "U"])
|
||||
dna_array = atom_array_input_annotated[is_dna]
|
||||
rna_array = atom_array_input_annotated[is_rna]
|
||||
|
||||
atom_array_input_annotated[is_dna] = reorder_atoms_per_residue(
|
||||
dna_array, backbone_atoms_DNA
|
||||
)
|
||||
atom_array_input_annotated[is_rna] = reorder_atoms_per_residue(
|
||||
rna_array, backbone_atoms_RNA
|
||||
)
|
||||
atom_array_input_annotated[is_dna] = reorder_atoms_per_residue(
|
||||
dna_array, backbone_atoms_DNA
|
||||
)
|
||||
atom_array_input_annotated[is_rna] = reorder_atoms_per_residue(
|
||||
rna_array, backbone_atoms_RNA
|
||||
)
|
||||
#######################################
|
||||
|
||||
atom_array = self._build_init(atom_array_input_annotated)
|
||||
@@ -928,11 +937,11 @@ def create_diffused_residues(n, additional_annotations=None, polymer_type="P"):
|
||||
elif polymer_type == "R":
|
||||
res_name = "A"
|
||||
bb_len = len(backbone_atoms_RNA)
|
||||
bb_atom_names = strip_list(backbone_atoms_RNA)
|
||||
bb_atom_names = backbone_atoms_RNA
|
||||
elif polymer_type == "D":
|
||||
res_name = "DA"
|
||||
bb_len = len(backbone_atoms_DNA)
|
||||
bb_atom_names = strip_list(backbone_atoms_DNA)
|
||||
bb_atom_names = backbone_atoms_DNA
|
||||
else:
|
||||
raise ValueError(
|
||||
f"invalid polymer type detected: {polymer_type}, check contig!"
|
||||
@@ -955,12 +964,13 @@ def create_diffused_residues(n, additional_annotations=None, polymer_type="P"):
|
||||
for idx in range(1, n + 1)
|
||||
]
|
||||
array = struc.array(atoms)
|
||||
array.set_annotation("element", np.array(bb_elements * n, dtype="<U2"))
|
||||
array.set_annotation("atom_name", np.array(bb_atom_names * n, dtype="<U2"))
|
||||
array.set_annotation("element", np.array(bb_elements * n, dtype="<U3"))
|
||||
array.set_annotation("atom_name", np.array(bb_atom_names * n, dtype="<U3"))
|
||||
array = set_default_conditioning_annotations(
|
||||
array, motif=False, additional=additional_annotations
|
||||
)
|
||||
array = set_common_annotations(array)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
|
||||
@@ -536,13 +536,14 @@ def create_atom_array_from_design_specification_legacy(
|
||||
optional_conditions = []
|
||||
|
||||
########## reorder NA atoms ###########
|
||||
is_dna = np.isin(atom_array_input.res_name, ["DA", "DC", "DG", "DT"])
|
||||
is_rna = np.isin(atom_array_input.res_name, ["A", "C", "G", "U"])
|
||||
dna_array = atom_array_input[is_dna]
|
||||
rna_array = atom_array_input[is_rna]
|
||||
if exists(atom_array_input):
|
||||
is_dna = np.isin(atom_array_input.res_name, ["DA", "DC", "DG", "DT"])
|
||||
is_rna = np.isin(atom_array_input.res_name, ["A", "C", "G", "U"])
|
||||
dna_array = atom_array_input[is_dna]
|
||||
rna_array = atom_array_input[is_rna]
|
||||
|
||||
atom_array_input[is_dna] = reorder_atoms_per_residue(dna_array, backbone_atoms_DNA)
|
||||
atom_array_input[is_rna] = reorder_atoms_per_residue(rna_array, backbone_atoms_RNA)
|
||||
atom_array_input[is_dna] = reorder_atoms_per_residue(dna_array, backbone_atoms_DNA)
|
||||
atom_array_input[is_rna] = reorder_atoms_per_residue(rna_array, backbone_atoms_RNA)
|
||||
#######################################
|
||||
|
||||
if exists(atomwise_rasa):
|
||||
|
||||
@@ -291,6 +291,7 @@ class BackboneMetrics(Metric):
|
||||
3.0 # maximum closest-neighbour distance before considered a floating atom
|
||||
)
|
||||
self.standard_ca_dist = 3.8
|
||||
self.standard_PP_dist = 6.4
|
||||
self.compute_for_diffused_region_only = compute_for_diffused_region_only
|
||||
|
||||
@property
|
||||
@@ -310,6 +311,8 @@ class BackboneMetrics(Metric):
|
||||
) # N_atoms x N_atoms
|
||||
|
||||
is_protein = f["is_protein"][tok_idx].cpu().numpy() # n_atoms
|
||||
is_rna = f["is_rna"][tok_idx].cpu().numpy()
|
||||
is_dna = f["is_dna"][tok_idx].cpu().numpy()
|
||||
|
||||
mask = np.zeros_like(dists, dtype=bool)
|
||||
mask = mask | (np.eye(dists.shape[-1], dtype=bool))[None]
|
||||
@@ -362,23 +365,42 @@ class BackboneMetrics(Metric):
|
||||
if self.compute_for_diffused_region_only:
|
||||
is_ca = is_ca[diffused_region]
|
||||
is_protein = is_protein[diffused_region]
|
||||
idx_mask = is_ca & is_protein
|
||||
is_dna = is_dna[diffused_region]
|
||||
is_rna = is_rna[diffused_region]
|
||||
protein_idx_mask = is_ca & (is_protein)
|
||||
na_idx_mask = is_ca & (is_rna | is_dna)
|
||||
|
||||
if self.compute_for_diffused_region_only:
|
||||
xyz = X_L.cpu()[:, diffused_region][:, idx_mask]
|
||||
xyz_protein = X_L.cpu()[:, diffused_region][:, protein_idx_mask]
|
||||
xyz_na = X_L.cpu()[:, diffused_region][:, na_idx_mask]
|
||||
else:
|
||||
xyz = X_L.cpu()[:, idx_mask]
|
||||
xyz_protein = X_L.cpu()[:, protein_idx_mask]
|
||||
xyz_na = X_L.cpu()[:, na_idx_mask]
|
||||
|
||||
ca_dists = torch.norm(xyz[:, 1:] - xyz[:, :-1], dim=-1)
|
||||
deviation = torch.abs(ca_dists - self.standard_ca_dist) # B, (I-1)
|
||||
is_chainbreak = deviation > 0.75
|
||||
ca_dists_protein = torch.norm(xyz_protein[:, 1:] - xyz_protein[:, :-1], dim=-1)
|
||||
ca_dists_na = torch.norm(xyz_na[:, 1:] - xyz_na[:, :-1], dim=-1)
|
||||
|
||||
deviation_protein = torch.abs(ca_dists_protein - self.standard_ca_dist) # B, (I-1)
|
||||
deviation_na = torch.abs(ca_dists_na - self.standard_PP_dist) # B, (I-1)
|
||||
is_chainbreak_protein = deviation_protein > 0.75
|
||||
is_chainbreak_na = deviation_na > 1
|
||||
|
||||
o["max_ca_deviation"] = float(deviation.max(-1).values.mean())
|
||||
o["fraction_chainbreaks"] = float(is_chainbreak.float().mean(-1).mean())
|
||||
o["n_chainbreaks"] = float(is_chainbreak.float().sum(-1).mean())
|
||||
try:
|
||||
o["max_ca_deviation_protein"] = float(deviation_protein.max(-1).values.mean())
|
||||
o["fraction_chainbreaks_protein"] = float(is_chainbreak_protein.float().mean(-1).mean())
|
||||
o["n_chainbreaks_protein"] = float(is_chainbreak_protein.float().sum(-1).mean())
|
||||
except:
|
||||
print("No protein in this example, skipping protein chainbreak metrics")
|
||||
|
||||
try:
|
||||
o["max_ca_deviation_na"] = float(deviation_na.max(-1).values.mean())
|
||||
o["fraction_chainbreaks_na"] = float(is_chainbreak_na.float().mean(-1).mean())
|
||||
o["n_chainbreaks_na"] = float(is_chainbreak_na.float().sum(-1).mean())
|
||||
except:
|
||||
print("No NA in this example, skipping NA chainbreak metrics")
|
||||
|
||||
return o
|
||||
|
||||
|
||||
class PPIMetrics(Metric):
|
||||
"""PPI-specific metrics"""
|
||||
|
||||
|
||||
@@ -243,6 +243,10 @@ class SampleConditioningType(Transform):
|
||||
)
|
||||
self.meta_conditioning_probabilities = meta_conditioning_probabilities
|
||||
self.train_conditions = train_conditions
|
||||
|
||||
for item in self.train_conditions:
|
||||
self.train_conditions[item].association_scheme = association_scheme
|
||||
|
||||
self.sequence_encoding = sequence_encoding
|
||||
self.association_scheme = association_scheme
|
||||
|
||||
@@ -261,12 +265,15 @@ class SampleConditioningType(Transform):
|
||||
assert "conditions" in data, "Conditioning dict not initialized"
|
||||
|
||||
def forward(self, data):
|
||||
#for item in self.train_conditions:
|
||||
# print(self.train_conditions[item].is_valid_for_example(data))
|
||||
|
||||
valid_conditions = [
|
||||
cond
|
||||
for cond in self.train_conditions.values()
|
||||
if cond.frequency > 0 and cond.is_valid_for_example(data)
|
||||
if cond.is_valid_for_example(data) and cond.frequency > 0
|
||||
]
|
||||
|
||||
|
||||
if len(valid_conditions) == 0:
|
||||
raise InvalidSampledConditionException("No valid condition was found.")
|
||||
|
||||
@@ -280,8 +287,6 @@ class SampleConditioningType(Transform):
|
||||
i_cond = np.random.choice(np.arange(len(p_cond)), p=p_cond)
|
||||
cond = valid_conditions[i_cond]
|
||||
|
||||
cond.association_scheme = self.association_scheme
|
||||
|
||||
data["sampled_condition"] = cond
|
||||
data["sampled_condition_name"] = cond.name
|
||||
data["sampled_condition_cls"] = cond.__class__
|
||||
|
||||
@@ -71,8 +71,10 @@ class SubsampleToTypes(Transform):
|
||||
def __init__(
|
||||
self,
|
||||
allowed_types: list | str = ["is_protein"],
|
||||
association_scheme: str = 'atom14'
|
||||
):
|
||||
self.allowed_types = allowed_types
|
||||
self.association_scheme = association_scheme
|
||||
if not self.allowed_types == "ALL":
|
||||
for k in allowed_types:
|
||||
if not k.startswith("is_"):
|
||||
@@ -104,7 +106,7 @@ class SubsampleToTypes(Transform):
|
||||
)
|
||||
)
|
||||
|
||||
if atom_array.is_protein.sum() == 0:
|
||||
if self.association_scheme != 'atom23' and atom_array.is_protein.sum() == 0:
|
||||
raise ValueError(
|
||||
"No protein atoms found in the atom array. Example ID: {}".format(
|
||||
data.get("example_id", "unknown")
|
||||
|
||||
@@ -119,7 +119,7 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
def __init__(
|
||||
self,
|
||||
is_inference,
|
||||
add_nucleic_ss_feats: bool = True,
|
||||
meta_conditioning_probabilities,
|
||||
|
||||
p_is_nucleic_ss_example: float = 0.3,
|
||||
p_show_partial_feats: float = 0.5,
|
||||
@@ -136,9 +136,18 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
):
|
||||
# Critical, must always have to know how to handle
|
||||
self.is_inference = is_inference
|
||||
|
||||
if not self.is_inference:
|
||||
## relevant in training
|
||||
self.sampling_prob = meta_conditioning_probabilities['calculate_NA_SS']
|
||||
else:
|
||||
## irrelevant in inference
|
||||
self.sampling_prob = 0
|
||||
# For sampling whether we add nucleic-ss features (extra t2d)
|
||||
self.add_nucleic_ss_feats = add_nucleic_ss_feats
|
||||
|
||||
# relevant in training
|
||||
self.add_nucleic_ss_feats = (self.sampling_prob > 0)
|
||||
######
|
||||
|
||||
self.p_canonical_bp_filter = p_canonical_bp_filter # enforce that bp labels are only canonical
|
||||
self.p_is_nucleic_ss_example = p_is_nucleic_ss_example
|
||||
self.nucleic_ss_min_shown = nucleic_ss_min_shown
|
||||
@@ -189,7 +198,7 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
n_tokens = len(token_starts)
|
||||
print(" DO I NEED TO CHANGE TO TOKEN_ID???")
|
||||
# Handle the training case with ground truth and masking:
|
||||
if not self.is_inference:
|
||||
if not self.is_inference and (np.random.rand() < self.sampling_prob):
|
||||
|
||||
# First, annotate as usual
|
||||
# atom_array = annotate_na_ss(atom_array, **kwargs)
|
||||
@@ -226,8 +235,9 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
- 3). Lists of paired indices
|
||||
|
||||
"""
|
||||
is_nucleic_ss_example=True
|
||||
give_partial_feats=False
|
||||
#is_nucleic_ss_example=True
|
||||
#give_partial_feats=False
|
||||
|
||||
atom_array = annotate_na_ss_from_data_specification(
|
||||
data,
|
||||
overwrite=True,
|
||||
@@ -282,4 +292,4 @@ class CalculateNucleicAcidGeomFeats(Transform):
|
||||
n_islands_max=self.n_islands_max,
|
||||
max_length=max_length,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1321,6 +1321,8 @@ def bp_partner_to_ss_matrix(
|
||||
*,
|
||||
feature_info: Optional[dict] = None,
|
||||
mol_info: Optional[NucMolInfo] = None,
|
||||
NA_only: Optional[bool] = False,
|
||||
planar_only: Optional[bool] = False,
|
||||
include_loops: bool = True,
|
||||
token_level_data: Optional[dict] = None,
|
||||
) -> np.ndarray:
|
||||
|
||||
@@ -196,6 +196,7 @@ def get_crop_transform(
|
||||
max_binder_length: int,
|
||||
max_atoms_in_crop: int | None,
|
||||
allowed_types: List[str],
|
||||
association_scheme: str,
|
||||
):
|
||||
if (
|
||||
crop_contiguous_probability > 0
|
||||
@@ -215,7 +216,7 @@ def get_crop_transform(
|
||||
), "Crop center cutoff distance must be greater than 0"
|
||||
|
||||
pre_crop_transforms = [
|
||||
SubsampleToTypes(allowed_types=allowed_types),
|
||||
SubsampleToTypes(allowed_types=allowed_types, association_scheme=association_scheme),
|
||||
]
|
||||
|
||||
cropping_transform = RandomRoute(
|
||||
@@ -360,8 +361,11 @@ def build_atom14_base_pipeline_(
|
||||
max_ss_frac_to_provide: float,
|
||||
min_ss_island_len: int,
|
||||
max_ss_island_len: int,
|
||||
# Nucleic acid features
|
||||
add_na_pair_features: bool,
|
||||
|
||||
## Nucleic acid features #####
|
||||
#add_na_pair_features: bool,
|
||||
## This should not be necessary, controlled through feature names in model, and meta conditioning probabilities, inference behavior handled in transform itself #####
|
||||
|
||||
**_, # dump additional kwargs (e.g. msa stuff)
|
||||
):
|
||||
"""
|
||||
@@ -411,6 +415,7 @@ def build_atom14_base_pipeline_(
|
||||
max_binder_length=max_binder_length,
|
||||
max_atoms_in_crop=max_atoms_in_crop,
|
||||
allowed_types=allowed_types,
|
||||
association_scheme=association_scheme
|
||||
)
|
||||
|
||||
if zero_occ_on_exposure_after_cropping:
|
||||
@@ -443,14 +448,15 @@ def build_atom14_base_pipeline_(
|
||||
)
|
||||
)
|
||||
# Add nucleic acid geometry features
|
||||
if add_na_pair_features:
|
||||
transforms.append(
|
||||
CalculateNucleicAcidGeomFeats(
|
||||
is_inference,
|
||||
NA_only=False,
|
||||
planar_only=True,
|
||||
)
|
||||
#if add_na_pair_features:
|
||||
transforms.append(
|
||||
CalculateNucleicAcidGeomFeats(
|
||||
is_inference,
|
||||
meta_conditioning_probabilities,
|
||||
NA_only=False,
|
||||
planar_only=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Design Transforms
|
||||
transforms += [
|
||||
@@ -618,7 +624,6 @@ def build_atom14_base_pipeline(
|
||||
Wrapper around pipeline construction to handle empty training args
|
||||
Sets default behaviour for inference to keep backward compatibility
|
||||
"""
|
||||
|
||||
if is_inference:
|
||||
# Provide explicit defaults for training-only args
|
||||
kwargs.setdefault("crop_size", 512)
|
||||
@@ -634,7 +639,8 @@ def build_atom14_base_pipeline(
|
||||
kwargs.setdefault("min_ss_island_len", 0)
|
||||
kwargs.setdefault("max_ss_island_len", 999)
|
||||
kwargs.setdefault("max_binder_length", 999)
|
||||
kwargs.setdefault("add_na_pair_features", False)
|
||||
# This should not be necessary.
|
||||
#kwargs.setdefault("add_na_pair_features", False)
|
||||
|
||||
kwargs.setdefault("b_factor_min", None)
|
||||
kwargs.setdefault("zero_occ_on_exposure_after_cropping", False)
|
||||
|
||||
@@ -96,13 +96,13 @@ class IslandCondition(TrainingCondition):
|
||||
is_rna = data["atom_array"].is_rna
|
||||
### updating this to allow other polymers
|
||||
if self.association_scheme == "atom23":
|
||||
if not np.any(is_protein | is_dna | is_rna):
|
||||
return False
|
||||
if np.any(is_protein | is_dna | is_rna):
|
||||
return True
|
||||
else:
|
||||
if not np.any(is_protein):
|
||||
return False
|
||||
|
||||
return True
|
||||
if np.any(is_protein):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def sample_motif_tokens(self, atom_array):
|
||||
"""
|
||||
|
||||
@@ -37,7 +37,7 @@ def assert_single_representative(token, central_atom="CB"):
|
||||
mask = get_af3_token_representative_masks(token, central_atom=central_atom)
|
||||
assert (
|
||||
np.sum(mask) == 1
|
||||
), f"No representative atom (CB) found. mask: {mask}\nToken: {token}"
|
||||
), f"No representative atom ({central_atom}) found. mask: {mask}\nToken: {token}"
|
||||
|
||||
|
||||
def assert_single_token(token):
|
||||
|
||||
@@ -225,17 +225,21 @@ class PadTokensWithVirtualAtoms(Transform):
|
||||
# First, pad with virtual atoms if needed
|
||||
if self.association_scheme == "atom23" and atom_array[start].is_dna:
|
||||
n_atoms_per_token = 22
|
||||
central_atom = "C1'"
|
||||
elif self.association_scheme == "atom23" and atom_array[start].is_rna:
|
||||
n_atoms_per_token = 23
|
||||
central_atom = "C1'"
|
||||
else:
|
||||
n_atoms_per_token = self.n_atoms_per_token
|
||||
central_atom = self.atom_to_pad_from
|
||||
|
||||
n_pad = n_atoms_per_token - len(token)
|
||||
|
||||
if n_pad > 0:
|
||||
mask = get_af3_token_representative_masks(
|
||||
token, central_atom=self.atom_to_pad_from
|
||||
token, central_atom=central_atom
|
||||
)
|
||||
assert_single_representative(token)
|
||||
assert_single_representative(token, central_atom=central_atom)
|
||||
|
||||
# ... Create virtual atoms
|
||||
pad_atoms = token[mask].copy()
|
||||
|
||||
Reference in New Issue
Block a user