mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
changes required for pytorch2
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
name: openfold-venv
|
||||
name: pytorch1-plupgrade
|
||||
channels:
|
||||
- conda-forge
|
||||
- bioconda
|
||||
@@ -11,7 +11,7 @@ dependencies:
|
||||
- openmm=7.7
|
||||
- pdbfixer
|
||||
- cudatoolkit==11.3.*
|
||||
- pytorch-lightning==1.5.10
|
||||
- pytorch-lightning==2.0.9
|
||||
- biopython==1.79
|
||||
- numpy==1.21
|
||||
- pandas==2.0
|
||||
@@ -19,7 +19,7 @@ dependencies:
|
||||
- requests
|
||||
- scipy==1.7
|
||||
- tqdm==4.62.2
|
||||
- typing-extensions==3.10
|
||||
- typing-extensions==4.0
|
||||
- wandb==0.12.21
|
||||
- modelcif==0.7
|
||||
- awscli
|
||||
@@ -31,6 +31,7 @@ dependencies:
|
||||
- bioconda::kalign2==2.04
|
||||
- pytorch::pytorch=1.12.*
|
||||
- pip:
|
||||
- mpi4py==3.1.5
|
||||
- deepspeed==0.12.4
|
||||
- dm-tree==0.1.6
|
||||
- git+https://github.com/NVIDIA/dllogger.git
|
||||
|
||||
@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
|
||||
with open(distillation_alignment_index_path, "r") as fp:
|
||||
self.distillation_alignment_index = json.load(fp)
|
||||
|
||||
def setup(self):
|
||||
def setup(self, stage=None):
|
||||
# Most of the arguments are the same for the three datasets
|
||||
dataset_gen = partial(OpenFoldSingleDataset,
|
||||
template_mmcif_dir=self.template_mmcif_dir,
|
||||
@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
|
||||
mode="predict",
|
||||
)
|
||||
|
||||
def _gen_dataloader(self, stage):
|
||||
def _gen_dataloader(self, stage=None):
|
||||
generator = None
|
||||
if self.batch_seed is not None:
|
||||
generator = torch.Generator()
|
||||
@@ -1053,7 +1053,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
|
||||
def val_dataloader(self):
|
||||
if self.eval_dataset is not None:
|
||||
return self._gen_dataloader("eval")
|
||||
return None
|
||||
return []
|
||||
|
||||
def predict_dataloader(self):
|
||||
return self._gen_dataloader("predict")
|
||||
@@ -1085,7 +1085,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
|
||||
self.training_mode = self.train_data_dir is not None
|
||||
self.val_mmcif_data_cache_path = val_mmcif_data_cache_path
|
||||
|
||||
def setup(self):
|
||||
def setup(self, setup=None):
|
||||
# Most of the arguments are the same for the three datasets
|
||||
dataset_gen = partial(OpenFoldSingleMultimerDataset,
|
||||
template_mmcif_dir=self.template_mmcif_dir,
|
||||
|
||||
@@ -244,7 +244,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict:
|
||||
features["num_alignments"] = np.array(
|
||||
[num_alignments] * num_res, dtype=np.int32
|
||||
)
|
||||
features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_)
|
||||
features["msa_species_identifiers"] = np.array(species_ids, dtype=object)
|
||||
return features
|
||||
|
||||
|
||||
@@ -590,7 +590,7 @@ def convert_monomer_features(
|
||||
) -> FeatureDict:
|
||||
"""Reshapes and modifies monomer features for multimer models."""
|
||||
converted = {}
|
||||
converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_)
|
||||
converted['auth_chain_id'] = np.asarray(chain_id, dtype=object)
|
||||
unnecessary_leading_dim_feats = {
|
||||
'sequence', 'domain_name', 'num_alignments', 'seq_length'
|
||||
}
|
||||
@@ -1296,7 +1296,7 @@ class DataPipelineMultimer:
|
||||
)
|
||||
|
||||
mmcif_feats["release_date"] = np.array(
|
||||
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_
|
||||
[mmcif_object.header["release_date"].encode("utf-8")], dtype=object
|
||||
)
|
||||
|
||||
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
|
||||
|
||||
@@ -35,8 +35,8 @@ def _superimpose_np(reference, coords):
|
||||
|
||||
|
||||
def _superimpose_single(reference, coords):
|
||||
reference_np = reference.detach().cpu().numpy()
|
||||
coords_np = coords.detach().cpu().numpy()
|
||||
reference_np = reference.detach().to(torch.float).cpu().numpy()
|
||||
coords_np = coords.detach().to(torch.float).cpu().numpy()
|
||||
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
|
||||
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
|
||||
|
||||
|
||||
@@ -8,8 +8,11 @@ import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
|
||||
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
|
||||
from pytorch_lightning.plugins.environments import MPIEnvironment
|
||||
from pytorch_lightning import seed_everything
|
||||
import torch
|
||||
import wandb
|
||||
|
||||
from openfold.config import model_config
|
||||
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
|
||||
@@ -24,7 +27,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
|
||||
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
|
||||
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
|
||||
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
|
||||
from openfold.utils.seed import seed_everything
|
||||
from openfold.utils.superimposition import superimpose
|
||||
from openfold.utils.tensor_utils import tensor_tree_map
|
||||
from openfold.utils.validation_metrics import (
|
||||
@@ -59,7 +61,7 @@ class OpenFoldWrapper(pl.LightningModule):
|
||||
|
||||
self.cached_weights = None
|
||||
self.last_lr_step = -1
|
||||
self.save_hyperparameters
|
||||
self.save_hyperparameters()
|
||||
|
||||
def forward(self, batch):
|
||||
return self.model(batch)
|
||||
@@ -70,14 +72,15 @@ class OpenFoldWrapper(pl.LightningModule):
|
||||
self.log(
|
||||
f"{phase}/{loss_name}",
|
||||
indiv_loss,
|
||||
on_step=train, on_epoch=(not train), logger=True,
|
||||
prog_bar=(loss_name == 'loss'),
|
||||
on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
|
||||
)
|
||||
|
||||
if(train):
|
||||
self.log(
|
||||
f"{phase}/{loss_name}_epoch",
|
||||
indiv_loss,
|
||||
on_step=False, on_epoch=True, logger=True,
|
||||
on_step=False, on_epoch=True, logger=True, sync_dist=False,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -91,7 +94,8 @@ class OpenFoldWrapper(pl.LightningModule):
|
||||
self.log(
|
||||
f"{phase}/{k}",
|
||||
torch.mean(v),
|
||||
on_step=False, on_epoch=True, logger=True
|
||||
prog_bar = (k == 'loss'),
|
||||
on_step=False, on_epoch=True, logger=True, sync_dist=False,
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
@@ -154,7 +158,7 @@ class OpenFoldWrapper(pl.LightningModule):
|
||||
|
||||
self._log(loss_breakdown, batch, outputs, train=False)
|
||||
|
||||
def validation_epoch_end(self, _):
|
||||
def on_validation_epoch_end(self):
|
||||
# Restore the model weights to normal
|
||||
self.model.load_state_dict(self.cached_weights)
|
||||
self.cached_weights = None
|
||||
@@ -377,40 +381,59 @@ def main(args):
|
||||
callbacks.append(lr_monitor)
|
||||
|
||||
loggers = []
|
||||
is_rank_zero = int(os.environ.get("PMI_RANK")) == 0
|
||||
if(args.wandb):
|
||||
if args.mpi_plugin and is_rank_zero:
|
||||
wandb_init_dict = dict(
|
||||
name=args.experiment_name,
|
||||
project=args.wandb_project,
|
||||
id=args.wandb_id,
|
||||
dir=args.output_dir,
|
||||
resume="allow",
|
||||
anonymous=None,
|
||||
entity=args.wandb_entity
|
||||
)
|
||||
wandb.run = wandb.init(**wandb_init_dict)
|
||||
|
||||
wdb_logger = WandbLogger(
|
||||
name=args.experiment_name,
|
||||
save_dir=args.output_dir,
|
||||
id=args.wandb_id,
|
||||
project=args.wandb_project,
|
||||
config=config.to_dict(),
|
||||
**{"entity": args.wandb_entity}
|
||||
)
|
||||
loggers.append(wdb_logger)
|
||||
|
||||
cluster_environment = MPIEnvironment() if args.mpi_plugin else None
|
||||
if(args.deepspeed_config_path is not None):
|
||||
strategy = DeepSpeedPlugin(
|
||||
strategy = DeepSpeedStrategy(
|
||||
config=args.deepspeed_config_path,
|
||||
cluster_environment=cluster_environment,
|
||||
)
|
||||
if(args.wandb):
|
||||
if(args.wandb and is_rank_zero):
|
||||
wdb_logger.experiment.save(args.deepspeed_config_path)
|
||||
wdb_logger.experiment.save("openfold/config.py")
|
||||
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
|
||||
strategy = DDPPlugin(find_unused_parameters=False)
|
||||
strategy = DDPStrategy(find_unused_parameters=False,
|
||||
cluster_environment=cluster_environment)
|
||||
else:
|
||||
strategy = None
|
||||
|
||||
if(args.wandb):
|
||||
if(args.wandb and is_rank_zero):
|
||||
freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
|
||||
os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
|
||||
wdb_logger.experiment.save(f"{freeze_path}")
|
||||
|
||||
trainer = pl.Trainer.from_argparse_args(
|
||||
args,
|
||||
trainer = pl.Trainer(
|
||||
num_nodes=args.num_nodes,
|
||||
devices=args.gpus,
|
||||
precision=args.precision,
|
||||
max_epochs=args.max_epochs,
|
||||
default_root_dir=args.output_dir,
|
||||
strategy=strategy,
|
||||
callbacks=callbacks,
|
||||
logger=loggers,
|
||||
profiler='simple',
|
||||
)
|
||||
|
||||
if(args.resume_model_weights_only):
|
||||
@@ -621,7 +644,16 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
|
||||
)
|
||||
parser = pl.Trainer.add_argparse_args(parser)
|
||||
parser.add_argument("--num_nodes", type=int, default=1)
|
||||
parser.add_argument("--gpus", type=int, default=None)
|
||||
parser.add_argument("--max_epochs", type=int, default=None)
|
||||
parser.add_argument("--precision", type=str, default="32")
|
||||
parser.add_argument("--log_every_n_steps", type=int, default=50)
|
||||
parser.add_argument("--accumulate_grad_batches", type=int, default=1)
|
||||
parser.add_argument("--flush_logs_every_n_steps", type=int, default=5)
|
||||
parser.add_argument("--num_sanity_val_steps", type=int, default=0)
|
||||
parser.add_argument("--mpi_plugin", action="store_true", default=False)
|
||||
# parser = pl.Trainer.add_argparse_args(parser)
|
||||
|
||||
# Disable the initial validation pass
|
||||
parser.set_defaults(
|
||||
|
||||
Reference in New Issue
Block a user