mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
clean: make pip installable, remove unused files, ruff, add license
This commit is contained in:
28
LICENSE.md
Normal file
28
LICENSE.md
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
BSD 3-Clause License
|
||||||
|
|
||||||
|
Copyright (c) 2025, Institute for Protein Design, University of Washington
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
* Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
* Neither the name of the copyright holder nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||||
@@ -1,13 +0,0 @@
|
|||||||
name: modelhub
|
|
||||||
channels:
|
|
||||||
- nvidia/label/cuda-12.6.0
|
|
||||||
- conda-forge
|
|
||||||
- defaults
|
|
||||||
dependencies:
|
|
||||||
- pip
|
|
||||||
- python=3.12
|
|
||||||
- cuda
|
|
||||||
- pytorch=2.7
|
|
||||||
- openbabel=3.1.1
|
|
||||||
- pip:
|
|
||||||
- -r file:requirements.txt
|
|
||||||
Submodule lib/atomworks-dev updated: 4f020cf7f4...cd9e8b76d1
@@ -1,11 +1,51 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "modelhub"
|
name = "rf3"
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
description = "Base repository for models at the University of Washington's Institute for Protein Design"
|
description = "Open-source biomolecular structure prediction for all molecules of life."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">= 3.10"
|
requires-python = ">= 3.12"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Bakerlab", email = "" },
|
{ name = "Institute for Protein Design", email = "contact@ipd.uw.edu" },
|
||||||
|
]
|
||||||
|
license = { file = "LICENSE.md" }
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
# ...ml tools
|
||||||
|
"torch>=2.2.0,<3",
|
||||||
|
"lightning>=2.4.0,<2.5",
|
||||||
|
"einops>=0.8.0,<1",
|
||||||
|
"einx>=0.1.0,<1",
|
||||||
|
"opt_einsum>=3.4.0,<4",
|
||||||
|
"dm-tree>=0.1.6,<1",
|
||||||
|
# ... kernels
|
||||||
|
"cuequivariance_ops_cu12>=0.5.0",
|
||||||
|
"cuequivariance_ops_torch_cu12>=0.5.0",
|
||||||
|
"cuequivariance_torch>=0.5.0",
|
||||||
|
# ... configuration & CLI
|
||||||
|
"hydra-core>=1.3.0,<1.4",
|
||||||
|
"environs>=11.0.0,<12",
|
||||||
|
# ... logging
|
||||||
|
"wandb>=0.15.10,<1",
|
||||||
|
"rich>=13.9.4,<14",
|
||||||
|
# ... typing & documentation
|
||||||
|
"jaxtyping>=0.2.17,<1",
|
||||||
|
"beartype>=0.18.0,<1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
# Linters & formatters
|
||||||
|
"ruff==0.8.3",
|
||||||
|
# Debugger/interactive
|
||||||
|
"debugpy>=1.8.5,<2",
|
||||||
|
"ipykernel>=6.29.4,<7",
|
||||||
|
# Testing tools
|
||||||
|
"pytest>=8.2.0,<9", # testing framework
|
||||||
|
"pytest-testmon>=2.1.1,<3", # run only tests related to changed code
|
||||||
|
"pytest-xdist>=3.6.1,<4", # run tests in parallel
|
||||||
|
"pytest-dotenv>=0.5.2,<1", # load environment variables from .env file
|
||||||
|
"pytest-cov>=4.1.0,<5", # generate coverage report
|
||||||
|
"pytest-benchmark>=5.0.0,<6", # benchmark tests for speed
|
||||||
]
|
]
|
||||||
|
|
||||||
# Build settings ----------------------------------------------------------------------
|
# Build settings ----------------------------------------------------------------------
|
||||||
@@ -65,7 +105,6 @@ exclude = [
|
|||||||
"*.ipynb",
|
"*.ipynb",
|
||||||
"dev.py",
|
"dev.py",
|
||||||
"archive",
|
"archive",
|
||||||
"src/modelhub/utils/predicted_error.py", # TEMPORARY
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
|
|||||||
@@ -1,77 +0,0 @@
|
|||||||
# Core ML dependencies
|
|
||||||
# NOTE: torch, torchvision, torchaudio are already included in NGC container
|
|
||||||
lightning>=2.4.0,<2.5
|
|
||||||
|
|
||||||
# Small molecule libraries
|
|
||||||
rdkit>=2024.3.5
|
|
||||||
# TODO: Remove OpenBabel dependency
|
|
||||||
# openbabel will be installed via apt-get (pip installation fails due to C++ build dependencies)
|
|
||||||
|
|
||||||
# Project-related dependencies
|
|
||||||
# ... generic tools
|
|
||||||
GitPython>=3.0.0,<4
|
|
||||||
cython>=3.0.0,<4
|
|
||||||
cytoolz>=0.12.3,<1
|
|
||||||
assertpy>=1.1.0,<2 # TODO: remove this dependency
|
|
||||||
tqdm>=4.65.0,<5
|
|
||||||
rootutils>=1.0.7,<1.1
|
|
||||||
dm-tree>=0.1.6,<1 # TODO: remove this dependency
|
|
||||||
deepdiff>=8.0.0,<9 # TODO: remove this dependency
|
|
||||||
|
|
||||||
# ... configuration & CLI
|
|
||||||
fire>=0.6.0,<1
|
|
||||||
hydra-core>=1.3.0,<1.4
|
|
||||||
environs>=11.0.0,<12
|
|
||||||
|
|
||||||
# ... linear algebra, maths & ml
|
|
||||||
numpy>=1.25.0,<2
|
|
||||||
scipy>=1.13.1,<2
|
|
||||||
einops>=0.8.0,<1
|
|
||||||
einx>=0.1.0,<1
|
|
||||||
opt_einsum>=3.4.0,<4
|
|
||||||
scikit-learn>=1.6.1,<2
|
|
||||||
|
|
||||||
# ... kernels
|
|
||||||
cuequivariance_ops_cu12>=0.5.0
|
|
||||||
cuequivariance_ops_torch_cu12>=0.5.0
|
|
||||||
cuequivariance_torch>=0.5.0
|
|
||||||
|
|
||||||
# ... data tools
|
|
||||||
pandas>=2.2,<2.3
|
|
||||||
pyarrow>=17.0.0
|
|
||||||
fastparquet>=2024.5.0
|
|
||||||
seaborn>=0.13.0,<1
|
|
||||||
|
|
||||||
# ... bioinformatics
|
|
||||||
biopython>=1.83,<2
|
|
||||||
py3Dmol>=2.2.1,<3
|
|
||||||
pymol-remote>=0.0.5
|
|
||||||
biotite==1.3.0 # Fixed version - updating may involve breaking changes
|
|
||||||
hydride==1.2.3 # Fixed version - updating may involve breaking changes
|
|
||||||
|
|
||||||
# ... logging
|
|
||||||
wandb>=0.15.10,<1
|
|
||||||
rich>=13.9.4,<14
|
|
||||||
|
|
||||||
# Formatting & linting (only needed for development)
|
|
||||||
ruff==0.8.3
|
|
||||||
pre-commit==3.7.1
|
|
||||||
|
|
||||||
# Debugger & interactive tools (only needed for development)
|
|
||||||
debugpy>=1.8.5,<2
|
|
||||||
ipykernel>=6.29.4,<7
|
|
||||||
icecream>=2.0.0,<3
|
|
||||||
pymol-remote>=0.1.0
|
|
||||||
ipdb>=0.13.9
|
|
||||||
|
|
||||||
# Pytest plugins (only needed for development)
|
|
||||||
pytest>=8.2.0,<9
|
|
||||||
pytest-testmon>=2.1.1,<3
|
|
||||||
pytest-xdist>=3.6.1,<4
|
|
||||||
pytest-dotenv>=0.5.2,<1
|
|
||||||
pytest-cov>=4.1.0,<5
|
|
||||||
pytest-benchmark>=5.0.0,<6
|
|
||||||
|
|
||||||
# Typing & documentation (only needed for development)
|
|
||||||
jaxtyping>=0.2.17,<1
|
|
||||||
beartype>=0.18.0,<1
|
|
||||||
@@ -1,345 +0,0 @@
|
|||||||
import warnings
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from beartype.typing import Any
|
|
||||||
from sklearn.ensemble import RandomForestClassifier
|
|
||||||
from sklearn.metrics import (
|
|
||||||
accuracy_score,
|
|
||||||
confusion_matrix,
|
|
||||||
roc_auc_score,
|
|
||||||
)
|
|
||||||
from sklearn.preprocessing import LabelEncoder
|
|
||||||
|
|
||||||
from modelhub.callbacks.base import BaseCallback
|
|
||||||
from modelhub.utils.ddp import RankedLogger
|
|
||||||
from modelhub.utils.logging import print_df_as_table
|
|
||||||
|
|
||||||
# Suppress warnings for cleaner output
|
|
||||||
warnings.filterwarnings("ignore")
|
|
||||||
|
|
||||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
|
||||||
|
|
||||||
|
|
||||||
def clean_and_encode(
|
|
||||||
X: pd.DataFrame, y: pd.Series
|
|
||||||
) -> tuple[pd.DataFrame, np.ndarray, LabelEncoder]:
|
|
||||||
"""Preprocess experimental data and encode categorical variables
|
|
||||||
|
|
||||||
Args:
|
|
||||||
X: Feature matrix
|
|
||||||
y: Target vector
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Processed features, encoded target, and fitted label encoder
|
|
||||||
"""
|
|
||||||
# Convert all columns to numeric where possible
|
|
||||||
for col in X.columns:
|
|
||||||
X[col] = pd.to_numeric(X[col], errors="coerce")
|
|
||||||
|
|
||||||
# Check for missing values
|
|
||||||
missing_count = X.isnull().sum().sum()
|
|
||||||
if missing_count > 0:
|
|
||||||
ranked_logger.warning(
|
|
||||||
f"Found {missing_count} missing values in feature matrix. Dropping rows."
|
|
||||||
)
|
|
||||||
X = X.dropna()
|
|
||||||
y = y.loc[X.index]
|
|
||||||
|
|
||||||
# Encode target variable
|
|
||||||
label_encoder = LabelEncoder()
|
|
||||||
y_encoded = label_encoder.fit_transform(y)
|
|
||||||
|
|
||||||
return X, y_encoded, label_encoder
|
|
||||||
|
|
||||||
|
|
||||||
def train_model(
|
|
||||||
X_train: pd.DataFrame, y_train: np.ndarray, model_params: dict | None = None
|
|
||||||
) -> RandomForestClassifier:
|
|
||||||
"""Train Random Forest classifier
|
|
||||||
|
|
||||||
Args:
|
|
||||||
X_train: Training features
|
|
||||||
y_train: Training target
|
|
||||||
model_params: Parameters for RandomForestClassifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Trained Random Forest classifier
|
|
||||||
"""
|
|
||||||
# Default parameters
|
|
||||||
default_params = {
|
|
||||||
"n_estimators": 20,
|
|
||||||
"max_depth": 2,
|
|
||||||
"random_state": 42,
|
|
||||||
"class_weight": "balanced",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Use provided parameters or defaults
|
|
||||||
params = {**default_params, **(model_params or {})}
|
|
||||||
|
|
||||||
# Initialize and train model
|
|
||||||
model = RandomForestClassifier(**params)
|
|
||||||
model.fit(X_train, y_train)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
model: RandomForestClassifier,
|
|
||||||
X: pd.DataFrame,
|
|
||||||
y: np.ndarray,
|
|
||||||
label_encoder: LabelEncoder,
|
|
||||||
split_name: str,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Evaluate model on a dataset split and generate metrics from predictions
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Trained classifier
|
|
||||||
X: Feature matrix
|
|
||||||
y: Target vector
|
|
||||||
label_encoder: Fitted label encoder
|
|
||||||
split_name: Dataset split (train, validation, test)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary of evaluation metrics
|
|
||||||
"""
|
|
||||||
# Generate predictions and probabilities
|
|
||||||
y_pred = model.predict(X)
|
|
||||||
y_proba = model.predict_proba(X)
|
|
||||||
|
|
||||||
# Get class index for "ACTIVE" (assuming binary classification)
|
|
||||||
if "ACTIVE" not in label_encoder.classes_:
|
|
||||||
ranked_logger.warning(
|
|
||||||
"Class 'ACTIVE' not found in label encoder. Cannot compute experimental metrics."
|
|
||||||
)
|
|
||||||
return {}
|
|
||||||
|
|
||||||
active_class_index = list(label_encoder.classes_).index("ACTIVE")
|
|
||||||
|
|
||||||
# Calculate metrics (binary classification)
|
|
||||||
accuracy = accuracy_score(y, y_pred)
|
|
||||||
auc = roc_auc_score(y == active_class_index, y_proba[:, active_class_index])
|
|
||||||
conf_matrix = confusion_matrix(y, y_pred)
|
|
||||||
|
|
||||||
# Create metrics summary DataFrame
|
|
||||||
metrics_summary = pd.DataFrame(
|
|
||||||
{"Metric": ["Accuracy", "ROC AUC"], "Value": [accuracy, auc]}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print metrics summary
|
|
||||||
print_df_as_table(metrics_summary, title=f"{split_name.upper()} SET METRICS")
|
|
||||||
|
|
||||||
# Print confusion matrix
|
|
||||||
conf_df = pd.DataFrame(
|
|
||||||
conf_matrix, index=label_encoder.classes_, columns=label_encoder.classes_
|
|
||||||
)
|
|
||||||
conf_df.index.name = "True"
|
|
||||||
conf_df.columns.name = "Predicted"
|
|
||||||
print_df_as_table(conf_df, title=f"{split_name.upper()} CONFUSION MATRIX")
|
|
||||||
|
|
||||||
# Return metrics
|
|
||||||
return {
|
|
||||||
"accuracy": accuracy,
|
|
||||||
"auc": auc,
|
|
||||||
"confusion_matrix": conf_matrix,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_model_on_all_splits(
|
|
||||||
model: RandomForestClassifier,
|
|
||||||
X_train: pd.DataFrame,
|
|
||||||
y_train: np.ndarray,
|
|
||||||
X_valid: pd.DataFrame,
|
|
||||||
y_valid: np.ndarray,
|
|
||||||
label_encoder: LabelEncoder,
|
|
||||||
X_test: pd.DataFrame | None = None,
|
|
||||||
y_test: np.ndarray | None = None,
|
|
||||||
) -> dict[str, dict[str, Any]]:
|
|
||||||
"""Evaluate model on train, validation, and (possibly) test splits
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Trained classifier
|
|
||||||
X_train: Training features
|
|
||||||
y_train: Training target
|
|
||||||
X_valid: Validation features
|
|
||||||
y_valid: Validation target
|
|
||||||
label_encoder: Original label encoder
|
|
||||||
X_test: Test features (optional)
|
|
||||||
y_test: Test target (optional)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary of evaluation metrics for all datasets
|
|
||||||
"""
|
|
||||||
# Define datasets to evaluate (train, validation are required)
|
|
||||||
datasets = [("train", X_train, y_train), ("validation", X_valid, y_valid)]
|
|
||||||
if X_test is not None and y_test is not None:
|
|
||||||
datasets.append(("test", X_test, y_test))
|
|
||||||
|
|
||||||
# Evaluate and store metrics for all splits (train, validation, test)
|
|
||||||
all_metrics = {}
|
|
||||||
for name, X, y in datasets:
|
|
||||||
all_metrics[name] = evaluate(model, X, y, label_encoder, name)
|
|
||||||
|
|
||||||
return all_metrics
|
|
||||||
|
|
||||||
|
|
||||||
def fit_and_evaluate(
|
|
||||||
df: pd.DataFrame,
|
|
||||||
feature_metrics: list[str],
|
|
||||||
model_params: dict | None = None,
|
|
||||||
labels_to_include: list[str] | None = ["ACTIVE", "INACTIVE"],
|
|
||||||
datasets: list[str] | None = None,
|
|
||||||
) -> dict:
|
|
||||||
"""Train models for each dataset and evaluate classification performance across train, validation, and test splits
|
|
||||||
|
|
||||||
Args:
|
|
||||||
df: DataFrame containing features and targets, grouped by dataset (distinct from "split")
|
|
||||||
feature_metrics: List of metric prefixes to use as features
|
|
||||||
model_params: Parameters for RandomForestClassifier
|
|
||||||
labels_to_include: List of labels to include in the target variable
|
|
||||||
datasets: List of datasets to process (if None, process all datasets in df)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary of evaluation metrics for all datasets and all models
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Fit a model for each dataset
|
|
||||||
results_by_dataset = {}
|
|
||||||
datasets_to_fit = datasets or df["dataset"].unique()
|
|
||||||
for dataset in datasets_to_fit:
|
|
||||||
ranked_logger.info(f"Processing dataset: {dataset}")
|
|
||||||
|
|
||||||
# ... subset data for the current dataset
|
|
||||||
dataset_df = df[df["dataset"] == dataset].copy()
|
|
||||||
assert len(dataset_df) > 0, f"No data found for dataset {dataset}!"
|
|
||||||
|
|
||||||
# ... filter out labels not in "labels_to_include"
|
|
||||||
if labels_to_include:
|
|
||||||
before_count = len(dataset_df)
|
|
||||||
dataset_df = dataset_df[
|
|
||||||
dataset_df["extra_info.activity_bin"].isin(labels_to_include)
|
|
||||||
]
|
|
||||||
after_count = len(dataset_df)
|
|
||||||
if before_count > after_count:
|
|
||||||
ranked_logger.info(
|
|
||||||
f"Filtered out {before_count - after_count} samples not in {labels_to_include}. Remaining: {after_count}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract target
|
|
||||||
if "extra_info.activity_bin" not in dataset_df.columns:
|
|
||||||
ranked_logger.warning(
|
|
||||||
f"Target column 'extra_info.activity_bin' not found in dataset {dataset}. Skipping."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
y = dataset_df["extra_info.activity_bin"]
|
|
||||||
|
|
||||||
# Extract features
|
|
||||||
feature_cols = dataset_df.columns[
|
|
||||||
dataset_df.columns.str.startswith(tuple(feature_metrics))
|
|
||||||
]
|
|
||||||
if len(feature_cols) == 0:
|
|
||||||
ranked_logger.warning(
|
|
||||||
f"No feature columns found for dataset {dataset}. Skipping."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
X = dataset_df[feature_cols]
|
|
||||||
|
|
||||||
# Preprocess data
|
|
||||||
X, y, label_encoder = clean_and_encode(X, y)
|
|
||||||
|
|
||||||
train_mask = dataset_df["extra_info.set"] == "train"
|
|
||||||
valid_mask = dataset_df["extra_info.set"] == "valid"
|
|
||||||
test_mask = dataset_df["extra_info.set"] == "test"
|
|
||||||
|
|
||||||
X_train, y_train = X[train_mask.values], y[train_mask.values]
|
|
||||||
X_valid, y_valid = X[valid_mask.values], y[valid_mask.values]
|
|
||||||
X_test, y_test = (
|
|
||||||
(X[test_mask.values], y[test_mask.values])
|
|
||||||
if test_mask.any()
|
|
||||||
else (None, None)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Train model
|
|
||||||
model = train_model(X_train, y_train, model_params)
|
|
||||||
|
|
||||||
# Evaluate model
|
|
||||||
metrics = evaluate_model_on_all_splits(
|
|
||||||
model=model,
|
|
||||||
X_train=X_train,
|
|
||||||
y_train=y_train,
|
|
||||||
X_valid=X_valid,
|
|
||||||
y_valid=y_valid,
|
|
||||||
X_test=X_test,
|
|
||||||
y_test=y_test,
|
|
||||||
label_encoder=label_encoder,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Store results
|
|
||||||
results_by_dataset[dataset] = {
|
|
||||||
"model": model,
|
|
||||||
"label_encoder": label_encoder,
|
|
||||||
"metrics": metrics,
|
|
||||||
}
|
|
||||||
|
|
||||||
return results_by_dataset
|
|
||||||
|
|
||||||
|
|
||||||
class FitAndEvaluateOnExperimentalDataCallback(BaseCallback):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
feature_metrics: list[str],
|
|
||||||
model_params: dict | None = None,
|
|
||||||
datasets: list[str] | None = None,
|
|
||||||
):
|
|
||||||
"""Callback to fit and evaluate models on experimental data
|
|
||||||
|
|
||||||
Args:
|
|
||||||
feature_metrics: List of metric prefixes to use as input features
|
|
||||||
model_params: Parameters for RandomForestClassifier
|
|
||||||
datasets: List of datasets to process (if None, process all datasets in df)
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.feature_metrics = feature_metrics
|
|
||||||
self.model_params = model_params
|
|
||||||
self.datasets = datasets
|
|
||||||
|
|
||||||
def on_validation_epoch_end(self, trainer: Any):
|
|
||||||
# Only fit and evaluate on experimental data for the global zero rank
|
|
||||||
if not trainer.fabric.is_global_zero:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check if validation results are available
|
|
||||||
assert hasattr(
|
|
||||||
trainer, "validation_results_path"
|
|
||||||
), "Results path not found! Ensure that StoreValidationMetricsInDFCallback is called first."
|
|
||||||
|
|
||||||
# Load validation results
|
|
||||||
df = pd.read_csv(trainer.validation_results_path)
|
|
||||||
|
|
||||||
# Subset to current epoch
|
|
||||||
current_epoch = trainer.state["current_epoch"]
|
|
||||||
df = df[df["epoch"] == current_epoch]
|
|
||||||
|
|
||||||
# Fit and evaluate models
|
|
||||||
try:
|
|
||||||
results = fit_and_evaluate(
|
|
||||||
df=df,
|
|
||||||
feature_metrics=self.feature_metrics,
|
|
||||||
model_params=self.model_params,
|
|
||||||
datasets=self.datasets,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log to Fabric
|
|
||||||
for dataset, result in results.items():
|
|
||||||
for split, metrics in result["metrics"].items():
|
|
||||||
for metric in ["accuracy", "auc"]:
|
|
||||||
trainer.fabric.log_dict(
|
|
||||||
{f"val/exp/{dataset}/{split}/{metric}": metrics[metric]},
|
|
||||||
step=trainer.state["current_epoch"],
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
ranked_logger.error(
|
|
||||||
f"Error during experimental model fitting/evaluation: {e}"
|
|
||||||
)
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import itertools
|
import itertools
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@@ -18,7 +19,6 @@ from modelhub.metrics.metric_utils import (
|
|||||||
spread_batch_into_dictionary,
|
spread_batch_into_dictionary,
|
||||||
unbin_logits,
|
unbin_logits,
|
||||||
)
|
)
|
||||||
import einops
|
|
||||||
|
|
||||||
|
|
||||||
def get_mean_atomwise_plddt(
|
def get_mean_atomwise_plddt(
|
||||||
@@ -36,22 +36,26 @@ def get_mean_atomwise_plddt(
|
|||||||
Returns:
|
Returns:
|
||||||
plddt: Tensor of shape [B,] with the mean atom-wise pLDDT for each batch
|
plddt: Tensor of shape [B,] with the mean atom-wise pLDDT for each batch
|
||||||
"""
|
"""
|
||||||
assert plddt_logits.ndim == 3, "plddt_logits must be a 3D tensor (B, n_token, max_atoms_in_a_token * n_bins)"
|
assert (
|
||||||
|
plddt_logits.ndim == 3
|
||||||
|
), "plddt_logits must be a 3D tensor (B, n_token, max_atoms_in_a_token * n_bins)"
|
||||||
|
|
||||||
# TODO: Replace with the last dimension of is_real_atom; right now that number is too large (36) because it includes hydrogens
|
# TODO: Replace with the last dimension of is_real_atom; right now that number is too large (36) because it includes hydrogens
|
||||||
max_atoms_in_a_token = NHEAVY
|
max_atoms_in_a_token = NHEAVY
|
||||||
|
|
||||||
# Since the pLDDT logits have the last dimension (max_atoms_in_a_token * n_bins), we can calculate n_bins directly
|
# Since the pLDDT logits have the last dimension (max_atoms_in_a_token * n_bins), we can calculate n_bins directly
|
||||||
assert plddt_logits.shape[-1] % max_atoms_in_a_token == 0, "The last dimension of plddt_logits must be divisible by max_atoms_in_a_token!"
|
assert (
|
||||||
|
plddt_logits.shape[-1] % max_atoms_in_a_token == 0
|
||||||
|
), "The last dimension of plddt_logits must be divisible by max_atoms_in_a_token!"
|
||||||
n_bins = plddt_logits.shape[-1] // max_atoms_in_a_token
|
n_bins = plddt_logits.shape[-1] // max_atoms_in_a_token
|
||||||
|
|
||||||
# ... reshape to match what unbin_logits expects
|
# ... reshape to match what unbin_logits expects
|
||||||
reshaped_plddt_logits = einops.rearrange(
|
reshaped_plddt_logits = einops.rearrange(
|
||||||
plddt_logits,
|
plddt_logits,
|
||||||
'... n_token (max_atoms_in_a_token n_bins) -> ... n_bins n_token max_atoms_in_a_token',
|
"... n_token (max_atoms_in_a_token n_bins) -> ... n_bins n_token max_atoms_in_a_token",
|
||||||
max_atoms_in_a_token=max_atoms_in_a_token,
|
max_atoms_in_a_token=max_atoms_in_a_token,
|
||||||
n_bins=n_bins
|
n_bins=n_bins,
|
||||||
).float() # [..., n_token, n_bins * max_atoms_in_a_token] -> [ ..., n_bins, n_token, max_atoms_in_a_token]
|
).float() # [..., n_token, n_bins * max_atoms_in_a_token] -> [ ..., n_bins, n_token, max_atoms_in_a_token]
|
||||||
|
|
||||||
plddt = unbin_logits(
|
plddt = unbin_logits(
|
||||||
reshaped_plddt_logits,
|
reshaped_plddt_logits,
|
||||||
@@ -144,7 +148,7 @@ def compile_af3_confidence_outputs(
|
|||||||
plddt_chainwise = {
|
plddt_chainwise = {
|
||||||
k: spread_batch_into_dictionary(
|
k: spread_batch_into_dictionary(
|
||||||
compute_mean_over_subsampled_pairs(
|
compute_mean_over_subsampled_pairs(
|
||||||
plddt, is_real_atom[..., : NHEAVY] * v[:, None]
|
plddt, is_real_atom[..., :NHEAVY] * v[:, None]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for k, v in chain_masks_1d.items()
|
for k, v in chain_masks_1d.items()
|
||||||
@@ -153,7 +157,9 @@ def compile_af3_confidence_outputs(
|
|||||||
# Aggregate confidence data
|
# Aggregate confidence data
|
||||||
confidence_data = {
|
confidence_data = {
|
||||||
"example_id": example_id,
|
"example_id": example_id,
|
||||||
"mean_plddt": spread_batch_into_dictionary(compute_mean_over_subsampled_pairs(plddt, is_real_atom[..., : NHEAVY])),
|
"mean_plddt": spread_batch_into_dictionary(
|
||||||
|
compute_mean_over_subsampled_pairs(plddt, is_real_atom[..., :NHEAVY])
|
||||||
|
),
|
||||||
"mean_pae": spread_batch_into_dictionary(pae.mean(dim=(-1, -2))),
|
"mean_pae": spread_batch_into_dictionary(pae.mean(dim=(-1, -2))),
|
||||||
"mean_pde": spread_batch_into_dictionary(pde.mean(dim=(-1, -2))),
|
"mean_pde": spread_batch_into_dictionary(pde.mean(dim=(-1, -2))),
|
||||||
"chain_wise_mean_plddt": plddt_chainwise,
|
"chain_wise_mean_plddt": plddt_chainwise,
|
||||||
@@ -300,9 +306,9 @@ def compute_batch_indices_with_lowest_predicted_error(
|
|||||||
|
|
||||||
complex_pae = pae_logits_unbinned.mean(dim=(1, 2))
|
complex_pae = pae_logits_unbinned.mean(dim=(1, 2))
|
||||||
complex_pde = pde_logits_unbinned.mean(dim=(1, 2))
|
complex_pde = pde_logits_unbinned.mean(dim=(1, 2))
|
||||||
complex_plddt = (
|
complex_plddt = (plddt_logits_unbinned * is_real_atom[..., :NHEAVY]).sum(
|
||||||
plddt_logits_unbinned * is_real_atom[..., : NHEAVY]
|
dim=(1, 2)
|
||||||
).sum(dim=(1, 2)) / is_real_atom[..., : NHEAVY].sum()
|
) / is_real_atom[..., :NHEAVY].sum()
|
||||||
|
|
||||||
return_dict["pae_idx"] = torch.argmin(complex_pae)
|
return_dict["pae_idx"] = torch.argmin(complex_pae)
|
||||||
return_dict["pde_idx"] = torch.argmin(complex_pde)
|
return_dict["pde_idx"] = torch.argmin(complex_pde)
|
||||||
@@ -353,7 +359,7 @@ def annotate_atom_array_b_factor_with_plddt(
|
|||||||
because the AtomArray class does not support setting different values as annotations
|
because the AtomArray class does not support setting different values as annotations
|
||||||
other than the coordinate feature.
|
other than the coordinate feature.
|
||||||
"""
|
"""
|
||||||
atom_wise_plddt = plddt[:, is_real_atom[..., : NHEAVY]]
|
atom_wise_plddt = plddt[:, is_real_atom[..., :NHEAVY]]
|
||||||
assert atom_wise_plddt.shape[1] == atom_array.array_length()
|
assert atom_wise_plddt.shape[1] == atom_array.array_length()
|
||||||
atom_array_list = []
|
atom_array_list = []
|
||||||
# bitotite's AtomArray does not support setting different values as annotations other than
|
# bitotite's AtomArray does not support setting different values as annotations other than
|
||||||
|
|||||||
Reference in New Issue
Block a user