mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
clean: make pip installable, remove unused files, ruff, add license
This commit is contained in:
28
LICENSE.md
Normal file
28
LICENSE.md
Normal file
@@ -0,0 +1,28 @@
|
||||
BSD 3-Clause License
|
||||
|
||||
Copyright (c) 2025, Institute for Protein Design, University of Washington
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
@@ -1,13 +0,0 @@
|
||||
name: modelhub
|
||||
channels:
|
||||
- nvidia/label/cuda-12.6.0
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- pip
|
||||
- python=3.12
|
||||
- cuda
|
||||
- pytorch=2.7
|
||||
- openbabel=3.1.1
|
||||
- pip:
|
||||
- -r file:requirements.txt
|
||||
Submodule lib/atomworks-dev updated: 4f020cf7f4...cd9e8b76d1
@@ -1,11 +1,51 @@
|
||||
[project]
|
||||
name = "modelhub"
|
||||
name = "rf3"
|
||||
dynamic = ["version"]
|
||||
description = "Base repository for models at the University of Washington's Institute for Protein Design"
|
||||
description = "Open-source biomolecular structure prediction for all molecules of life."
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.10"
|
||||
requires-python = ">= 3.12"
|
||||
authors = [
|
||||
{ name = "Bakerlab", email = "" },
|
||||
{ name = "Institute for Protein Design", email = "contact@ipd.uw.edu" },
|
||||
]
|
||||
license = { file = "LICENSE.md" }
|
||||
|
||||
dependencies = [
|
||||
# ...ml tools
|
||||
"torch>=2.2.0,<3",
|
||||
"lightning>=2.4.0,<2.5",
|
||||
"einops>=0.8.0,<1",
|
||||
"einx>=0.1.0,<1",
|
||||
"opt_einsum>=3.4.0,<4",
|
||||
"dm-tree>=0.1.6,<1",
|
||||
# ... kernels
|
||||
"cuequivariance_ops_cu12>=0.5.0",
|
||||
"cuequivariance_ops_torch_cu12>=0.5.0",
|
||||
"cuequivariance_torch>=0.5.0",
|
||||
# ... configuration & CLI
|
||||
"hydra-core>=1.3.0,<1.4",
|
||||
"environs>=11.0.0,<12",
|
||||
# ... logging
|
||||
"wandb>=0.15.10,<1",
|
||||
"rich>=13.9.4,<14",
|
||||
# ... typing & documentation
|
||||
"jaxtyping>=0.2.17,<1",
|
||||
"beartype>=0.18.0,<1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
# Linters & formatters
|
||||
"ruff==0.8.3",
|
||||
# Debugger/interactive
|
||||
"debugpy>=1.8.5,<2",
|
||||
"ipykernel>=6.29.4,<7",
|
||||
# Testing tools
|
||||
"pytest>=8.2.0,<9", # testing framework
|
||||
"pytest-testmon>=2.1.1,<3", # run only tests related to changed code
|
||||
"pytest-xdist>=3.6.1,<4", # run tests in parallel
|
||||
"pytest-dotenv>=0.5.2,<1", # load environment variables from .env file
|
||||
"pytest-cov>=4.1.0,<5", # generate coverage report
|
||||
"pytest-benchmark>=5.0.0,<6", # benchmark tests for speed
|
||||
]
|
||||
|
||||
# Build settings ----------------------------------------------------------------------
|
||||
@@ -65,7 +105,6 @@ exclude = [
|
||||
"*.ipynb",
|
||||
"dev.py",
|
||||
"archive",
|
||||
"src/modelhub/utils/predicted_error.py", # TEMPORARY
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
# Core ML dependencies
|
||||
# NOTE: torch, torchvision, torchaudio are already included in NGC container
|
||||
lightning>=2.4.0,<2.5
|
||||
|
||||
# Small molecule libraries
|
||||
rdkit>=2024.3.5
|
||||
# TODO: Remove OpenBabel dependency
|
||||
# openbabel will be installed via apt-get (pip installation fails due to C++ build dependencies)
|
||||
|
||||
# Project-related dependencies
|
||||
# ... generic tools
|
||||
GitPython>=3.0.0,<4
|
||||
cython>=3.0.0,<4
|
||||
cytoolz>=0.12.3,<1
|
||||
assertpy>=1.1.0,<2 # TODO: remove this dependency
|
||||
tqdm>=4.65.0,<5
|
||||
rootutils>=1.0.7,<1.1
|
||||
dm-tree>=0.1.6,<1 # TODO: remove this dependency
|
||||
deepdiff>=8.0.0,<9 # TODO: remove this dependency
|
||||
|
||||
# ... configuration & CLI
|
||||
fire>=0.6.0,<1
|
||||
hydra-core>=1.3.0,<1.4
|
||||
environs>=11.0.0,<12
|
||||
|
||||
# ... linear algebra, maths & ml
|
||||
numpy>=1.25.0,<2
|
||||
scipy>=1.13.1,<2
|
||||
einops>=0.8.0,<1
|
||||
einx>=0.1.0,<1
|
||||
opt_einsum>=3.4.0,<4
|
||||
scikit-learn>=1.6.1,<2
|
||||
|
||||
# ... kernels
|
||||
cuequivariance_ops_cu12>=0.5.0
|
||||
cuequivariance_ops_torch_cu12>=0.5.0
|
||||
cuequivariance_torch>=0.5.0
|
||||
|
||||
# ... data tools
|
||||
pandas>=2.2,<2.3
|
||||
pyarrow>=17.0.0
|
||||
fastparquet>=2024.5.0
|
||||
seaborn>=0.13.0,<1
|
||||
|
||||
# ... bioinformatics
|
||||
biopython>=1.83,<2
|
||||
py3Dmol>=2.2.1,<3
|
||||
pymol-remote>=0.0.5
|
||||
biotite==1.3.0 # Fixed version - updating may involve breaking changes
|
||||
hydride==1.2.3 # Fixed version - updating may involve breaking changes
|
||||
|
||||
# ... logging
|
||||
wandb>=0.15.10,<1
|
||||
rich>=13.9.4,<14
|
||||
|
||||
# Formatting & linting (only needed for development)
|
||||
ruff==0.8.3
|
||||
pre-commit==3.7.1
|
||||
|
||||
# Debugger & interactive tools (only needed for development)
|
||||
debugpy>=1.8.5,<2
|
||||
ipykernel>=6.29.4,<7
|
||||
icecream>=2.0.0,<3
|
||||
pymol-remote>=0.1.0
|
||||
ipdb>=0.13.9
|
||||
|
||||
# Pytest plugins (only needed for development)
|
||||
pytest>=8.2.0,<9
|
||||
pytest-testmon>=2.1.1,<3
|
||||
pytest-xdist>=3.6.1,<4
|
||||
pytest-dotenv>=0.5.2,<1
|
||||
pytest-cov>=4.1.0,<5
|
||||
pytest-benchmark>=5.0.0,<6
|
||||
|
||||
# Typing & documentation (only needed for development)
|
||||
jaxtyping>=0.2.17,<1
|
||||
beartype>=0.18.0,<1
|
||||
@@ -1,345 +0,0 @@
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from beartype.typing import Any
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
confusion_matrix,
|
||||
roc_auc_score,
|
||||
)
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
|
||||
from modelhub.callbacks.base import BaseCallback
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from modelhub.utils.logging import print_df_as_table
|
||||
|
||||
# Suppress warnings for cleaner output
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
def clean_and_encode(
|
||||
X: pd.DataFrame, y: pd.Series
|
||||
) -> tuple[pd.DataFrame, np.ndarray, LabelEncoder]:
|
||||
"""Preprocess experimental data and encode categorical variables
|
||||
|
||||
Args:
|
||||
X: Feature matrix
|
||||
y: Target vector
|
||||
|
||||
Returns:
|
||||
Processed features, encoded target, and fitted label encoder
|
||||
"""
|
||||
# Convert all columns to numeric where possible
|
||||
for col in X.columns:
|
||||
X[col] = pd.to_numeric(X[col], errors="coerce")
|
||||
|
||||
# Check for missing values
|
||||
missing_count = X.isnull().sum().sum()
|
||||
if missing_count > 0:
|
||||
ranked_logger.warning(
|
||||
f"Found {missing_count} missing values in feature matrix. Dropping rows."
|
||||
)
|
||||
X = X.dropna()
|
||||
y = y.loc[X.index]
|
||||
|
||||
# Encode target variable
|
||||
label_encoder = LabelEncoder()
|
||||
y_encoded = label_encoder.fit_transform(y)
|
||||
|
||||
return X, y_encoded, label_encoder
|
||||
|
||||
|
||||
def train_model(
|
||||
X_train: pd.DataFrame, y_train: np.ndarray, model_params: dict | None = None
|
||||
) -> RandomForestClassifier:
|
||||
"""Train Random Forest classifier
|
||||
|
||||
Args:
|
||||
X_train: Training features
|
||||
y_train: Training target
|
||||
model_params: Parameters for RandomForestClassifier
|
||||
|
||||
Returns:
|
||||
Trained Random Forest classifier
|
||||
"""
|
||||
# Default parameters
|
||||
default_params = {
|
||||
"n_estimators": 20,
|
||||
"max_depth": 2,
|
||||
"random_state": 42,
|
||||
"class_weight": "balanced",
|
||||
}
|
||||
|
||||
# Use provided parameters or defaults
|
||||
params = {**default_params, **(model_params or {})}
|
||||
|
||||
# Initialize and train model
|
||||
model = RandomForestClassifier(**params)
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def evaluate(
|
||||
model: RandomForestClassifier,
|
||||
X: pd.DataFrame,
|
||||
y: np.ndarray,
|
||||
label_encoder: LabelEncoder,
|
||||
split_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Evaluate model on a dataset split and generate metrics from predictions
|
||||
|
||||
Args:
|
||||
model: Trained classifier
|
||||
X: Feature matrix
|
||||
y: Target vector
|
||||
label_encoder: Fitted label encoder
|
||||
split_name: Dataset split (train, validation, test)
|
||||
|
||||
Returns:
|
||||
Dictionary of evaluation metrics
|
||||
"""
|
||||
# Generate predictions and probabilities
|
||||
y_pred = model.predict(X)
|
||||
y_proba = model.predict_proba(X)
|
||||
|
||||
# Get class index for "ACTIVE" (assuming binary classification)
|
||||
if "ACTIVE" not in label_encoder.classes_:
|
||||
ranked_logger.warning(
|
||||
"Class 'ACTIVE' not found in label encoder. Cannot compute experimental metrics."
|
||||
)
|
||||
return {}
|
||||
|
||||
active_class_index = list(label_encoder.classes_).index("ACTIVE")
|
||||
|
||||
# Calculate metrics (binary classification)
|
||||
accuracy = accuracy_score(y, y_pred)
|
||||
auc = roc_auc_score(y == active_class_index, y_proba[:, active_class_index])
|
||||
conf_matrix = confusion_matrix(y, y_pred)
|
||||
|
||||
# Create metrics summary DataFrame
|
||||
metrics_summary = pd.DataFrame(
|
||||
{"Metric": ["Accuracy", "ROC AUC"], "Value": [accuracy, auc]}
|
||||
)
|
||||
|
||||
# Print metrics summary
|
||||
print_df_as_table(metrics_summary, title=f"{split_name.upper()} SET METRICS")
|
||||
|
||||
# Print confusion matrix
|
||||
conf_df = pd.DataFrame(
|
||||
conf_matrix, index=label_encoder.classes_, columns=label_encoder.classes_
|
||||
)
|
||||
conf_df.index.name = "True"
|
||||
conf_df.columns.name = "Predicted"
|
||||
print_df_as_table(conf_df, title=f"{split_name.upper()} CONFUSION MATRIX")
|
||||
|
||||
# Return metrics
|
||||
return {
|
||||
"accuracy": accuracy,
|
||||
"auc": auc,
|
||||
"confusion_matrix": conf_matrix,
|
||||
}
|
||||
|
||||
|
||||
def evaluate_model_on_all_splits(
|
||||
model: RandomForestClassifier,
|
||||
X_train: pd.DataFrame,
|
||||
y_train: np.ndarray,
|
||||
X_valid: pd.DataFrame,
|
||||
y_valid: np.ndarray,
|
||||
label_encoder: LabelEncoder,
|
||||
X_test: pd.DataFrame | None = None,
|
||||
y_test: np.ndarray | None = None,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Evaluate model on train, validation, and (possibly) test splits
|
||||
|
||||
Args:
|
||||
model: Trained classifier
|
||||
X_train: Training features
|
||||
y_train: Training target
|
||||
X_valid: Validation features
|
||||
y_valid: Validation target
|
||||
label_encoder: Original label encoder
|
||||
X_test: Test features (optional)
|
||||
y_test: Test target (optional)
|
||||
|
||||
Returns:
|
||||
Dictionary of evaluation metrics for all datasets
|
||||
"""
|
||||
# Define datasets to evaluate (train, validation are required)
|
||||
datasets = [("train", X_train, y_train), ("validation", X_valid, y_valid)]
|
||||
if X_test is not None and y_test is not None:
|
||||
datasets.append(("test", X_test, y_test))
|
||||
|
||||
# Evaluate and store metrics for all splits (train, validation, test)
|
||||
all_metrics = {}
|
||||
for name, X, y in datasets:
|
||||
all_metrics[name] = evaluate(model, X, y, label_encoder, name)
|
||||
|
||||
return all_metrics
|
||||
|
||||
|
||||
def fit_and_evaluate(
|
||||
df: pd.DataFrame,
|
||||
feature_metrics: list[str],
|
||||
model_params: dict | None = None,
|
||||
labels_to_include: list[str] | None = ["ACTIVE", "INACTIVE"],
|
||||
datasets: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Train models for each dataset and evaluate classification performance across train, validation, and test splits
|
||||
|
||||
Args:
|
||||
df: DataFrame containing features and targets, grouped by dataset (distinct from "split")
|
||||
feature_metrics: List of metric prefixes to use as features
|
||||
model_params: Parameters for RandomForestClassifier
|
||||
labels_to_include: List of labels to include in the target variable
|
||||
datasets: List of datasets to process (if None, process all datasets in df)
|
||||
|
||||
Returns:
|
||||
Dictionary of evaluation metrics for all datasets and all models
|
||||
"""
|
||||
|
||||
# Fit a model for each dataset
|
||||
results_by_dataset = {}
|
||||
datasets_to_fit = datasets or df["dataset"].unique()
|
||||
for dataset in datasets_to_fit:
|
||||
ranked_logger.info(f"Processing dataset: {dataset}")
|
||||
|
||||
# ... subset data for the current dataset
|
||||
dataset_df = df[df["dataset"] == dataset].copy()
|
||||
assert len(dataset_df) > 0, f"No data found for dataset {dataset}!"
|
||||
|
||||
# ... filter out labels not in "labels_to_include"
|
||||
if labels_to_include:
|
||||
before_count = len(dataset_df)
|
||||
dataset_df = dataset_df[
|
||||
dataset_df["extra_info.activity_bin"].isin(labels_to_include)
|
||||
]
|
||||
after_count = len(dataset_df)
|
||||
if before_count > after_count:
|
||||
ranked_logger.info(
|
||||
f"Filtered out {before_count - after_count} samples not in {labels_to_include}. Remaining: {after_count}"
|
||||
)
|
||||
|
||||
# Extract target
|
||||
if "extra_info.activity_bin" not in dataset_df.columns:
|
||||
ranked_logger.warning(
|
||||
f"Target column 'extra_info.activity_bin' not found in dataset {dataset}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
y = dataset_df["extra_info.activity_bin"]
|
||||
|
||||
# Extract features
|
||||
feature_cols = dataset_df.columns[
|
||||
dataset_df.columns.str.startswith(tuple(feature_metrics))
|
||||
]
|
||||
if len(feature_cols) == 0:
|
||||
ranked_logger.warning(
|
||||
f"No feature columns found for dataset {dataset}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
X = dataset_df[feature_cols]
|
||||
|
||||
# Preprocess data
|
||||
X, y, label_encoder = clean_and_encode(X, y)
|
||||
|
||||
train_mask = dataset_df["extra_info.set"] == "train"
|
||||
valid_mask = dataset_df["extra_info.set"] == "valid"
|
||||
test_mask = dataset_df["extra_info.set"] == "test"
|
||||
|
||||
X_train, y_train = X[train_mask.values], y[train_mask.values]
|
||||
X_valid, y_valid = X[valid_mask.values], y[valid_mask.values]
|
||||
X_test, y_test = (
|
||||
(X[test_mask.values], y[test_mask.values])
|
||||
if test_mask.any()
|
||||
else (None, None)
|
||||
)
|
||||
|
||||
# Train model
|
||||
model = train_model(X_train, y_train, model_params)
|
||||
|
||||
# Evaluate model
|
||||
metrics = evaluate_model_on_all_splits(
|
||||
model=model,
|
||||
X_train=X_train,
|
||||
y_train=y_train,
|
||||
X_valid=X_valid,
|
||||
y_valid=y_valid,
|
||||
X_test=X_test,
|
||||
y_test=y_test,
|
||||
label_encoder=label_encoder,
|
||||
)
|
||||
|
||||
# Store results
|
||||
results_by_dataset[dataset] = {
|
||||
"model": model,
|
||||
"label_encoder": label_encoder,
|
||||
"metrics": metrics,
|
||||
}
|
||||
|
||||
return results_by_dataset
|
||||
|
||||
|
||||
class FitAndEvaluateOnExperimentalDataCallback(BaseCallback):
|
||||
def __init__(
|
||||
self,
|
||||
feature_metrics: list[str],
|
||||
model_params: dict | None = None,
|
||||
datasets: list[str] | None = None,
|
||||
):
|
||||
"""Callback to fit and evaluate models on experimental data
|
||||
|
||||
Args:
|
||||
feature_metrics: List of metric prefixes to use as input features
|
||||
model_params: Parameters for RandomForestClassifier
|
||||
datasets: List of datasets to process (if None, process all datasets in df)
|
||||
"""
|
||||
super().__init__()
|
||||
self.feature_metrics = feature_metrics
|
||||
self.model_params = model_params
|
||||
self.datasets = datasets
|
||||
|
||||
def on_validation_epoch_end(self, trainer: Any):
|
||||
# Only fit and evaluate on experimental data for the global zero rank
|
||||
if not trainer.fabric.is_global_zero:
|
||||
return
|
||||
|
||||
# Check if validation results are available
|
||||
assert hasattr(
|
||||
trainer, "validation_results_path"
|
||||
), "Results path not found! Ensure that StoreValidationMetricsInDFCallback is called first."
|
||||
|
||||
# Load validation results
|
||||
df = pd.read_csv(trainer.validation_results_path)
|
||||
|
||||
# Subset to current epoch
|
||||
current_epoch = trainer.state["current_epoch"]
|
||||
df = df[df["epoch"] == current_epoch]
|
||||
|
||||
# Fit and evaluate models
|
||||
try:
|
||||
results = fit_and_evaluate(
|
||||
df=df,
|
||||
feature_metrics=self.feature_metrics,
|
||||
model_params=self.model_params,
|
||||
datasets=self.datasets,
|
||||
)
|
||||
|
||||
# Log to Fabric
|
||||
for dataset, result in results.items():
|
||||
for split, metrics in result["metrics"].items():
|
||||
for metric in ["accuracy", "auc"]:
|
||||
trainer.fabric.log_dict(
|
||||
{f"val/exp/{dataset}/{split}/{metric}": metrics[metric]},
|
||||
step=trainer.state["current_epoch"],
|
||||
)
|
||||
except ValueError as e:
|
||||
ranked_logger.error(
|
||||
f"Error during experimental model fitting/evaluation: {e}"
|
||||
)
|
||||
@@ -1,6 +1,7 @@
|
||||
import itertools
|
||||
from typing import List
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
@@ -18,7 +19,6 @@ from modelhub.metrics.metric_utils import (
|
||||
spread_batch_into_dictionary,
|
||||
unbin_logits,
|
||||
)
|
||||
import einops
|
||||
|
||||
|
||||
def get_mean_atomwise_plddt(
|
||||
@@ -36,21 +36,25 @@ def get_mean_atomwise_plddt(
|
||||
Returns:
|
||||
plddt: Tensor of shape [B,] with the mean atom-wise pLDDT for each batch
|
||||
"""
|
||||
assert plddt_logits.ndim == 3, "plddt_logits must be a 3D tensor (B, n_token, max_atoms_in_a_token * n_bins)"
|
||||
assert (
|
||||
plddt_logits.ndim == 3
|
||||
), "plddt_logits must be a 3D tensor (B, n_token, max_atoms_in_a_token * n_bins)"
|
||||
|
||||
# TODO: Replace with the last dimension of is_real_atom; right now that number is too large (36) because it includes hydrogens
|
||||
max_atoms_in_a_token = NHEAVY
|
||||
|
||||
# Since the pLDDT logits have the last dimension (max_atoms_in_a_token * n_bins), we can calculate n_bins directly
|
||||
assert plddt_logits.shape[-1] % max_atoms_in_a_token == 0, "The last dimension of plddt_logits must be divisible by max_atoms_in_a_token!"
|
||||
assert (
|
||||
plddt_logits.shape[-1] % max_atoms_in_a_token == 0
|
||||
), "The last dimension of plddt_logits must be divisible by max_atoms_in_a_token!"
|
||||
n_bins = plddt_logits.shape[-1] // max_atoms_in_a_token
|
||||
|
||||
# ... reshape to match what unbin_logits expects
|
||||
reshaped_plddt_logits = einops.rearrange(
|
||||
plddt_logits,
|
||||
'... n_token (max_atoms_in_a_token n_bins) -> ... n_bins n_token max_atoms_in_a_token',
|
||||
"... n_token (max_atoms_in_a_token n_bins) -> ... n_bins n_token max_atoms_in_a_token",
|
||||
max_atoms_in_a_token=max_atoms_in_a_token,
|
||||
n_bins=n_bins
|
||||
n_bins=n_bins,
|
||||
).float() # [..., n_token, n_bins * max_atoms_in_a_token] -> [ ..., n_bins, n_token, max_atoms_in_a_token]
|
||||
|
||||
plddt = unbin_logits(
|
||||
@@ -144,7 +148,7 @@ def compile_af3_confidence_outputs(
|
||||
plddt_chainwise = {
|
||||
k: spread_batch_into_dictionary(
|
||||
compute_mean_over_subsampled_pairs(
|
||||
plddt, is_real_atom[..., : NHEAVY] * v[:, None]
|
||||
plddt, is_real_atom[..., :NHEAVY] * v[:, None]
|
||||
)
|
||||
)
|
||||
for k, v in chain_masks_1d.items()
|
||||
@@ -153,7 +157,9 @@ def compile_af3_confidence_outputs(
|
||||
# Aggregate confidence data
|
||||
confidence_data = {
|
||||
"example_id": example_id,
|
||||
"mean_plddt": spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(plddt, is_real_atom[..., : NHEAVY])),
|
||||
"mean_plddt": spread_batch_into_dictionary(
|
||||
compute_mean_over_subsampled_pairs(plddt, is_real_atom[..., :NHEAVY])
|
||||
),
|
||||
"mean_pae": spread_batch_into_dictionary(pae.mean(dim=(-1, -2))),
|
||||
"mean_pde": spread_batch_into_dictionary(pde.mean(dim=(-1, -2))),
|
||||
"chain_wise_mean_plddt": plddt_chainwise,
|
||||
@@ -300,9 +306,9 @@ def compute_batch_indices_with_lowest_predicted_error(
|
||||
|
||||
complex_pae = pae_logits_unbinned.mean(dim=(1, 2))
|
||||
complex_pde = pde_logits_unbinned.mean(dim=(1, 2))
|
||||
complex_plddt = (
|
||||
plddt_logits_unbinned * is_real_atom[..., : NHEAVY]
|
||||
).sum(dim=(1, 2)) / is_real_atom[..., : NHEAVY].sum()
|
||||
complex_plddt = (plddt_logits_unbinned * is_real_atom[..., :NHEAVY]).sum(
|
||||
dim=(1, 2)
|
||||
) / is_real_atom[..., :NHEAVY].sum()
|
||||
|
||||
return_dict["pae_idx"] = torch.argmin(complex_pae)
|
||||
return_dict["pde_idx"] = torch.argmin(complex_pde)
|
||||
@@ -353,7 +359,7 @@ def annotate_atom_array_b_factor_with_plddt(
|
||||
because the AtomArray class does not support setting different values as annotations
|
||||
other than the coordinate feature.
|
||||
"""
|
||||
atom_wise_plddt = plddt[:, is_real_atom[..., : NHEAVY]]
|
||||
atom_wise_plddt = plddt[:, is_real_atom[..., :NHEAVY]]
|
||||
assert atom_wise_plddt.shape[1] == atom_array.array_length()
|
||||
atom_array_list = []
|
||||
# bitotite's AtomArray does not support setting different values as annotations other than
|
||||
|
||||
Reference in New Issue
Block a user