mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
wip: stage 2, al working
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# Model architecture
|
||||
_target_: modelhub.model.AF3.AF3
|
||||
_target_: modelhub.model.RF3.RF3
|
||||
|
||||
# +---------- Channel dimensions ----------+
|
||||
c_s: 384
|
||||
|
||||
@@ -2,11 +2,11 @@ defaults:
|
||||
- rf3_net
|
||||
|
||||
# Model architecture
|
||||
_target_: modelhub.model.AF3.AF3WithConfidence
|
||||
_target_: modelhub.model.RF3.RF3WithConfidence
|
||||
|
||||
# +---------- Mini rollout sampler ----------+
|
||||
# From the AF-3 main text:
|
||||
# > ...To remedy this, we developed a diffusion ‘rollout’ procedure for the full-structure prediction generation during training (using a larger step size than normal)
|
||||
# > ... To remedy this, we developed a diffusion ‘rollout’ procedure for the full-structure prediction generation during training (using a larger step size than normal)
|
||||
# They do not further elaborate on how they adjusted the step size during diffusion rollout, but this may be a fruitful area of exploration moving forwards
|
||||
mini_rollout_sampler:
|
||||
solver: "af3"
|
||||
|
||||
@@ -6,7 +6,6 @@ defaults:
|
||||
- _self_
|
||||
|
||||
net:
|
||||
_target_: projects.rfscore.model.RFScore.RFScore
|
||||
feature_initializer:
|
||||
input_feature_embedder:
|
||||
atom_attention_encoder:
|
||||
|
||||
@@ -4,4 +4,4 @@ defaults:
|
||||
- _self_
|
||||
|
||||
net:
|
||||
_target_: projects.rfscore.model.RFScore.RFScoreWithConfidence
|
||||
_target_: modelhub.model.RF3.RF3WithConfidence
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +0,0 @@
|
||||
defaults:
|
||||
- default
|
||||
|
||||
fit_and_evaluate_on_experimental_data_callback:
|
||||
_target_: projects.rfscore.callbacks.experimental_validation.FitAndEvaluateOnExperimentalDataCallback
|
||||
# Metric outputs to use as features for experimental validation
|
||||
feature_metrics:
|
||||
- selected_atom_by_atom_distances
|
||||
datasets:
|
||||
- ts1_unconditional
|
||||
@@ -1,30 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
name: rfscore-arbitrary-templating
|
||||
|
||||
# For explanation of the "override" syntax, see: https://hydra.cc/docs/upgrades/1.0_to_1.1/defaults_list_override/
|
||||
defaults:
|
||||
- override /callbacks: default
|
||||
- override /datasets: rfscore_arbitrary_templating
|
||||
- override /model: rfscore
|
||||
- override /logger: csv
|
||||
- override /trainer/metrics: structure_prediction
|
||||
|
||||
tags:
|
||||
# list of tags to add to the run (& on wandb to easily find & filter runs)
|
||||
- arbitrary-templating
|
||||
- rfscore
|
||||
|
||||
model:
|
||||
lr_scheduler:
|
||||
base_lr: 0.9e-3 # 1/2 of original learning rate (1.8e-3)
|
||||
|
||||
ckpt_config:
|
||||
_target_: modelhub.utils.weights.CheckpointConfig
|
||||
path: /net/software/containers/versions/modelhub_inference/ckpts/modelhub_af3_with_confidence_latest.ckpt
|
||||
reset_optimizer: true
|
||||
weight_loading_config:
|
||||
_target_: modelhub.utils.weights.WeightLoadingConfig
|
||||
fallback_policy: copy_and_zero_pad
|
||||
param_policies:
|
||||
"*.recycler.template_embedder.emb_templ.weight": zero_init
|
||||
@@ -1,12 +0,0 @@
|
||||
# @package _global_
|
||||
# ^ The "package" determines where the content of the config is placed in the output config
|
||||
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
|
||||
|
||||
# This line enables us to access the common configs in the root-level `configs` directory
|
||||
# For more information, see: https://hydra.cc/docs/advanced/search_path/
|
||||
hydra:
|
||||
searchpath:
|
||||
- pkg://configs
|
||||
|
||||
defaults:
|
||||
- inference_engine: ???
|
||||
@@ -1,10 +0,0 @@
|
||||
# @package _global_
|
||||
defaults:
|
||||
- /inference_engine/af3
|
||||
- _self_
|
||||
|
||||
_target_: projects.rfscore.inference_engines.rfscore.RFScoreInferenceEngine
|
||||
|
||||
ckpt_path: /projects/ml/rfscore/inference/rfscore_no_confidence_2025-03-03.ckpt
|
||||
|
||||
template_selection_strings: null
|
||||
@@ -1,36 +0,0 @@
|
||||
# @package _global_
|
||||
|
||||
defaults:
|
||||
- /model/af3
|
||||
- _self_
|
||||
|
||||
model:
|
||||
net:
|
||||
_target_: projects.rfscore.model.RFScore.RFScore
|
||||
|
||||
feature_initializer:
|
||||
input_feature_embedder:
|
||||
atom_attention_encoder:
|
||||
c_atom_1d_features: 392 # Add three channels to concatenate the ground-truth reference conformer as well
|
||||
atom_1d_features:
|
||||
- ref_pos
|
||||
- ref_charge
|
||||
- ref_mask
|
||||
- ref_element
|
||||
- ref_atom_name_chars
|
||||
- ref_pos_ground_truth
|
||||
|
||||
recycler:
|
||||
template_embedder:
|
||||
raw_template_dim: 66
|
||||
|
||||
diffusion_module:
|
||||
atom_attention_encoder:
|
||||
c_atom_1d_features: 392 # Add three channels to concatenate the ground-truth reference conformer as well
|
||||
atom_1d_features:
|
||||
- ref_pos
|
||||
- ref_charge
|
||||
- ref_mask
|
||||
- ref_element
|
||||
- ref_atom_name_chars
|
||||
- ref_pos_ground_truth
|
||||
@@ -1,52 +0,0 @@
|
||||
# @package _global_
|
||||
# ^ The "package" determines where the content of the config is placed in the output config
|
||||
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
|
||||
|
||||
# This line enables us to access the common configs in the root-level `configs` directory
|
||||
# For more information, see: https://hydra.cc/docs/advanced/search_path/
|
||||
hydra:
|
||||
searchpath:
|
||||
- pkg://configs
|
||||
|
||||
# @package _global_
|
||||
# ^ The "package" determines where the content of the config is placed in the output config
|
||||
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
|
||||
|
||||
# NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
|
||||
defaults:
|
||||
- callbacks: default
|
||||
- logger: csv
|
||||
- trainer: af3
|
||||
- paths: default
|
||||
- datasets: ???
|
||||
- dataloader: default
|
||||
- hydra: default
|
||||
- model: rfscore
|
||||
# We must keep _self_ before experiment and debug to ensure that the experiment and debug configs can override
|
||||
- _self_
|
||||
|
||||
# experiment configs allow for version control of specific hyperparameters
|
||||
# e.g. best hyperparameters for given model and datamodule
|
||||
- experiment: ???
|
||||
|
||||
# debug configs to add onto any experiment for quickly testing or debugging code
|
||||
- debug: null
|
||||
|
||||
|
||||
# DO NOT set these here. Set them in the relevant experiment config file.
|
||||
# ... these are just here to ensure users always specify these fields in their experiment configs.
|
||||
name: ???
|
||||
tags: ???
|
||||
|
||||
# NOTE: These values will be overwritten by the experiment config if they are set there. They are just provided as defaults
|
||||
# here.
|
||||
# ... task name (determines the output directory path)
|
||||
task_name: "train"
|
||||
|
||||
project: rfscore # required for W&B logging
|
||||
|
||||
seed: 42
|
||||
|
||||
# Provide checkpoint path to resume training from a checkpoint
|
||||
# NOTE: If using W&B, must also set the `id` and `resume` fields in the `logger/wandb` config
|
||||
ckpt_path: null
|
||||
@@ -1,8 +0,0 @@
|
||||
defaults:
|
||||
- structure_prediction
|
||||
|
||||
extra_info:
|
||||
_target_: projects.rfscore.metrics.metadata.ExtraInfo
|
||||
|
||||
selected_atom_by_atom_distances:
|
||||
_target_: projects.rfscore.metrics.distances.SelectedAtomByAtomDistances
|
||||
@@ -1,55 +0,0 @@
|
||||
# @package _global_
|
||||
# ^ The "package" determines where the content of the config is placed in the output config
|
||||
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
|
||||
|
||||
# This line enables us to access the common configs in the root-level `configs` directory
|
||||
# For more information, see: https://hydra.cc/docs/advanced/search_path/
|
||||
hydra:
|
||||
searchpath:
|
||||
- pkg://configs
|
||||
|
||||
# NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
|
||||
defaults:
|
||||
- callbacks: default
|
||||
- logger: csv
|
||||
- trainer: af3
|
||||
- paths: default
|
||||
- datasets: ???
|
||||
- dataloader: default
|
||||
- hydra: default
|
||||
- model: rfscore
|
||||
# We must keep _self_ before experiment and debug to ensure that the experiment and debug configs can override
|
||||
- _self_
|
||||
|
||||
# experiment configs allow for version control of specific hyperparameters
|
||||
# e.g. best hyperparameters for given model and datamodule
|
||||
- experiment: ???
|
||||
|
||||
# debug configs to add onto any experiment for quickly testing or debugging code
|
||||
- debug: null
|
||||
|
||||
|
||||
# DO NOT set these here. Set them in the relevant experiment config file.
|
||||
# ... these are just here to ensure users always specify these fields in their experiment configs.
|
||||
name: ???
|
||||
tags: ???
|
||||
|
||||
# NOTE: These values will be overwritten by the experiment config if they are set there. They are just provided as defaults
|
||||
# here.
|
||||
# ... task name (determines the output directory path)
|
||||
task_name: "validate"
|
||||
|
||||
project: rfscore # required for W&B logging
|
||||
|
||||
seed: 42
|
||||
|
||||
# Dump CIF files for validation structures
|
||||
callbacks:
|
||||
dump_validation_structures_callback:
|
||||
dump_predictions: True
|
||||
one_model_per_file: False
|
||||
dump_trajectories: False
|
||||
|
||||
# passing checkpoint path required for validation
|
||||
# DO NOT set here; set in the experiment config file
|
||||
ckpt_path: ???
|
||||
@@ -1,4 +1,4 @@
|
||||
from modelhub.model.AF3 import AF3, AF3WithConfidence
|
||||
from modelhub.model.RF3 import AF3, AF3WithConfidence
|
||||
from projects.rfscore.model.recycler import RFScoreRecycler
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from modelhub.model.AF3_structure import Recycler
|
||||
from projects.rfscore.model.template_embedder import RFScoreTemplateEmbedder
|
||||
from modelhub.model.RF3_structure import Recycler
|
||||
from projects.rfscore.model.template_embedder import RF3TemplateEmbedder
|
||||
|
||||
|
||||
class RFScoreRecycler(Recycler):
|
||||
@@ -10,7 +10,7 @@ class RFScoreRecycler(Recycler):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# ... override the template embedder to use the RFScore template embedder, which provides additional conditioning
|
||||
self.template_embedder = RFScoreTemplateEmbedder(
|
||||
self.template_embedder = RF3TemplateEmbedder(
|
||||
c_z=kwargs["c_z"],
|
||||
**kwargs["template_embedder"],
|
||||
)
|
||||
|
||||
@@ -4,14 +4,12 @@ from torch.nn.functional import relu
|
||||
|
||||
from modelhub.model.layers.pairformer_layers import PairformerBlock
|
||||
from modelhub.training.checkpoint import activation_checkpointing
|
||||
from modelhub.utils.torch_utils import device_of
|
||||
from projects.rfscore.model.embeddings import FourierEmbedding
|
||||
from projects.rfscore.transforms.ground_truth_template import (
|
||||
from modelhub.data.ground_truth_template import (
|
||||
af3_noise_scale_to_noise_level,
|
||||
)
|
||||
|
||||
|
||||
class RFScoreTemplateEmbedder(nn.Module):
|
||||
class RF3TemplateEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_block,
|
||||
|
||||
@@ -1,6 +1,15 @@
|
||||
from datahub.transforms._checks import check_atom_array_annotation
|
||||
from datahub.transforms.crop import compute_local_hash
|
||||
|
||||
from modelhub.data.ground_truth_template import (
|
||||
FeaturizeNoisedGroundTruthAsTemplateDistogram,
|
||||
TokenGroupNoiseScaleSampler,
|
||||
af3_noise_scale_distribution_wrapped,
|
||||
af3_noise_scale_to_noise_level,
|
||||
)
|
||||
from omegaconf import DictConfig
|
||||
from cifutils.enums import ChainType
|
||||
from functools import partial
|
||||
import torch
|
||||
|
||||
def annotate_pre_crop_hash(data: dict) -> dict:
|
||||
hash_pre = compute_local_hash(data["atom_array"])
|
||||
@@ -29,3 +38,89 @@ def set_to_occupancy_0_where_crop_hashes_differ(data: dict) -> dict:
|
||||
atom_array.occupancy[mask] = 0
|
||||
|
||||
return data
|
||||
|
||||
def build_ground_truth_distogram_transform(
|
||||
*,
|
||||
template_noise_scales: dict[str, float | None] | DictConfig,
|
||||
allowed_chain_types_for_conditioning: list[ChainType] | None = None,
|
||||
p_condition_per_token: float = 1.0,
|
||||
p_provide_inter_molecule_distances: float = 0.0,
|
||||
is_inference: bool = False,
|
||||
) -> FeaturizeNoisedGroundTruthAsTemplateDistogram:
|
||||
"""
|
||||
Build a FeaturizeNoisedGroundTruthAsTemplateDistogram transform for either training or inference.
|
||||
|
||||
For inference, we must be deterministic, so we:
|
||||
- Use constant noise scales (1.0)
|
||||
- Always apply token-level conditioning
|
||||
|
||||
Args:
|
||||
template_noise_scales (dict[str, float | None] | DictConfig):
|
||||
Noise scales for 'atomized' and 'not_atomized' tokens. If is_inference=True, these are used as constants.
|
||||
If is_inference=False, these are used as upper bounds for the noise scale distribution.
|
||||
allowed_chain_types_for_conditioning (list[ChainType] | None):
|
||||
List of allowed chain types for conditioning. None disables conditioning.
|
||||
p_condition_per_token (float):
|
||||
Probability of conditioning each eligible token. For inference, this is always 1.0.
|
||||
p_provide_inter_molecule_distances (float):
|
||||
Probability of providing inter-molecule (inter-chain) distances.
|
||||
is_inference (bool):
|
||||
If True, use constant noise scales and always condition. If False, use distributions and provided probability.
|
||||
|
||||
Returns:
|
||||
FeaturizeNoisedGroundTruthAsTemplateDistogram: The configured transform.
|
||||
"""
|
||||
mask_and_sampling_fns = []
|
||||
if is_inference:
|
||||
# Use constant noise scales for inference, rather than sampling (no stochasticity)
|
||||
if template_noise_scales["atomized"] is not None:
|
||||
mask_and_sampling_fns.append(
|
||||
(
|
||||
lambda arr: arr.atomize,
|
||||
lambda size: torch.ones(size) * template_noise_scales["atomized"],
|
||||
)
|
||||
)
|
||||
if template_noise_scales["not_atomized"] is not None:
|
||||
mask_and_sampling_fns.append(
|
||||
(
|
||||
lambda arr: ~arr.atomize,
|
||||
lambda size: torch.ones(size) * template_noise_scales["not_atomized"],
|
||||
)
|
||||
)
|
||||
p_condition = 1.0 # Always condition for inference (no stochasticity)
|
||||
else:
|
||||
# Use noise scale distributions for training
|
||||
if template_noise_scales["atomized"] is not None:
|
||||
mask_and_sampling_fns.append(
|
||||
(
|
||||
lambda arr: arr.atomize,
|
||||
partial(
|
||||
af3_noise_scale_distribution_wrapped,
|
||||
upper_noise_level=af3_noise_scale_to_noise_level(
|
||||
template_noise_scales["atomized"]
|
||||
).item(),
|
||||
),
|
||||
)
|
||||
)
|
||||
if template_noise_scales["not_atomized"] is not None:
|
||||
mask_and_sampling_fns.append(
|
||||
(
|
||||
lambda arr: ~arr.atomize,
|
||||
partial(
|
||||
af3_noise_scale_distribution_wrapped,
|
||||
upper_noise_level=af3_noise_scale_to_noise_level(
|
||||
template_noise_scales["not_atomized"]
|
||||
).item(),
|
||||
),
|
||||
)
|
||||
)
|
||||
p_condition = p_condition_per_token # Apply conditioning to only some tokens during training
|
||||
|
||||
return FeaturizeNoisedGroundTruthAsTemplateDistogram(
|
||||
noise_scale_distribution=TokenGroupNoiseScaleSampler(
|
||||
mask_and_sampling_fns=tuple(mask_and_sampling_fns),
|
||||
),
|
||||
allowed_chain_types=allowed_chain_types_for_conditioning,
|
||||
p_condition_per_token=p_condition,
|
||||
p_provide_inter_molecule_distances=p_provide_inter_molecule_distances,
|
||||
)
|
||||
|
||||
@@ -98,7 +98,7 @@ from modelhub.data.pipeline_utils import (
|
||||
annotate_pre_crop_hash,
|
||||
set_to_occupancy_0_where_crop_hashes_differ,
|
||||
)
|
||||
from projects.rfscore.pipelines.composed import build_ground_truth_distogram_transform
|
||||
from modelhub.data.pipeline_utils import build_ground_truth_distogram_transform
|
||||
|
||||
|
||||
def TrainingRoute(transform):
|
||||
|
||||
@@ -14,7 +14,7 @@ from lightning.fabric import seed_everything
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from modelhub.inference_engines.base import InferenceEngine
|
||||
from modelhub.model.AF3 import ShouldEarlyStopFn
|
||||
from modelhub.model.RF3 import ShouldEarlyStopFn
|
||||
from modelhub.utils.datasets import (
|
||||
assemble_distributed_inference_loader_from_list_of_paths,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from modelhub.diffusion_samplers.inference_sampler import (
|
||||
SampleDiffusion,
|
||||
SamplePartialDiffusion,
|
||||
)
|
||||
from modelhub.model.AF3_structure import DiffusionModule, DistogramHead, Recycler
|
||||
from modelhub.model.RF3_structure import DiffusionModule, DistogramHead, Recycler
|
||||
from modelhub.model.layers.pairformer_layers import (
|
||||
FeatureInitializer,
|
||||
)
|
||||
@@ -52,8 +52,8 @@ class ShouldEarlyStopFn(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class AF3(nn.Module):
|
||||
"""AF3 Network module.
|
||||
class RF3(nn.Module):
|
||||
"""RF3 Network module.
|
||||
|
||||
We adhere to the PyTorch Lightning Style Guide; see (1).
|
||||
|
||||
@@ -89,14 +89,11 @@ class AF3(nn.Module):
|
||||
c_z: Token-level pair reprentation channel dimension
|
||||
c_atom: Atom-level single reprentation channel dimension
|
||||
c_atompair: Atom-level pair reprentation channel dimension
|
||||
c_s_inputs: TBD what the heck this is
|
||||
loss: Arguments for the loss function
|
||||
partial_optimizer: Optimizer (partially initialized) to be used for training. The "configure_optimizers" method will finish instantiating the optimizer.
|
||||
partial_lr_scheduler: Learning rate scheduler (partially initialized) to be used for training. The "configure_optimizers" method will finish instantiating the scheduler.
|
||||
c_s_inputs: Output dimension of the InputFeatureEmbedder
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# ... initialize the FeatureInitializer, which creates the initial token- and atom-level representations and conditioning
|
||||
# ... initialize the FeatureInitializer, which creates the initial token-level representations and conditioning
|
||||
self.feature_initializer = FeatureInitializer(
|
||||
c_s=c_s,
|
||||
c_z=c_z,
|
||||
@@ -286,6 +283,7 @@ class AF3(nn.Module):
|
||||
|
||||
def recycle(
|
||||
self,
|
||||
# TODO: Jax typing
|
||||
S_inputs_I,
|
||||
S_init_I,
|
||||
Z_init_II,
|
||||
@@ -311,7 +309,7 @@ class AF3(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class AF3WithConfidence(AF3):
|
||||
class RF3WithConfidence(RF3):
|
||||
"""Model for training and inference with confidence metric computation"""
|
||||
|
||||
def __init__(
|
||||
@@ -444,6 +442,8 @@ class AF3WithConfidence(AF3):
|
||||
)
|
||||
|
||||
# ... run non-batched confidence head
|
||||
# TODO: Write a version of the confidence head that splits into batches based on memory available
|
||||
# (Currently, we OOM with the full batch size, so we loop, which is slow)
|
||||
D = sample_diffusion_outs["X_L"].shape[0]
|
||||
confidence_stack = {}
|
||||
for i in range(D):
|
||||
@@ -13,7 +13,7 @@ from modelhub.model.layers.pairformer_layers import (
|
||||
MSAModule,
|
||||
PairformerBlock,
|
||||
RelativePositionEncoding,
|
||||
TemplateEmbedder,
|
||||
RF3TemplateEmbedder,
|
||||
)
|
||||
from modelhub.training.checkpoint import activation_checkpointing
|
||||
|
||||
@@ -292,7 +292,7 @@ class Recycler(nn.Module):
|
||||
nn.LayerNorm(c_z),
|
||||
linearNoBias(c_z, c_z),
|
||||
)
|
||||
self.template_embedder = TemplateEmbedder(c_z=c_z, **template_embedder)
|
||||
self.template_embedder = RF3TemplateEmbedder(c_z=c_z, **template_embedder)
|
||||
self.msa_module = MSAModule(**msa_module)
|
||||
self.process_sh = nn.Sequential(
|
||||
nn.LayerNorm(c_s),
|
||||
@@ -6,7 +6,7 @@ from opt_einsum import contract as einsum
|
||||
|
||||
from modelhub.chemical import ChemicalData as ChemData
|
||||
from modelhub.flow_matching import data_utils as du
|
||||
from modelhub.model.AF3_structure import FourierEmbedding
|
||||
from modelhub.model.RF3_structure import FourierEmbedding
|
||||
from modelhub.model.layers.Attention_module import FeedForwardLayer
|
||||
from modelhub.model.layers.SE3_network import (
|
||||
SE3TransformerWrapper,
|
||||
|
||||
@@ -3,7 +3,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import modelhub
|
||||
from modelhub.model.AF3_structure import PairformerBlock, linearNoBias
|
||||
from modelhub.model.RF3_structure import PairformerBlock, linearNoBias
|
||||
|
||||
# TODO: Get from RF2AA encoding instead
|
||||
CHEM_DATA_LEGACY = {"NHEAVY": 23, "aa2num": {"UNK": 20, "GLY": 7, "MAS": 21}}
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn.functional import one_hot, relu
|
||||
|
||||
from modelhub.model.AF3_blocks import MsaPairWeightedAverage, MsaSubsampleEmbedder
|
||||
from modelhub.model.RF3_blocks import MsaPairWeightedAverage, MsaSubsampleEmbedder
|
||||
from modelhub.model.layers.af3_diffusion_transformer import AtomTransformer
|
||||
from modelhub.model.layers.Attention_module import (
|
||||
TriangleAttention,
|
||||
@@ -23,6 +23,9 @@ from modelhub.model.layers.outer_product import (
|
||||
)
|
||||
from modelhub.training.checkpoint import activation_checkpointing
|
||||
from modelhub.util_module import Dropout
|
||||
from modelhub.data.ground_truth_template import (
|
||||
af3_noise_scale_to_noise_level,
|
||||
)
|
||||
|
||||
|
||||
class AtomAttentionEncoderPairformer(nn.Module):
|
||||
@@ -646,7 +649,11 @@ class MSAModule(nn.Module):
|
||||
return Z_II
|
||||
|
||||
|
||||
class TemplateEmbedder(nn.Module):
|
||||
class AF3TemplateEmbedder(nn.Module):
|
||||
"""
|
||||
AF3-like TemplateEmbedding (e.g., protein-only, etc.)
|
||||
Unused in RF3.
|
||||
"""
|
||||
def __init__(self, n_block, raw_template_dim, c_z, c, p_drop):
|
||||
super().__init__()
|
||||
self.c = c
|
||||
@@ -751,3 +758,110 @@ class TemplateEmbedder(nn.Module):
|
||||
template_restype,
|
||||
asym_id,
|
||||
)
|
||||
|
||||
class RF3TemplateEmbedder(nn.Module):
|
||||
"""
|
||||
Template track that enables conditioning on noisy ground-truth templates at the token level.
|
||||
Supports all chain types.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
n_block,
|
||||
raw_template_dim,
|
||||
c_z,
|
||||
c,
|
||||
p_drop,
|
||||
):
|
||||
super().__init__()
|
||||
self.c = c
|
||||
self.emb_pair = nn.Linear(c_z, c, bias=False)
|
||||
self.norm_pair_before_pairformer = nn.LayerNorm(c_z)
|
||||
self.norm_after_pairformer = nn.LayerNorm(c)
|
||||
self.emb_templ = nn.Linear(raw_template_dim, c, bias=False)
|
||||
|
||||
# template pairformer does not operate on sequence representation
|
||||
self.pairformer = nn.ModuleList(
|
||||
[
|
||||
PairformerBlock(
|
||||
c_s=0,
|
||||
c_z=c,
|
||||
p_drop=p_drop,
|
||||
triangle_multiplication=dict(d_hidden=c),
|
||||
triangle_attention=dict(d_hidden=c),
|
||||
attention_pair_bias={},
|
||||
n_transition=4,
|
||||
)
|
||||
for _ in range(n_block)
|
||||
]
|
||||
)
|
||||
|
||||
# NOTE: this is not consistent with AF3 paper which outputs this tensor in the template_channel dimension
|
||||
# In Algorithm 1, line 9, the outputs of this function are added to the Z_II tensor which has dimensions [B, I, I, C_z]
|
||||
# so we make the outputs of this module also has those dimensions
|
||||
self.agg_emb = nn.Linear(c, c_z, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
f,
|
||||
Z_II,
|
||||
):
|
||||
@activation_checkpointing
|
||||
def embed_templates_like_rfscore(
|
||||
has_distogram_condition, # [I, I]
|
||||
distogram_condition_noise_scale, # [I]
|
||||
distogram_condition, # [I, I, 64], where 64 is the number of distogram bins
|
||||
):
|
||||
I = Z_II.shape[0] # n_tokens
|
||||
|
||||
# Transform noise scale to reasonable range
|
||||
joint_noise_scale = (
|
||||
distogram_condition_noise_scale[None, :] ** 2
|
||||
+ distogram_condition_noise_scale[:, None] ** 2
|
||||
).sqrt()
|
||||
joint_noise_level = af3_noise_scale_to_noise_level(joint_noise_scale)
|
||||
|
||||
# ---------------------------- #
|
||||
|
||||
# ... concatenate along the channel dimension
|
||||
template_feats = torch.cat(
|
||||
[
|
||||
distogram_condition, # [I, I, 64]
|
||||
has_distogram_condition.unsqueeze(-1), # [I, I, 1]
|
||||
joint_noise_level.unsqueeze(-1), # [I, I, 1]
|
||||
],
|
||||
dim=-1,
|
||||
) # [I, I, 66]
|
||||
|
||||
# ... remove any invalid interactions
|
||||
template_feats = template_feats * has_distogram_condition.unsqueeze(
|
||||
-1
|
||||
) # [I, I, 66], where 66 = 64 + 1 + 1
|
||||
|
||||
# ... embed template features
|
||||
template_channels = self.emb_templ(template_feats) # [I, I, c]
|
||||
|
||||
# ---------------------------- #
|
||||
|
||||
# ... pass through pairformer
|
||||
u_II = torch.zeros(I, I, self.c, device=Z_II.device)
|
||||
v_II = (
|
||||
self.emb_pair(self.norm_pair_before_pairformer(Z_II))
|
||||
+ template_channels
|
||||
) # [I, I, c]
|
||||
for block in self.pairformer:
|
||||
_, v_II = block(None, v_II)
|
||||
u_II = u_II + self.norm_after_pairformer(v_II)
|
||||
|
||||
return self.agg_emb(relu(u_II))
|
||||
|
||||
# rfscore template embedding (noisy ground-truth template as input)
|
||||
embedded_templates = embed_templates_like_rfscore(
|
||||
has_distogram_condition=f["has_distogram_condition"], # [I, I]
|
||||
distogram_condition_noise_scale=f["distogram_condition_noise_scale"], # [I]
|
||||
distogram_condition=f[
|
||||
"distogram_condition"
|
||||
], # [I, I, 64], where 64 is the number of distogram bins
|
||||
)
|
||||
|
||||
return embedded_templates
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from modelhub.loss.af3_losses import (
|
||||
SubunitSymmetryResolution,
|
||||
)
|
||||
from modelhub.metrics.base import MetricManager
|
||||
from modelhub.model.AF3 import ShouldEarlyStopFn
|
||||
from modelhub.model.RF3 import ShouldEarlyStopFn
|
||||
from modelhub.trainers.fabric import FabricTrainer
|
||||
from modelhub.training.EMA import EMA
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
Reference in New Issue
Block a user