mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
Merge pull request #418 from aqlaboratory/seeding-fix
Fix distributed seeding behavior
This commit is contained in:
@@ -1,19 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import random
|
||||
import numpy as np
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
|
||||
from openfold.utils.suppress_output import SuppressLogging
|
||||
|
||||
|
||||
def seed_globally(seed=None):
|
||||
if("PL_GLOBAL_SEED" not in os.environ):
|
||||
if(seed is None):
|
||||
seed = random.randint(0, np.iinfo(np.uint32).max)
|
||||
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
||||
logging.info(f'os.environ["PL_GLOBAL_SEED"] set to {seed}')
|
||||
|
||||
# seed_everything is a bit log-happy
|
||||
with SuppressLogging(logging.INFO):
|
||||
seed_everything(seed=None)
|
||||
@@ -1,26 +0,0 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
class SuppressStdout:
|
||||
def __enter__(self):
|
||||
self.stdout = sys.stdout
|
||||
dev_null = open("/dev/null", "w")
|
||||
sys.stdout = dev_null
|
||||
|
||||
def __exit__(self, typ, value, traceback):
|
||||
fp = sys.stdout
|
||||
sys.stdout = self.stdout
|
||||
fp.close()
|
||||
|
||||
|
||||
class SuppressLogging:
|
||||
def __init__(self, level):
|
||||
self.level = level
|
||||
|
||||
def __enter__(self):
|
||||
logging.disable(self.level)
|
||||
|
||||
def __exit__(self, typ, value, traceback):
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
@@ -8,6 +8,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
import torch
|
||||
|
||||
from openfold.config import model_config
|
||||
@@ -23,7 +24,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
|
||||
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
|
||||
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
|
||||
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
|
||||
from openfold.utils.seed import seed_everything
|
||||
from openfold.utils.superimposition import superimpose
|
||||
from openfold.utils.tensor_utils import tensor_tree_map
|
||||
from openfold.utils.validation_metrics import (
|
||||
@@ -272,7 +272,7 @@ class OpenFoldWrapper(pl.LightningModule):
|
||||
|
||||
def main(args):
|
||||
if(args.seed is not None):
|
||||
seed_everything(args.seed)
|
||||
seed_everything(args.seed, workers=True)
|
||||
|
||||
config = model_config(
|
||||
args.config_preset,
|
||||
|
||||
Reference in New Issue
Block a user