From 2ed8115194490b63ffc2f254e3ade11b97ff94b3 Mon Sep 17 00:00:00 2001 From: Jennifer Date: Fri, 12 Jan 2024 04:27:30 -0500 Subject: [PATCH 01/34] initial compatibility changes for upgrading multimer --- .gitignore | 2 +- environment.yml | 19 ++++++++++--------- openfold/data/data_pipeline.py | 12 ++++++------ openfold/data/templates.py | 4 ++-- openfold/model/primitives.py | 4 ++-- setup.py | 2 +- 6 files changed, 22 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index 25fa357..f65a3e5 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,4 @@ dist data openfold/resources/ tests/test_data/ - +cutlass diff --git a/environment.yml b/environment.yml index 7b73b25..d6ccb46 100644 --- a/environment.yml +++ b/environment.yml @@ -3,6 +3,7 @@ channels: - conda-forge - bioconda - pytorch + - nvidia dependencies: - python=3.9 - libgcc=7.2 @@ -10,17 +11,16 @@ dependencies: - pip - openmm=7.7 - pdbfixer - - cudatoolkit==11.3.* - - pytorch-lightning==1.5.10 + - pytorch-lightning - biopython==1.79 - - numpy==1.21 - - pandas==2.0 + - numpy + - pandas - PyYAML==5.4.1 - requests - - scipy==1.7 + - scipy - tqdm==4.62.2 - - typing-extensions==3.10 - - wandb==0.12.21 + - typing-extensions + - wandb - modelcif==0.7 - awscli - ml-collections @@ -29,9 +29,10 @@ dependencies: - bioconda::hmmer==3.3.2 - bioconda::hhsuite==3.3.0 - bioconda::kalign2==2.04 - - pytorch::pytorch=1.12.* + - pytorch::pytorch=2.1 + - pytorch::pytorch-cuda=12.1 - pip: - deepspeed==0.12.4 - dm-tree==0.1.6 - git+https://github.com/NVIDIA/dllogger.git - - git+https://github.com/Dao-AILab/flash-attention.git@5b838a8 + - flash-attn diff --git a/openfold/data/data_pipeline.py b/openfold/data/data_pipeline.py index 6474d28..6eb3d8e 100644 --- a/openfold/data/data_pipeline.py +++ b/openfold/data/data_pipeline.py @@ -110,12 +110,12 @@ def make_sequence_features( ) features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32) features["domain_name"] = np.array( - [description.encode("utf-8")], dtype=np.object_ + [description.encode("utf-8")], dtype=object ) features["residue_index"] = np.array(range(num_res), dtype=np.int32) features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32) features["sequence"] = np.array( - [sequence.encode("utf-8")], dtype=np.object_ + [sequence.encode("utf-8")], dtype=object ) return features @@ -148,7 +148,7 @@ def make_mmcif_features( ) mmcif_feats["release_date"] = np.array( - [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_ + [mmcif_object.header["release_date"].encode("utf-8")], dtype=object ) mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32) @@ -247,7 +247,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: features["num_alignments"] = np.array( [num_alignments] * num_res, dtype=np.int32 ) - features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_) + features["msa_species_identifiers"] = np.array(species_ids, dtype=object) return features @@ -593,7 +593,7 @@ def convert_monomer_features( ) -> FeatureDict: """Reshapes and modifies monomer features for multimer models.""" converted = {} - converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_) + converted['auth_chain_id'] = np.asarray(chain_id, dtype=object) unnecessary_leading_dim_feats = { 'sequence', 'domain_name', 'num_alignments', 'seq_length' } @@ -1290,7 +1290,7 @@ class DataPipelineMultimer: ) mmcif_feats["release_date"] = np.array( - [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_ + [mmcif_object.header["release_date"].encode("utf-8")], dtype=object ) mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32) diff --git a/openfold/data/templates.py b/openfold/data/templates.py index af6d37a..4c13e7f 100644 --- a/openfold/data/templates.py +++ b/openfold/data/templates.py @@ -83,8 +83,8 @@ TEMPLATE_FEATURES = { "template_aatype": np.int64, "template_all_atom_mask": np.float32, "template_all_atom_positions": np.float32, - "template_domain_names": np.object, - "template_sequence": np.object, + "template_domain_names": object, + "template_sequence": object, "template_sum_probs": np.float32, } diff --git a/openfold/model/primitives.py b/openfold/model/primitives.py index 8fb749a..00d9124 100644 --- a/openfold/model/primitives.py +++ b/openfold/model/primitives.py @@ -28,7 +28,7 @@ if ds4s_is_installed: fa_is_installed = importlib.util.find_spec("flash_attn") is not None if fa_is_installed: from flash_attn.bert_padding import unpad_input - from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func + from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func import torch import torch.nn as nn @@ -809,7 +809,7 @@ def _flash_attn(q, k, v, kv_mask): kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask) kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:]) - out = flash_attn_unpadded_kvpacked_func( + out = flash_attn_varlen_kvpacked_func( q, kv_unpad, q_cu_seqlens, diff --git a/setup.py b/setup.py index dbee0a5..5876271 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ version_dependent_macros = [ ] extra_cuda_flags = [ - '-std=c++14', + '-std=c++17', '-maxrregcount=50', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', From 201eafdf5485ccfafff51a5ea59a5644001ab9ed Mon Sep 17 00:00:00 2001 From: Jennifer Date: Tue, 16 Jan 2024 03:52:57 -0500 Subject: [PATCH 02/34] np type update in openfold.np.relax --- openfold/np/relax/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfold/np/relax/utils.py b/openfold/np/relax/utils.py index fc19a91..422d21b 100644 --- a/openfold/np/relax/utils.py +++ b/openfold/np/relax/utils.py @@ -79,7 +79,7 @@ def assert_equal_nonterminal_atom_types( """Checks that pre- and post-minimized proteins have same atom set.""" # Ignore any terminal OXT atoms which may have been added by minimization. oxt = residue_constants.atom_order["OXT"] - no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool) + no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=bool) no_oxt_mask[..., oxt] = False np.testing.assert_almost_equal( ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask] From e71c1b1450407c758a09a68e3151a32f5bace6b6 Mon Sep 17 00:00:00 2001 From: Jennifer Date: Fri, 12 Jan 2024 04:27:30 -0500 Subject: [PATCH 03/34] initial compatibility changes for upgrading multimer --- .gitignore | 2 +- environment.yml | 19 ++++++++++--------- openfold/data/data_pipeline.py | 12 ++++++------ openfold/data/templates.py | 4 ++-- openfold/model/primitives.py | 4 ++-- setup.py | 2 +- 6 files changed, 22 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index 25fa357..f65a3e5 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,4 @@ dist data openfold/resources/ tests/test_data/ - +cutlass diff --git a/environment.yml b/environment.yml index 7b73b25..d6ccb46 100644 --- a/environment.yml +++ b/environment.yml @@ -3,6 +3,7 @@ channels: - conda-forge - bioconda - pytorch + - nvidia dependencies: - python=3.9 - libgcc=7.2 @@ -10,17 +11,16 @@ dependencies: - pip - openmm=7.7 - pdbfixer - - cudatoolkit==11.3.* - - pytorch-lightning==1.5.10 + - pytorch-lightning - biopython==1.79 - - numpy==1.21 - - pandas==2.0 + - numpy + - pandas - PyYAML==5.4.1 - requests - - scipy==1.7 + - scipy - tqdm==4.62.2 - - typing-extensions==3.10 - - wandb==0.12.21 + - typing-extensions + - wandb - modelcif==0.7 - awscli - ml-collections @@ -29,9 +29,10 @@ dependencies: - bioconda::hmmer==3.3.2 - bioconda::hhsuite==3.3.0 - bioconda::kalign2==2.04 - - pytorch::pytorch=1.12.* + - pytorch::pytorch=2.1 + - pytorch::pytorch-cuda=12.1 - pip: - deepspeed==0.12.4 - dm-tree==0.1.6 - git+https://github.com/NVIDIA/dllogger.git - - git+https://github.com/Dao-AILab/flash-attention.git@5b838a8 + - flash-attn diff --git a/openfold/data/data_pipeline.py b/openfold/data/data_pipeline.py index 6474d28..6eb3d8e 100644 --- a/openfold/data/data_pipeline.py +++ b/openfold/data/data_pipeline.py @@ -110,12 +110,12 @@ def make_sequence_features( ) features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32) features["domain_name"] = np.array( - [description.encode("utf-8")], dtype=np.object_ + [description.encode("utf-8")], dtype=object ) features["residue_index"] = np.array(range(num_res), dtype=np.int32) features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32) features["sequence"] = np.array( - [sequence.encode("utf-8")], dtype=np.object_ + [sequence.encode("utf-8")], dtype=object ) return features @@ -148,7 +148,7 @@ def make_mmcif_features( ) mmcif_feats["release_date"] = np.array( - [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_ + [mmcif_object.header["release_date"].encode("utf-8")], dtype=object ) mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32) @@ -247,7 +247,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: features["num_alignments"] = np.array( [num_alignments] * num_res, dtype=np.int32 ) - features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_) + features["msa_species_identifiers"] = np.array(species_ids, dtype=object) return features @@ -593,7 +593,7 @@ def convert_monomer_features( ) -> FeatureDict: """Reshapes and modifies monomer features for multimer models.""" converted = {} - converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_) + converted['auth_chain_id'] = np.asarray(chain_id, dtype=object) unnecessary_leading_dim_feats = { 'sequence', 'domain_name', 'num_alignments', 'seq_length' } @@ -1290,7 +1290,7 @@ class DataPipelineMultimer: ) mmcif_feats["release_date"] = np.array( - [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_ + [mmcif_object.header["release_date"].encode("utf-8")], dtype=object ) mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32) diff --git a/openfold/data/templates.py b/openfold/data/templates.py index af6d37a..4c13e7f 100644 --- a/openfold/data/templates.py +++ b/openfold/data/templates.py @@ -83,8 +83,8 @@ TEMPLATE_FEATURES = { "template_aatype": np.int64, "template_all_atom_mask": np.float32, "template_all_atom_positions": np.float32, - "template_domain_names": np.object, - "template_sequence": np.object, + "template_domain_names": object, + "template_sequence": object, "template_sum_probs": np.float32, } diff --git a/openfold/model/primitives.py b/openfold/model/primitives.py index ea38cb3..e5735d1 100644 --- a/openfold/model/primitives.py +++ b/openfold/model/primitives.py @@ -28,7 +28,7 @@ if ds4s_is_installed: fa_is_installed = importlib.util.find_spec("flash_attn") is not None if fa_is_installed: from flash_attn.bert_padding import unpad_input - from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func + from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func import torch import torch.nn as nn @@ -811,7 +811,7 @@ def _flash_attn(q, k, v, kv_mask): kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask) kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:]) - out = flash_attn_unpadded_kvpacked_func( + out = flash_attn_varlen_kvpacked_func( q, kv_unpad, q_cu_seqlens, diff --git a/setup.py b/setup.py index dbee0a5..5876271 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ version_dependent_macros = [ ] extra_cuda_flags = [ - '-std=c++14', + '-std=c++17', '-maxrregcount=50', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', From 91776cdfb988c61b9720664eed15a5e6d956b15b Mon Sep 17 00:00:00 2001 From: Jennifer Date: Tue, 16 Jan 2024 03:52:57 -0500 Subject: [PATCH 04/34] np type update in openfold.np.relax --- openfold/np/relax/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openfold/np/relax/utils.py b/openfold/np/relax/utils.py index fc19a91..422d21b 100644 --- a/openfold/np/relax/utils.py +++ b/openfold/np/relax/utils.py @@ -79,7 +79,7 @@ def assert_equal_nonterminal_atom_types( """Checks that pre- and post-minimized proteins have same atom set.""" # Ignore any terminal OXT atoms which may have been added by minimization. oxt = residue_constants.atom_order["OXT"] - no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool) + no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=bool) no_oxt_mask[..., oxt] = False np.testing.assert_almost_equal( ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask] From 427a6ee754ccc53fa6dd30042b6a9a8d64ed3a5f Mon Sep 17 00:00:00 2001 From: Jennifer Date: Tue, 23 Jan 2024 04:28:31 -0500 Subject: [PATCH 05/34] update deprecated jax.numpy.DeviceArray to jax.Array --- tests/test_evoformer.py | 4 ++-- tests/test_msa.py | 6 +++--- tests/test_outer_product_mean.py | 2 +- tests/test_pair_transition.py | 2 +- tests/test_template.py | 4 +++- tests/test_triangular_attention.py | 2 +- tests/test_triangular_multiplicative_update.py | 2 +- 7 files changed, 12 insertions(+), 10 deletions(-) diff --git a/tests/test_evoformer.py b/tests/test_evoformer.py index 86ca98b..dd06ce7 100644 --- a/tests/test_evoformer.py +++ b/tests/test_evoformer.py @@ -178,7 +178,7 @@ class TestEvoformerStack(unittest.TestCase): params = compare_utils.fetch_alphafold_module_weights( "alphafold/alphafold_iteration/evoformer/evoformer_iteration" ) - params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) + params = tree_map(lambda n: n[0], params, jax.Array) key = jax.random.PRNGKey(42) out_gt = f.apply(params, key, activations, masks) @@ -339,7 +339,7 @@ class TestMSATransition(unittest.TestCase): "alphafold/alphafold_iteration/evoformer/evoformer_iteration/" + "msa_transition" ) - params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) + params = tree_map(lambda n: n[0], params, jax.Array) out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready() out_gt = torch.as_tensor(np.array(out_gt)) diff --git a/tests/test_msa.py b/tests/test_msa.py index ad968f9..b5b3f67 100644 --- a/tests/test_msa.py +++ b/tests/test_msa.py @@ -79,7 +79,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase): "alphafold/alphafold_iteration/evoformer/evoformer_iteration/" + "msa_row_attention" ) - params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) + params = tree_map(lambda n: n[0], params, jax.Array) out_gt = f.apply( params, None, msa_act, msa_mask, pair_act @@ -144,7 +144,7 @@ class TestMSAColumnAttention(unittest.TestCase): "alphafold/alphafold_iteration/evoformer/evoformer_iteration/" + "msa_column_attention" ) - params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) + params = tree_map(lambda n: n[0], params, jax.Array) out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready() out_gt = torch.as_tensor(np.array(out_gt)) @@ -207,7 +207,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase): "alphafold/alphafold_iteration/evoformer/extra_msa_stack/" + "msa_column_global_attention" ) - params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) + params = tree_map(lambda n: n[0], params, jax.Array) out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready() out_gt = torch.as_tensor(np.array(out_gt.block_until_ready())) diff --git a/tests/test_outer_product_mean.py b/tests/test_outer_product_mean.py index aa9e4fb..8335aa1 100644 --- a/tests/test_outer_product_mean.py +++ b/tests/test_outer_product_mean.py @@ -74,7 +74,7 @@ class TestOuterProductMean(unittest.TestCase): "alphafold/alphafold_iteration/evoformer/" + "evoformer_iteration/outer_product_mean" ) - params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) + params = tree_map(lambda n: n[0], params, jax.Array) out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready() out_gt = torch.as_tensor(np.array(out_gt)) diff --git a/tests/test_pair_transition.py b/tests/test_pair_transition.py index 31a9815..c8e8d7f 100644 --- a/tests/test_pair_transition.py +++ b/tests/test_pair_transition.py @@ -69,7 +69,7 @@ class TestPairTransition(unittest.TestCase): "alphafold/alphafold_iteration/evoformer/evoformer_iteration/" + "pair_transition" ) - params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) + params = tree_map(lambda n: n[0], params, jax.Array) out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() out_gt = torch.as_tensor(np.array(out_gt.block_until_ready())) diff --git a/tests/test_template.py b/tests/test_template.py index 47cf630..262e08b 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -191,7 +191,9 @@ class TestTemplatePairStack(unittest.TestCase): _mask_trans=False, ).cpu() - self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) + diff = torch.max(torch.abs(out_gt - out_repro)) + self.assertTrue(diff < consts.eps, + msg=f"Found difference between ground truth and reproduction of {diff}") class Template(unittest.TestCase): diff --git a/tests/test_triangular_attention.py b/tests/test_triangular_attention.py index 3f14b55..2435a3c 100644 --- a/tests/test_triangular_attention.py +++ b/tests/test_triangular_attention.py @@ -79,7 +79,7 @@ class TestTriangularAttention(unittest.TestCase): "alphafold/alphafold_iteration/evoformer/evoformer_iteration/" + name ) - params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) + params = tree_map(lambda n: n[0], params, jax.Array) out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() out_gt = torch.as_tensor(np.array(out_gt)) diff --git a/tests/test_triangular_multiplicative_update.py b/tests/test_triangular_multiplicative_update.py index 39122af..56a2d9c 100644 --- a/tests/test_triangular_multiplicative_update.py +++ b/tests/test_triangular_multiplicative_update.py @@ -85,7 +85,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): "alphafold/alphafold_iteration/evoformer/evoformer_iteration/" + name ) - params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) + params = tree_map(lambda n: n[0], params, jax.Array) out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() out_gt = torch.as_tensor(np.array(out_gt)) From e813bb5375cb8af6c667df5fca0da782341c8ca9 Mon Sep 17 00:00:00 2001 From: Christina Floristean Date: Tue, 23 Jan 2024 12:17:40 -0500 Subject: [PATCH 06/34] Additional fix for multimer deepspeed test --- tests/test_deepspeed_evo_attention.py | 11 ++++++++++- tests/test_model.py | 7 ++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/test_deepspeed_evo_attention.py b/tests/test_deepspeed_evo_attention.py index 034c60c..52411a7 100644 --- a/tests/test_deepspeed_evo_attention.py +++ b/tests/test_deepspeed_evo_attention.py @@ -293,6 +293,15 @@ class TestDeepSpeedKernel(unittest.TestCase): batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0] batch["no_recycling_iters"] = np.array([3., 3., 3., 3., ]) + + if consts.is_multimer: + n_res = batch['aatype'].shape[1] + n_extra_seq = batch['extra_msa'].shape[1] + batch["asym_id"] = np.ones((4, n_res)) + batch["entity_id"] = np.ones((4, n_res)) + batch["sym_id"] = np.ones((4, n_res)) + batch["extra_deletion_matrix"] = np.random.randint(0, 2, size=(4, n_extra_seq, n_res)) + batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()} batch["aatype"] = batch["aatype"].long() @@ -301,6 +310,7 @@ class TestDeepSpeedKernel(unittest.TestCase): batch["residx_atom37_to_atom14"] = batch[ "residx_atom37_to_atom14" ].long() + batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], 21).to(torch.float32) batch["template_all_atom_mask"] = batch["template_all_atom_masks"] batch.update( data_transforms.atom37_to_torsion_angles("template_")(batch) @@ -309,7 +319,6 @@ class TestDeepSpeedKernel(unittest.TestCase): # Move the recycling dimension to the end move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0) batch = tensor_tree_map(move_dim, batch) - with torch.no_grad(): with torch.cuda.amp.autocast(dtype=torch.bfloat16): model = compare_utils.get_global_pretrained_openfold() diff --git a/tests/test_model.py b/tests/test_model.py index 19ab87f..3d19f14 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -27,6 +27,7 @@ from tests.config import consts from tests.data_utils import ( random_template_feats, random_extra_msa_feats, + random_asym_ids ) if compare_utils.alphafold_is_installed(): @@ -85,9 +86,9 @@ class TestModel(unittest.TestCase): batch["no_recycling_iters"] = torch.tensor(2.) if consts.is_multimer: - batch["asym_id"] = torch.randint(0, 1, size=(n_res,)) - batch["entity_id"] = torch.randint(0, 1, size=(n_res,)) - batch["sym_id"] = torch.randint(0, 1, size=(n_res,)) + batch["asym_id"] = torch.as_tensor(random_asym_ids(n_res)) + batch["entity_id"] = batch["asym_id"].clone() + batch["sym_id"] = torch.ones(n_res) batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res)) add_recycling_dims = lambda t: ( From df4dfacb3aade38f1d222c0902485c35a358f64a Mon Sep 17 00:00:00 2001 From: Jennifer Date: Wed, 24 Jan 2024 01:34:33 -0500 Subject: [PATCH 07/34] first pass changes to run with pl 2.1 --- openfold/data/data_modules.py | 9 +- openfold/utils/seed.py | 2 +- tests/compare_utils.py | 18 ++ tests/test_deepspeed_evo_attention.py | 6 +- tests/test_evoformer.py | 11 +- tests/test_feats.py | 2 +- tests/test_msa.py | 6 +- tests/test_outer_product_mean.py | 2 +- tests/test_structure_module.py | 4 +- tests/test_template.py | 6 +- tests/test_triangular_attention.py | 2 +- .../test_triangular_multiplicative_update.py | 2 +- train_openfold.py | 249 ++++++++++-------- 13 files changed, 183 insertions(+), 136 deletions(-) diff --git a/openfold/data/data_modules.py b/openfold/data/data_modules.py index 62f48e4..de9f111 100644 --- a/openfold/data/data_modules.py +++ b/openfold/data/data_modules.py @@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule): with open(distillation_alignment_index_path, "r") as fp: self.distillation_alignment_index = json.load(fp) - def setup(self): + def setup(self, stage=None): # Most of the arguments are the same for the three datasets dataset_gen = partial(OpenFoldSingleDataset, template_mmcif_dir=self.template_mmcif_dir, @@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule): mode="predict", ) - def _gen_dataloader(self, stage): + def _gen_dataloader(self, stage=None): generator = None if self.batch_seed is not None: generator = torch.Generator() @@ -1053,7 +1053,8 @@ class OpenFoldDataModule(pl.LightningDataModule): def val_dataloader(self): if self.eval_dataset is not None: return self._gen_dataloader("eval") - return None + # Temp fix to pass the validation step + return [] def predict_dataloader(self): return self._gen_dataloader("predict") @@ -1085,7 +1086,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): self.training_mode = self.train_data_dir is not None self.val_mmcif_data_cache_path = val_mmcif_data_cache_path - def setup(self): + def setup(self, setup=None): # Most of the arguments are the same for the three datasets dataset_gen = partial(OpenFoldSingleMultimerDataset, template_mmcif_dir=self.template_mmcif_dir, diff --git a/openfold/utils/seed.py b/openfold/utils/seed.py index b45b813..a305bfb 100644 --- a/openfold/utils/seed.py +++ b/openfold/utils/seed.py @@ -2,7 +2,7 @@ import os import logging import random import numpy as np -from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning import seed_everything from openfold.utils.suppress_output import SuppressLogging diff --git a/tests/compare_utils.py b/tests/compare_utils.py index ae41658..326f5e2 100644 --- a/tests/compare_utils.py +++ b/tests/compare_utils.py @@ -6,6 +6,7 @@ import sys import unittest import numpy as np +import torch from openfold.config import model_config from openfold.model.model import AlphaFold @@ -119,3 +120,20 @@ def fetch_alphafold_module_weights(weight_path): "Make sure to call import_alphafold before running this function" ) return params + + +def _assert_abs_diff_small_base(compare_func, expected, actual, eps): + # Helper function for comparing absolute differences of two torch tensors. + abs_diff = torch.abs(expected - actual) + err = compare_func(abs_diff) + zero_tensor = torch.tensor(0, dtype=err.dtype) + rtol = 1.6e-2 if err.dtype == torch.bfloat16 else 1.3e-6 + torch.testing.assert_close(err, zero_tensor, atol=eps, rtol=rtol) + + +def assert_max_abs_diff_small(expected, actual, eps): + _assert_abs_diff_small_base(torch.max, expected, actual, eps) + + +def assert_mean_abs_diff_small(expected, actual, eps): + _assert_abs_diff_small_base(torch.mean, expected, actual, eps) diff --git a/tests/test_deepspeed_evo_attention.py b/tests/test_deepspeed_evo_attention.py index 52411a7..9a2a774 100644 --- a/tests/test_deepspeed_evo_attention.py +++ b/tests/test_deepspeed_evo_attention.py @@ -276,8 +276,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ) out_repro_ds = out_repro_ds["template_pair_embedding"].cpu() - err = torch.max(torch.abs(out_repro - out_repro_ds)) - self.assertTrue(err < eps, f'Error {err}') + compare_utils.assert_max_abs_diff_small(out_repro, out_repro_ds, eps) def test_compare_model(self): """ @@ -335,8 +334,7 @@ class TestDeepSpeedKernel(unittest.TestCase): out_repro = out_repro["sm"]["positions"][-1].squeeze(0) out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0) - err = torch.mean(torch.abs(out_repro - out_repro_ds)) - self.assertTrue(err < eps, f'Error: {err}') + compare_utils.assert_mean_abs_diff_small(out_repro, out_repro_ds, eps) if __name__ == "__main__": diff --git a/tests/test_evoformer.py b/tests/test_evoformer.py index dd06ce7..66162b4 100644 --- a/tests/test_evoformer.py +++ b/tests/test_evoformer.py @@ -200,8 +200,8 @@ class TestEvoformerStack(unittest.TestCase): out_repro_msa = out_repro_msa.cpu() out_repro_pair = out_repro_pair.cpu() - self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps) - self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps) + compare_utils.assert_mean_abs_diff_small(out_gt_msa, out_repro_msa, consts.eps) + compare_utils.assert_max_abs_diff_small(out_gt_pair, out_repro_pair, consts.eps) # Inplace version out_repro_msa, out_repro_pair = model.evoformer.blocks[0]( @@ -217,8 +217,8 @@ class TestEvoformerStack(unittest.TestCase): out_repro_msa = out_repro_msa.cpu() out_repro_pair = out_repro_pair.cpu() - self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps) - self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps) + compare_utils.assert_mean_abs_diff_small(out_gt_msa, out_repro_msa, consts.eps) + compare_utils.assert_max_abs_diff_small(out_gt_pair, out_repro_pair, consts.eps) class TestExtraMSAStack(unittest.TestCase): @@ -354,8 +354,7 @@ class TestMSATransition(unittest.TestCase): .cpu() ) - self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) - + compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps) if __name__ == "__main__": unittest.main() diff --git a/tests/test_feats.py b/tests/test_feats.py index 6419328..7a1783b 100644 --- a/tests/test_feats.py +++ b/tests/test_feats.py @@ -386,7 +386,7 @@ class TestFeats(unittest.TestCase): torch.tensor(restype_atom14_rigid_group_positions).cuda(), ).cpu() - self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) + compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps) if __name__ == "__main__": diff --git a/tests/test_msa.py b/tests/test_msa.py index b5b3f67..241bead 100644 --- a/tests/test_msa.py +++ b/tests/test_msa.py @@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase): ) ).cpu() - self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) + compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps) class TestMSAColumnAttention(unittest.TestCase): @@ -158,7 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase): ) ).cpu() - self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) + compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps) class TestMSAColumnGlobalAttention(unittest.TestCase): @@ -222,7 +222,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase): .cpu() ) - self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) + compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps) if __name__ == "__main__": diff --git a/tests/test_outer_product_mean.py b/tests/test_outer_product_mean.py index 8335aa1..a9665c0 100644 --- a/tests/test_outer_product_mean.py +++ b/tests/test_outer_product_mean.py @@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase): # Even when correct, OPM has large, precision-related errors. It gets # a special pass from consts.eps. - self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 5e-4) + compare_utils.assert_max_abs_diff_small(out_gt, out_repro, 5e-4) if __name__ == "__main__": diff --git a/tests/test_structure_module.py b/tests/test_structure_module.py index 410e090..7858223 100644 --- a/tests/test_structure_module.py +++ b/tests/test_structure_module.py @@ -197,7 +197,7 @@ class TestStructureModule(unittest.TestCase): # The structure module, thanks to angle normalization, is very volatile # We only assess the mean here. Heuristically speaking, it seems to # have lower error in general on real rather than synthetic data. - self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05) + compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 0.05) class TestInvariantPointAttention(unittest.TestCase): @@ -321,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase): torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(), ).cpu() - self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) + compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps) class TestAngleResnet(unittest.TestCase): diff --git a/tests/test_template.py b/tests/test_template.py index 262e08b..ae65b7c 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -191,9 +191,7 @@ class TestTemplatePairStack(unittest.TestCase): _mask_trans=False, ).cpu() - diff = torch.max(torch.abs(out_gt - out_repro)) - self.assertTrue(diff < consts.eps, - msg=f"Found difference between ground truth and reproduction of {diff}") + compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps) class Template(unittest.TestCase): @@ -286,7 +284,7 @@ class Template(unittest.TestCase): out_repro = out_repro_all["template_pair_embedding"] out_repro = out_repro.cpu() - self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) + compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps) if __name__ == "__main__": diff --git a/tests/test_triangular_attention.py b/tests/test_triangular_attention.py index 2435a3c..6c3099d 100644 --- a/tests/test_triangular_attention.py +++ b/tests/test_triangular_attention.py @@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase): chunk_size=None, ).cpu() - self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) + compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps) @compare_utils.skip_unless_alphafold_installed() def test_tri_att_end_compare(self): diff --git a/tests/test_triangular_multiplicative_update.py b/tests/test_triangular_multiplicative_update.py index 56a2d9c..b99f8e1 100644 --- a/tests/test_triangular_multiplicative_update.py +++ b/tests/test_triangular_multiplicative_update.py @@ -103,7 +103,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): inplace_safe=True, _inplace_chunk_size=4, ).cpu() - self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) + compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps) @compare_utils.skip_unless_alphafold_installed() def test_tri_mul_out_compare(self): diff --git a/train_openfold.py b/train_openfold.py index 9ec26ee..baa3d3d 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -7,7 +7,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger -from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin +from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy import torch from openfold.config import model_config @@ -55,7 +55,7 @@ class OpenFoldWrapper(pl.LightningModule): self.ema = ExponentialMovingAverage( model=self.model, decay=config.ema.decay ) - + self.cached_weights = None self.last_lr_step = -1 @@ -66,12 +66,12 @@ class OpenFoldWrapper(pl.LightningModule): phase = "train" if train else "val" for loss_name, indiv_loss in loss_breakdown.items(): self.log( - f"{phase}/{loss_name}", - indiv_loss, + f"{phase}/{loss_name}", + indiv_loss, on_step=train, on_epoch=(not train), logger=True, ) - if(train): + if (train): self.log( f"{phase}/{loss_name}_epoch", indiv_loss, @@ -80,12 +80,12 @@ class OpenFoldWrapper(pl.LightningModule): with torch.no_grad(): other_metrics = self._compute_validation_metrics( - batch, + batch, outputs, superimposition_metrics=(not train) ) - for k,v in other_metrics.items(): + for k, v in other_metrics.items(): self.log( f"{phase}/{k}", torch.mean(v), @@ -93,7 +93,7 @@ class OpenFoldWrapper(pl.LightningModule): ) def training_step(self, batch, batch_idx): - if(self.ema.device != batch["aatype"].device): + if (self.ema.device != batch["aatype"].device): self.ema.to(batch["aatype"].device) ground_truth = batch.pop('gt_features', None) @@ -124,12 +124,13 @@ class OpenFoldWrapper(pl.LightningModule): def validation_step(self, batch, batch_idx): # At the start of validation, load the EMA weights - if(self.cached_weights is None): + if (self.cached_weights is None): # model.state_dict() contains references to model weights rather - # than copies. Therefore, we need to clone them before calling + # than copies. Therefore, we need to clone them before calling # load_state_dict(). - clone_param = lambda t: t.detach().clone() - self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) + def clone_param(t): return t.detach().clone() + self.cached_weights = tensor_tree_map( + clone_param, self.model.state_dict()) self.model.load_state_dict(self.ema.state_dict()["params"]) ground_truth = batch.pop('gt_features', None) @@ -151,23 +152,23 @@ class OpenFoldWrapper(pl.LightningModule): ) self._log(loss_breakdown, batch, outputs, train=False) - - def validation_epoch_end(self, _): + + def on_validation_epoch_end(self, _): # Restore the model weights to normal self.model.load_state_dict(self.cached_weights) self.cached_weights = None - def _compute_validation_metrics(self, - batch, - outputs, - superimposition_metrics=False - ): + def _compute_validation_metrics(self, + batch, + outputs, + superimposition_metrics=False + ): metrics = {} - + gt_coords = batch["all_atom_positions"] pred_coords = outputs["final_atom_positions"] all_atom_mask = batch["all_atom_mask"] - + # This is super janky for superimposition. Fix later gt_coords_masked = gt_coords * all_atom_mask[..., None] pred_coords_masked = pred_coords * all_atom_mask[..., None] @@ -175,7 +176,7 @@ class OpenFoldWrapper(pl.LightningModule): gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :] pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :] all_atom_mask_ca = all_atom_mask[..., ca_pos] - + lddt_ca_score = lddt_ca( pred_coords, gt_coords, @@ -183,18 +184,18 @@ class OpenFoldWrapper(pl.LightningModule): eps=self.config.globals.eps, per_residue=False, ) - + metrics["lddt_ca"] = lddt_ca_score - + drmsd_ca_score = drmsd( pred_coords_masked_ca, gt_coords_masked_ca, - mask=all_atom_mask_ca, # still required here to compute n + mask=all_atom_mask_ca, # still required here to compute n ) - + metrics["drmsd_ca"] = drmsd_ca_score - - if(superimposition_metrics): + + if (superimposition_metrics): superimposed_pred, alignment_rmsd = superimpose( gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca, ) @@ -208,22 +209,22 @@ class OpenFoldWrapper(pl.LightningModule): metrics["alignment_rmsd"] = alignment_rmsd metrics["gdt_ts"] = gdt_ts_score metrics["gdt_ha"] = gdt_ha_score - + return metrics - def configure_optimizers(self, - learning_rate: float = 1e-3, - eps: float = 1e-5, - ) -> torch.optim.Adam: -# return torch.optim.Adam( -# self.model.parameters(), -# lr=learning_rate, -# eps=eps -# ) + def configure_optimizers(self, + learning_rate: float = 1e-3, + eps: float = 1e-5, + ) -> torch.optim.Adam: + # return torch.optim.Adam( + # self.model.parameters(), + # lr=learning_rate, + # eps=eps + # ) # Ignored as long as a DeepSpeed optimizer is configured optimizer = torch.optim.Adam( - self.model.parameters(), - lr=learning_rate, + self.model.parameters(), + lr=learning_rate, eps=eps ) @@ -247,8 +248,9 @@ class OpenFoldWrapper(pl.LightningModule): def on_load_checkpoint(self, checkpoint): ema = checkpoint["ema"] - if(not self.model.template_config.enabled): - ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k} + if (not self.model.template_config.enabled): + ema["params"] = {k: v for k, + v in ema["params"].items() if not "template" in k} self.ema.load_state_dict(ema) def on_save_checkpoint(self, checkpoint): @@ -259,69 +261,72 @@ class OpenFoldWrapper(pl.LightningModule): def load_from_jax(self, jax_path): model_basename = os.path.splitext( - os.path.basename( - os.path.normpath(jax_path) - ) + os.path.basename( + os.path.normpath(jax_path) + ) )[0] model_version = "_".join(model_basename.split("_")[1:]) import_jax_weights_( - self.model, jax_path, version=model_version + self.model, jax_path, version=model_version ) def main(args): - if(args.seed is not None): - seed_everything(args.seed) + if (args.seed is not None): + seed_everything(args.seed) config = model_config( - args.config_preset, - train=True, + args.config_preset, + train=True, low_prec=(str(args.precision) == "16") - ) + ) model_module = OpenFoldWrapper(config) - if(args.resume_from_ckpt): - if(os.path.isdir(args.resume_from_ckpt)): - last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt) + if (args.resume_from_ckpt): + if (os.path.isdir(args.resume_from_ckpt)): + last_global_step = get_global_step_from_zero_checkpoint( + args.resume_from_ckpt) else: sd = torch.load(args.resume_from_ckpt) last_global_step = int(sd['global_step']) model_module.resume_last_lr_step(last_global_step) logging.info("Successfully loaded last lr step...") - if(args.resume_from_ckpt and args.resume_model_weights_only): - if(os.path.isdir(args.resume_from_ckpt)): - sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt) + if (args.resume_from_ckpt and args.resume_model_weights_only): + if (os.path.isdir(args.resume_from_ckpt)): + sd = get_fp32_state_dict_from_zero_checkpoint( + args.resume_from_ckpt) else: sd = torch.load(args.resume_from_ckpt) - sd = {k[len("module."):]:v for k,v in sd.items()} + sd = {k[len("module."):]: v for k, v in sd.items()} import_openfold_weights_(model=model_module, state_dict=sd) logging.info("Successfully loaded model weights...") - if(args.resume_from_jax_params): + if (args.resume_from_jax_params): model_module.load_from_jax(args.resume_from_jax_params) - logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") - + logging.info( + f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") + # TorchScript components of the model - if(args.script_modules): + if (args.script_modules): script_preset_(model_module) if "multimer" in args.config_preset: data_module = OpenFoldMultimerDataModule( - config=config.data, - batch_seed=args.seed, - **vars(args) - ) + config=config.data, + batch_seed=args.seed, + **vars(args) + ) else: data_module = OpenFoldDataModule( - config=config.data, + config=config.data, batch_seed=args.seed, **vars(args) ) data_module.prepare_data() data_module.setup() - + callbacks = [] - if(args.checkpoint_every_epoch): + if (args.checkpoint_every_epoch): mc = ModelCheckpoint( every_n_epochs=1, auto_insert_metric_name=False, @@ -329,7 +334,7 @@ def main(args): ) callbacks.append(mc) - if(args.early_stopping): + if (args.early_stopping): es = EarlyStoppingVerbose( monitor="val/lddt_ca", min_delta=args.min_delta, @@ -341,7 +346,7 @@ def main(args): ) callbacks.append(es) - if(args.log_performance): + if (args.log_performance): global_batch_size = args.num_nodes * args.gpus perf = PerformanceLoggingCallback( log_file=os.path.join(args.output_dir, "performance_log.json"), @@ -349,12 +354,12 @@ def main(args): ) callbacks.append(perf) - if(args.log_lr): + if (args.log_lr): lr_monitor = LearningRateMonitor(logging_interval="step") callbacks.append(lr_monitor) loggers = [] - if(args.wandb): + if (args.wandb): wdb_logger = WandbLogger( name=args.experiment_name, save_dir=args.output_dir, @@ -364,38 +369,43 @@ def main(args): ) loggers.append(wdb_logger) - if(args.deepspeed_config_path is not None): - strategy = DeepSpeedPlugin( + if (args.deepspeed_config_path is not None): + strategy = DeepSpeedStrategy( config=args.deepspeed_config_path, ) - if(args.wandb): + if (args.wandb): wdb_logger.experiment.save(args.deepspeed_config_path) wdb_logger.experiment.save("openfold/config.py") elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1: - strategy = DDPPlugin(find_unused_parameters=False) + strategy = DDPStrategy(find_unused_parameters=False) else: strategy = None - - if(args.wandb): + + if (args.wandb): freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt" os.system(f"{sys.executable} -m pip freeze > {freeze_path}") wdb_logger.experiment.save(f"{freeze_path}") - trainer = pl.Trainer.from_argparse_args( - args, - default_root_dir=args.output_dir, - strategy=strategy, - callbacks=callbacks, - logger=loggers, - ) + # Raw dump of all args from pl.Trainer constructor + trainer_kws = set([ + 'accelerator', 'strategy', 'devices', 'num_nodes', 'precision', 'logger', 'callbacks', 'fast_dev_run', 'max_epochs', 'min_epochs', 'max_steps', 'min_steps', 'max_tim', 'limit_train_batches', 'limit_val_batches', 'limit_test_batches', 'limit_predict_batches', 'overfit_batches', 'val_check_interval', 'check_val_every_n_epoch', 'num_sanity_val_steps', 'log_every_n_steps', 'enable_checkpointing', 'enable_progress_bar', 'enable_model_summary', 'accumulate_grad_batches', 'gradient_clip_val', 'gradient_clip_algorithm', 'deterministic', 'benchmark', 'inference_mode', 'use_distributed_sampler', 'profiler', 'detect_anomaly', 'barebones', 'plugins', 'sync_batchnorm', 'reload_dataloaders_every_n_epochs', 'default_root_dir', + ]) + trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws} + trainer_args.update({ + 'default_root_dir': args.output_dir, + 'strategy': strategy, + 'callbacks': callbacks, + 'logger': loggers, + }) + trainer = pl.Trainer(**trainer_args) - if(args.resume_model_weights_only): + if (args.resume_model_weights_only): ckpt_path = None else: ckpt_path = args.resume_from_ckpt trainer.fit( - model_module, + model_module, datamodule=data_module, ckpt_path=ckpt_path, ) @@ -594,36 +604,59 @@ if __name__ == "__main__": "--distillation_alignment_index_path", type=str, default=None, help="Distillation alignment index. See the README for instructions." ) - parser = pl.Trainer.add_argparse_args(parser) - - # Disable the initial validation pass - parser.set_defaults( - num_sanity_val_steps=0, + parser.add_argument( + "--num_nodes", type=int, default=1, + ) + parser.add_argument( + "--gpus", type=int, default=1, + ) + parser.add_argument( + "--precision", type=str, default=None, + ) + parser.add_argument( + "--replace_sampler_ddp", type=bool_type, default=True, + ) + parser.add_argument( + "--max_epochs", type=int, default=1, + ) + parser.add_argument( + "--log_every_n_steps", type=int, default=25, + ) + parser.add_argument( + "--num_sanity_val_steps", type=int, default=0, ) - # Remove some buggy/redundant arguments introduced by the Trainer - remove_arguments( - parser, - [ - "--accelerator", - "--resume_from_checkpoint", - "--reload_dataloaders_every_epoch", - "--reload_dataloaders_every_n_epochs", - ] - ) + # parser = pl.Trainer.add_argparse_args(parser) + # + # # Disable the initial validation pass + # parser.set_defaults( + # num_sanity_val_steps=0, + # ) + + # # Remove some buggy/redundant arguments introduced by the Trainer + # remove_arguments( + # parser, + # [ + # "--accelerator", + # "--resume_from_checkpoint", + # "--reload_dataloaders_every_epoch", + # "--reload_dataloaders_every_n_epochs", + # ] + # ) args = parser.parse_args() - if(args.seed is None and - ((args.gpus is not None and args.gpus > 1) or + if (args.seed is None and + ((args.gpus is not None and args.gpus > 1) or (args.num_nodes is not None and args.num_nodes > 1))): raise ValueError("For distributed training, --seed must be specified") - if(str(args.precision) == "16" and args.deepspeed_config_path is not None): + if (str(args.precision) == "16" and args.deepspeed_config_path is not None): raise ValueError("DeepSpeed and FP16 training are not compatible") - if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): - raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path") + if (args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): + raise ValueError( + "Choose between loading pretrained Jax-weights and a checkpoint-path") # This re-applies the training-time filters at the beginning of every epoch args.reload_dataloaders_every_n_epochs = 1 From 456103da68aa8ac42d6a5ec40d388b5f0ab0477e Mon Sep 17 00:00:00 2001 From: Jennifer Date: Fri, 12 Jan 2024 04:27:30 -0500 Subject: [PATCH 08/34] initial compatibility changes for upgrading multimer --- environment.yml | 19 ++++++++++--------- openfold/data/data_pipeline.py | 6 +++--- openfold/model/primitives.py | 4 ++-- setup.py | 2 +- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/environment.yml b/environment.yml index 7b73b25..d6ccb46 100644 --- a/environment.yml +++ b/environment.yml @@ -3,6 +3,7 @@ channels: - conda-forge - bioconda - pytorch + - nvidia dependencies: - python=3.9 - libgcc=7.2 @@ -10,17 +11,16 @@ dependencies: - pip - openmm=7.7 - pdbfixer - - cudatoolkit==11.3.* - - pytorch-lightning==1.5.10 + - pytorch-lightning - biopython==1.79 - - numpy==1.21 - - pandas==2.0 + - numpy + - pandas - PyYAML==5.4.1 - requests - - scipy==1.7 + - scipy - tqdm==4.62.2 - - typing-extensions==3.10 - - wandb==0.12.21 + - typing-extensions + - wandb - modelcif==0.7 - awscli - ml-collections @@ -29,9 +29,10 @@ dependencies: - bioconda::hmmer==3.3.2 - bioconda::hhsuite==3.3.0 - bioconda::kalign2==2.04 - - pytorch::pytorch=1.12.* + - pytorch::pytorch=2.1 + - pytorch::pytorch-cuda=12.1 - pip: - deepspeed==0.12.4 - dm-tree==0.1.6 - git+https://github.com/NVIDIA/dllogger.git - - git+https://github.com/Dao-AILab/flash-attention.git@5b838a8 + - flash-attn diff --git a/openfold/data/data_pipeline.py b/openfold/data/data_pipeline.py index ce8494d..adde0b7 100644 --- a/openfold/data/data_pipeline.py +++ b/openfold/data/data_pipeline.py @@ -244,7 +244,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: features["num_alignments"] = np.array( [num_alignments] * num_res, dtype=np.int32 ) - features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_) + features["msa_species_identifiers"] = np.array(species_ids, dtype=object) return features @@ -590,7 +590,7 @@ def convert_monomer_features( ) -> FeatureDict: """Reshapes and modifies monomer features for multimer models.""" converted = {} - converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_) + converted['auth_chain_id'] = np.asarray(chain_id, dtype=object) unnecessary_leading_dim_feats = { 'sequence', 'domain_name', 'num_alignments', 'seq_length' } @@ -1296,7 +1296,7 @@ class DataPipelineMultimer: ) mmcif_feats["release_date"] = np.array( - [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_ + [mmcif_object.header["release_date"].encode("utf-8")], dtype=object ) mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32) diff --git a/openfold/model/primitives.py b/openfold/model/primitives.py index ea38cb3..e5735d1 100644 --- a/openfold/model/primitives.py +++ b/openfold/model/primitives.py @@ -28,7 +28,7 @@ if ds4s_is_installed: fa_is_installed = importlib.util.find_spec("flash_attn") is not None if fa_is_installed: from flash_attn.bert_padding import unpad_input - from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func + from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func import torch import torch.nn as nn @@ -811,7 +811,7 @@ def _flash_attn(q, k, v, kv_mask): kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask) kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:]) - out = flash_attn_unpadded_kvpacked_func( + out = flash_attn_varlen_kvpacked_func( q, kv_unpad, q_cu_seqlens, diff --git a/setup.py b/setup.py index bec9862..9179856 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ version_dependent_macros = [ ] extra_cuda_flags = [ - '-std=c++14', + '-std=c++17', '-maxrregcount=50', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', From ff3680084903f43881bcca4e5825b6c663625be1 Mon Sep 17 00:00:00 2001 From: Jennifer Date: Wed, 24 Jan 2024 01:34:33 -0500 Subject: [PATCH 09/34] first pass changes to run with pl 2.1 --- openfold/data/data_modules.py | 9 +- openfold/utils/seed.py | 2 +- train_openfold.py | 249 +++++++++++++++++++--------------- 3 files changed, 147 insertions(+), 113 deletions(-) diff --git a/openfold/data/data_modules.py b/openfold/data/data_modules.py index 62f48e4..de9f111 100644 --- a/openfold/data/data_modules.py +++ b/openfold/data/data_modules.py @@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule): with open(distillation_alignment_index_path, "r") as fp: self.distillation_alignment_index = json.load(fp) - def setup(self): + def setup(self, stage=None): # Most of the arguments are the same for the three datasets dataset_gen = partial(OpenFoldSingleDataset, template_mmcif_dir=self.template_mmcif_dir, @@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule): mode="predict", ) - def _gen_dataloader(self, stage): + def _gen_dataloader(self, stage=None): generator = None if self.batch_seed is not None: generator = torch.Generator() @@ -1053,7 +1053,8 @@ class OpenFoldDataModule(pl.LightningDataModule): def val_dataloader(self): if self.eval_dataset is not None: return self._gen_dataloader("eval") - return None + # Temp fix to pass the validation step + return [] def predict_dataloader(self): return self._gen_dataloader("predict") @@ -1085,7 +1086,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): self.training_mode = self.train_data_dir is not None self.val_mmcif_data_cache_path = val_mmcif_data_cache_path - def setup(self): + def setup(self, setup=None): # Most of the arguments are the same for the three datasets dataset_gen = partial(OpenFoldSingleMultimerDataset, template_mmcif_dir=self.template_mmcif_dir, diff --git a/openfold/utils/seed.py b/openfold/utils/seed.py index b45b813..a305bfb 100644 --- a/openfold/utils/seed.py +++ b/openfold/utils/seed.py @@ -2,7 +2,7 @@ import os import logging import random import numpy as np -from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning import seed_everything from openfold.utils.suppress_output import SuppressLogging diff --git a/train_openfold.py b/train_openfold.py index b41d738..9612aea 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -7,7 +7,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger -from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin +from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy import torch from openfold.config import model_config @@ -55,7 +55,7 @@ class OpenFoldWrapper(pl.LightningModule): self.ema = ExponentialMovingAverage( model=self.model, decay=config.ema.decay ) - + self.cached_weights = None self.last_lr_step = -1 @@ -66,12 +66,12 @@ class OpenFoldWrapper(pl.LightningModule): phase = "train" if train else "val" for loss_name, indiv_loss in loss_breakdown.items(): self.log( - f"{phase}/{loss_name}", - indiv_loss, + f"{phase}/{loss_name}", + indiv_loss, on_step=train, on_epoch=(not train), logger=True, ) - if(train): + if (train): self.log( f"{phase}/{loss_name}_epoch", indiv_loss, @@ -80,12 +80,12 @@ class OpenFoldWrapper(pl.LightningModule): with torch.no_grad(): other_metrics = self._compute_validation_metrics( - batch, + batch, outputs, superimposition_metrics=(not train) ) - for k,v in other_metrics.items(): + for k, v in other_metrics.items(): self.log( f"{phase}/{k}", torch.mean(v), @@ -93,7 +93,7 @@ class OpenFoldWrapper(pl.LightningModule): ) def training_step(self, batch, batch_idx): - if(self.ema.device != batch["aatype"].device): + if (self.ema.device != batch["aatype"].device): self.ema.to(batch["aatype"].device) ground_truth = batch.pop('gt_features', None) @@ -124,12 +124,13 @@ class OpenFoldWrapper(pl.LightningModule): def validation_step(self, batch, batch_idx): # At the start of validation, load the EMA weights - if(self.cached_weights is None): + if (self.cached_weights is None): # model.state_dict() contains references to model weights rather - # than copies. Therefore, we need to clone them before calling + # than copies. Therefore, we need to clone them before calling # load_state_dict(). - clone_param = lambda t: t.detach().clone() - self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) + def clone_param(t): return t.detach().clone() + self.cached_weights = tensor_tree_map( + clone_param, self.model.state_dict()) self.model.load_state_dict(self.ema.state_dict()["params"]) ground_truth = batch.pop('gt_features', None) @@ -151,23 +152,23 @@ class OpenFoldWrapper(pl.LightningModule): ) self._log(loss_breakdown, batch, outputs, train=False) - - def validation_epoch_end(self, _): + + def on_validation_epoch_end(self, _): # Restore the model weights to normal self.model.load_state_dict(self.cached_weights) self.cached_weights = None - def _compute_validation_metrics(self, - batch, - outputs, - superimposition_metrics=False - ): + def _compute_validation_metrics(self, + batch, + outputs, + superimposition_metrics=False + ): metrics = {} - + gt_coords = batch["all_atom_positions"] pred_coords = outputs["final_atom_positions"] all_atom_mask = batch["all_atom_mask"] - + # This is super janky for superimposition. Fix later gt_coords_masked = gt_coords * all_atom_mask[..., None] pred_coords_masked = pred_coords * all_atom_mask[..., None] @@ -175,7 +176,7 @@ class OpenFoldWrapper(pl.LightningModule): gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :] pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :] all_atom_mask_ca = all_atom_mask[..., ca_pos] - + lddt_ca_score = lddt_ca( pred_coords, gt_coords, @@ -183,18 +184,18 @@ class OpenFoldWrapper(pl.LightningModule): eps=self.config.globals.eps, per_residue=False, ) - + metrics["lddt_ca"] = lddt_ca_score - + drmsd_ca_score = drmsd( pred_coords_masked_ca, gt_coords_masked_ca, - mask=all_atom_mask_ca, # still required here to compute n + mask=all_atom_mask_ca, # still required here to compute n ) - + metrics["drmsd_ca"] = drmsd_ca_score - - if(superimposition_metrics): + + if (superimposition_metrics): superimposed_pred, alignment_rmsd = superimpose( gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca, ) @@ -208,22 +209,22 @@ class OpenFoldWrapper(pl.LightningModule): metrics["alignment_rmsd"] = alignment_rmsd metrics["gdt_ts"] = gdt_ts_score metrics["gdt_ha"] = gdt_ha_score - + return metrics - def configure_optimizers(self, - learning_rate: float = 1e-3, - eps: float = 1e-5, - ) -> torch.optim.Adam: -# return torch.optim.Adam( -# self.model.parameters(), -# lr=learning_rate, -# eps=eps -# ) + def configure_optimizers(self, + learning_rate: float = 1e-3, + eps: float = 1e-5, + ) -> torch.optim.Adam: + # return torch.optim.Adam( + # self.model.parameters(), + # lr=learning_rate, + # eps=eps + # ) # Ignored as long as a DeepSpeed optimizer is configured optimizer = torch.optim.Adam( - self.model.parameters(), - lr=learning_rate, + self.model.parameters(), + lr=learning_rate, eps=eps ) @@ -248,8 +249,9 @@ class OpenFoldWrapper(pl.LightningModule): def on_load_checkpoint(self, checkpoint): ema = checkpoint["ema"] - if(not self.model.template_config.enabled): - ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k} + if (not self.model.template_config.enabled): + ema["params"] = {k: v for k, + v in ema["params"].items() if not "template" in k} self.ema.load_state_dict(ema) def on_save_checkpoint(self, checkpoint): @@ -260,69 +262,72 @@ class OpenFoldWrapper(pl.LightningModule): def load_from_jax(self, jax_path): model_basename = os.path.splitext( - os.path.basename( - os.path.normpath(jax_path) - ) + os.path.basename( + os.path.normpath(jax_path) + ) )[0] model_version = "_".join(model_basename.split("_")[1:]) import_jax_weights_( - self.model, jax_path, version=model_version + self.model, jax_path, version=model_version ) def main(args): - if(args.seed is not None): - seed_everything(args.seed) + if (args.seed is not None): + seed_everything(args.seed) config = model_config( - args.config_preset, - train=True, + args.config_preset, + train=True, low_prec=(str(args.precision) == "16") - ) + ) model_module = OpenFoldWrapper(config) - if(args.resume_from_ckpt): - if(os.path.isdir(args.resume_from_ckpt)): - last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt) + if (args.resume_from_ckpt): + if (os.path.isdir(args.resume_from_ckpt)): + last_global_step = get_global_step_from_zero_checkpoint( + args.resume_from_ckpt) else: sd = torch.load(args.resume_from_ckpt) last_global_step = int(sd['global_step']) model_module.resume_last_lr_step(last_global_step) logging.info("Successfully loaded last lr step...") - if(args.resume_from_ckpt and args.resume_model_weights_only): - if(os.path.isdir(args.resume_from_ckpt)): - sd = get_fp32_state_dict_from_zero_checkpoint(args.resume_from_ckpt) + if (args.resume_from_ckpt and args.resume_model_weights_only): + if (os.path.isdir(args.resume_from_ckpt)): + sd = get_fp32_state_dict_from_zero_checkpoint( + args.resume_from_ckpt) else: sd = torch.load(args.resume_from_ckpt) - sd = {k[len("module."):]:v for k,v in sd.items()} + sd = {k[len("module."):]: v for k, v in sd.items()} import_openfold_weights_(model=model_module, state_dict=sd) logging.info("Successfully loaded model weights...") - if(args.resume_from_jax_params): + if (args.resume_from_jax_params): model_module.load_from_jax(args.resume_from_jax_params) - logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") - + logging.info( + f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") + # TorchScript components of the model - if(args.script_modules): + if (args.script_modules): script_preset_(model_module) if "multimer" in args.config_preset: data_module = OpenFoldMultimerDataModule( - config=config.data, - batch_seed=args.seed, - **vars(args) - ) + config=config.data, + batch_seed=args.seed, + **vars(args) + ) else: data_module = OpenFoldDataModule( - config=config.data, + config=config.data, batch_seed=args.seed, **vars(args) ) data_module.prepare_data() data_module.setup() - + callbacks = [] - if(args.checkpoint_every_epoch): + if (args.checkpoint_every_epoch): mc = ModelCheckpoint( every_n_epochs=1, auto_insert_metric_name=False, @@ -330,7 +335,7 @@ def main(args): ) callbacks.append(mc) - if(args.early_stopping): + if (args.early_stopping): es = EarlyStoppingVerbose( monitor="val/lddt_ca", min_delta=args.min_delta, @@ -342,7 +347,7 @@ def main(args): ) callbacks.append(es) - if(args.log_performance): + if (args.log_performance): global_batch_size = args.num_nodes * args.gpus perf = PerformanceLoggingCallback( log_file=os.path.join(args.output_dir, "performance_log.json"), @@ -350,12 +355,12 @@ def main(args): ) callbacks.append(perf) - if(args.log_lr): + if (args.log_lr): lr_monitor = LearningRateMonitor(logging_interval="step") callbacks.append(lr_monitor) loggers = [] - if(args.wandb): + if (args.wandb): wdb_logger = WandbLogger( name=args.experiment_name, save_dir=args.output_dir, @@ -365,38 +370,43 @@ def main(args): ) loggers.append(wdb_logger) - if(args.deepspeed_config_path is not None): - strategy = DeepSpeedPlugin( + if (args.deepspeed_config_path is not None): + strategy = DeepSpeedStrategy( config=args.deepspeed_config_path, ) - if(args.wandb): + if (args.wandb): wdb_logger.experiment.save(args.deepspeed_config_path) wdb_logger.experiment.save("openfold/config.py") elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1: - strategy = DDPPlugin(find_unused_parameters=False) + strategy = DDPStrategy(find_unused_parameters=False) else: strategy = None - - if(args.wandb): + + if (args.wandb): freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt" os.system(f"{sys.executable} -m pip freeze > {freeze_path}") wdb_logger.experiment.save(f"{freeze_path}") - trainer = pl.Trainer.from_argparse_args( - args, - default_root_dir=args.output_dir, - strategy=strategy, - callbacks=callbacks, - logger=loggers, - ) + # Raw dump of all args from pl.Trainer constructor + trainer_kws = set([ + 'accelerator', 'strategy', 'devices', 'num_nodes', 'precision', 'logger', 'callbacks', 'fast_dev_run', 'max_epochs', 'min_epochs', 'max_steps', 'min_steps', 'max_tim', 'limit_train_batches', 'limit_val_batches', 'limit_test_batches', 'limit_predict_batches', 'overfit_batches', 'val_check_interval', 'check_val_every_n_epoch', 'num_sanity_val_steps', 'log_every_n_steps', 'enable_checkpointing', 'enable_progress_bar', 'enable_model_summary', 'accumulate_grad_batches', 'gradient_clip_val', 'gradient_clip_algorithm', 'deterministic', 'benchmark', 'inference_mode', 'use_distributed_sampler', 'profiler', 'detect_anomaly', 'barebones', 'plugins', 'sync_batchnorm', 'reload_dataloaders_every_n_epochs', 'default_root_dir', + ]) + trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws} + trainer_args.update({ + 'default_root_dir': args.output_dir, + 'strategy': strategy, + 'callbacks': callbacks, + 'logger': loggers, + }) + trainer = pl.Trainer(**trainer_args) - if(args.resume_model_weights_only): + if (args.resume_model_weights_only): ckpt_path = None else: ckpt_path = args.resume_from_ckpt trainer.fit( - model_module, + model_module, datamodule=data_module, ckpt_path=ckpt_path, ) @@ -595,36 +605,59 @@ if __name__ == "__main__": "--distillation_alignment_index_path", type=str, default=None, help="Distillation alignment index. See the README for instructions." ) - parser = pl.Trainer.add_argparse_args(parser) - - # Disable the initial validation pass - parser.set_defaults( - num_sanity_val_steps=0, + parser.add_argument( + "--num_nodes", type=int, default=1, + ) + parser.add_argument( + "--gpus", type=int, default=1, + ) + parser.add_argument( + "--precision", type=str, default=None, + ) + parser.add_argument( + "--replace_sampler_ddp", type=bool_type, default=True, + ) + parser.add_argument( + "--max_epochs", type=int, default=1, + ) + parser.add_argument( + "--log_every_n_steps", type=int, default=25, + ) + parser.add_argument( + "--num_sanity_val_steps", type=int, default=0, ) - # Remove some buggy/redundant arguments introduced by the Trainer - remove_arguments( - parser, - [ - "--accelerator", - "--resume_from_checkpoint", - "--reload_dataloaders_every_epoch", - "--reload_dataloaders_every_n_epochs", - ] - ) + # parser = pl.Trainer.add_argparse_args(parser) + # + # # Disable the initial validation pass + # parser.set_defaults( + # num_sanity_val_steps=0, + # ) + + # # Remove some buggy/redundant arguments introduced by the Trainer + # remove_arguments( + # parser, + # [ + # "--accelerator", + # "--resume_from_checkpoint", + # "--reload_dataloaders_every_epoch", + # "--reload_dataloaders_every_n_epochs", + # ] + # ) args = parser.parse_args() - if(args.seed is None and - ((args.gpus is not None and args.gpus > 1) or + if (args.seed is None and + ((args.gpus is not None and args.gpus > 1) or (args.num_nodes is not None and args.num_nodes > 1))): raise ValueError("For distributed training, --seed must be specified") - if(str(args.precision) == "16" and args.deepspeed_config_path is not None): + if (str(args.precision) == "16" and args.deepspeed_config_path is not None): raise ValueError("DeepSpeed and FP16 training are not compatible") - if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): - raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path") + if (args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): + raise ValueError( + "Choose between loading pretrained Jax-weights and a checkpoint-path") # This re-applies the training-time filters at the beginning of every epoch args.reload_dataloaders_every_n_epochs = 1 From 5f5a79a7d89b6cfa3c9c6f96f972f20dd7ba700e Mon Sep 17 00:00:00 2001 From: Jennifer Date: Fri, 12 Jan 2024 04:27:30 -0500 Subject: [PATCH 10/34] initial compatibility changes for upgrading multimer --- environment.yml | 19 ++++++++++--------- openfold/data/data_pipeline.py | 6 +++--- openfold/model/primitives.py | 4 ++-- setup.py | 2 +- 4 files changed, 16 insertions(+), 15 deletions(-) diff --git a/environment.yml b/environment.yml index 7b73b25..d6ccb46 100644 --- a/environment.yml +++ b/environment.yml @@ -3,6 +3,7 @@ channels: - conda-forge - bioconda - pytorch + - nvidia dependencies: - python=3.9 - libgcc=7.2 @@ -10,17 +11,16 @@ dependencies: - pip - openmm=7.7 - pdbfixer - - cudatoolkit==11.3.* - - pytorch-lightning==1.5.10 + - pytorch-lightning - biopython==1.79 - - numpy==1.21 - - pandas==2.0 + - numpy + - pandas - PyYAML==5.4.1 - requests - - scipy==1.7 + - scipy - tqdm==4.62.2 - - typing-extensions==3.10 - - wandb==0.12.21 + - typing-extensions + - wandb - modelcif==0.7 - awscli - ml-collections @@ -29,9 +29,10 @@ dependencies: - bioconda::hmmer==3.3.2 - bioconda::hhsuite==3.3.0 - bioconda::kalign2==2.04 - - pytorch::pytorch=1.12.* + - pytorch::pytorch=2.1 + - pytorch::pytorch-cuda=12.1 - pip: - deepspeed==0.12.4 - dm-tree==0.1.6 - git+https://github.com/NVIDIA/dllogger.git - - git+https://github.com/Dao-AILab/flash-attention.git@5b838a8 + - flash-attn diff --git a/openfold/data/data_pipeline.py b/openfold/data/data_pipeline.py index ce8494d..adde0b7 100644 --- a/openfold/data/data_pipeline.py +++ b/openfold/data/data_pipeline.py @@ -244,7 +244,7 @@ def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: features["num_alignments"] = np.array( [num_alignments] * num_res, dtype=np.int32 ) - features["msa_species_identifiers"] = np.array(species_ids, dtype=np.object_) + features["msa_species_identifiers"] = np.array(species_ids, dtype=object) return features @@ -590,7 +590,7 @@ def convert_monomer_features( ) -> FeatureDict: """Reshapes and modifies monomer features for multimer models.""" converted = {} - converted['auth_chain_id'] = np.asarray(chain_id, dtype=np.object_) + converted['auth_chain_id'] = np.asarray(chain_id, dtype=object) unnecessary_leading_dim_feats = { 'sequence', 'domain_name', 'num_alignments', 'seq_length' } @@ -1296,7 +1296,7 @@ class DataPipelineMultimer: ) mmcif_feats["release_date"] = np.array( - [mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_ + [mmcif_object.header["release_date"].encode("utf-8")], dtype=object ) mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32) diff --git a/openfold/model/primitives.py b/openfold/model/primitives.py index ea38cb3..e5735d1 100644 --- a/openfold/model/primitives.py +++ b/openfold/model/primitives.py @@ -28,7 +28,7 @@ if ds4s_is_installed: fa_is_installed = importlib.util.find_spec("flash_attn") is not None if fa_is_installed: from flash_attn.bert_padding import unpad_input - from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func + from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func import torch import torch.nn as nn @@ -811,7 +811,7 @@ def _flash_attn(q, k, v, kv_mask): kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask) kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:]) - out = flash_attn_unpadded_kvpacked_func( + out = flash_attn_varlen_kvpacked_func( q, kv_unpad, q_cu_seqlens, diff --git a/setup.py b/setup.py index bec9862..9179856 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ version_dependent_macros = [ ] extra_cuda_flags = [ - '-std=c++14', + '-std=c++17', '-maxrregcount=50', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', From 6dc34d71f52d3ae06a3b16e56f3cc3d90e61f99c Mon Sep 17 00:00:00 2001 From: Jennifer Date: Wed, 24 Jan 2024 01:34:33 -0500 Subject: [PATCH 11/34] first pass changes to run with pl 2.1 --- openfold/data/data_modules.py | 9 +- openfold/utils/seed.py | 2 +- train_openfold.py | 229 +++++++++++++++++++--------------- 3 files changed, 136 insertions(+), 104 deletions(-) diff --git a/openfold/data/data_modules.py b/openfold/data/data_modules.py index 62f48e4..de9f111 100644 --- a/openfold/data/data_modules.py +++ b/openfold/data/data_modules.py @@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule): with open(distillation_alignment_index_path, "r") as fp: self.distillation_alignment_index = json.load(fp) - def setup(self): + def setup(self, stage=None): # Most of the arguments are the same for the three datasets dataset_gen = partial(OpenFoldSingleDataset, template_mmcif_dir=self.template_mmcif_dir, @@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule): mode="predict", ) - def _gen_dataloader(self, stage): + def _gen_dataloader(self, stage=None): generator = None if self.batch_seed is not None: generator = torch.Generator() @@ -1053,7 +1053,8 @@ class OpenFoldDataModule(pl.LightningDataModule): def val_dataloader(self): if self.eval_dataset is not None: return self._gen_dataloader("eval") - return None + # Temp fix to pass the validation step + return [] def predict_dataloader(self): return self._gen_dataloader("predict") @@ -1085,7 +1086,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule): self.training_mode = self.train_data_dir is not None self.val_mmcif_data_cache_path = val_mmcif_data_cache_path - def setup(self): + def setup(self, setup=None): # Most of the arguments are the same for the three datasets dataset_gen = partial(OpenFoldSingleMultimerDataset, template_mmcif_dir=self.template_mmcif_dir, diff --git a/openfold/utils/seed.py b/openfold/utils/seed.py index b45b813..a305bfb 100644 --- a/openfold/utils/seed.py +++ b/openfold/utils/seed.py @@ -2,7 +2,7 @@ import os import logging import random import numpy as np -from pytorch_lightning.utilities.seed import seed_everything +from pytorch_lightning import seed_everything from openfold.utils.suppress_output import SuppressLogging diff --git a/train_openfold.py b/train_openfold.py index c4ba843..18396d7 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -8,7 +8,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger -from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin +from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy import torch from openfold.config import model_config @@ -56,7 +56,7 @@ class OpenFoldWrapper(pl.LightningModule): self.ema = ExponentialMovingAverage( model=self.model, decay=config.ema.decay ) - + self.cached_weights = None self.last_lr_step = -1 self.save_hyperparameters @@ -68,12 +68,12 @@ class OpenFoldWrapper(pl.LightningModule): phase = "train" if train else "val" for loss_name, indiv_loss in loss_breakdown.items(): self.log( - f"{phase}/{loss_name}", - indiv_loss, + f"{phase}/{loss_name}", + indiv_loss, on_step=train, on_epoch=(not train), logger=True, ) - if(train): + if (train): self.log( f"{phase}/{loss_name}_epoch", indiv_loss, @@ -82,12 +82,12 @@ class OpenFoldWrapper(pl.LightningModule): with torch.no_grad(): other_metrics = self._compute_validation_metrics( - batch, + batch, outputs, superimposition_metrics=(not train) ) - for k,v in other_metrics.items(): + for k, v in other_metrics.items(): self.log( f"{phase}/{k}", torch.mean(v), @@ -95,7 +95,7 @@ class OpenFoldWrapper(pl.LightningModule): ) def training_step(self, batch, batch_idx): - if(self.ema.device != batch["aatype"].device): + if (self.ema.device != batch["aatype"].device): self.ema.to(batch["aatype"].device) ground_truth = batch.pop('gt_features', None) @@ -126,12 +126,13 @@ class OpenFoldWrapper(pl.LightningModule): def validation_step(self, batch, batch_idx): # At the start of validation, load the EMA weights - if(self.cached_weights is None): + if (self.cached_weights is None): # model.state_dict() contains references to model weights rather - # than copies. Therefore, we need to clone them before calling + # than copies. Therefore, we need to clone them before calling # load_state_dict(). - clone_param = lambda t: t.detach().clone() - self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) + def clone_param(t): return t.detach().clone() + self.cached_weights = tensor_tree_map( + clone_param, self.model.state_dict()) self.model.load_state_dict(self.ema.state_dict()["params"]) ground_truth = batch.pop('gt_features', None) @@ -153,23 +154,23 @@ class OpenFoldWrapper(pl.LightningModule): ) self._log(loss_breakdown, batch, outputs, train=False) - - def validation_epoch_end(self, _): + + def on_validation_epoch_end(self, _): # Restore the model weights to normal self.model.load_state_dict(self.cached_weights) self.cached_weights = None - def _compute_validation_metrics(self, - batch, - outputs, - superimposition_metrics=False - ): + def _compute_validation_metrics(self, + batch, + outputs, + superimposition_metrics=False + ): metrics = {} - + gt_coords = batch["all_atom_positions"] pred_coords = outputs["final_atom_positions"] all_atom_mask = batch["all_atom_mask"] - + # This is super janky for superimposition. Fix later gt_coords_masked = gt_coords * all_atom_mask[..., None] pred_coords_masked = pred_coords * all_atom_mask[..., None] @@ -177,7 +178,7 @@ class OpenFoldWrapper(pl.LightningModule): gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :] pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :] all_atom_mask_ca = all_atom_mask[..., ca_pos] - + lddt_ca_score = lddt_ca( pred_coords, gt_coords, @@ -185,18 +186,18 @@ class OpenFoldWrapper(pl.LightningModule): eps=self.config.globals.eps, per_residue=False, ) - + metrics["lddt_ca"] = lddt_ca_score - + drmsd_ca_score = drmsd( pred_coords_masked_ca, gt_coords_masked_ca, - mask=all_atom_mask_ca, # still required here to compute n + mask=all_atom_mask_ca, # still required here to compute n ) - + metrics["drmsd_ca"] = drmsd_ca_score - - if(superimposition_metrics): + + if (superimposition_metrics): superimposed_pred, alignment_rmsd = superimpose( gt_coords_masked_ca, pred_coords_masked_ca, all_atom_mask_ca, ) @@ -210,22 +211,22 @@ class OpenFoldWrapper(pl.LightningModule): metrics["alignment_rmsd"] = alignment_rmsd metrics["gdt_ts"] = gdt_ts_score metrics["gdt_ha"] = gdt_ha_score - + return metrics - def configure_optimizers(self, - learning_rate: float = 1e-3, - eps: float = 1e-5, - ) -> torch.optim.Adam: -# return torch.optim.Adam( -# self.model.parameters(), -# lr=learning_rate, -# eps=eps -# ) + def configure_optimizers(self, + learning_rate: float = 1e-3, + eps: float = 1e-5, + ) -> torch.optim.Adam: + # return torch.optim.Adam( + # self.model.parameters(), + # lr=learning_rate, + # eps=eps + # ) # Ignored as long as a DeepSpeed optimizer is configured optimizer = torch.optim.Adam( - self.model.parameters(), - lr=learning_rate, + self.model.parameters(), + lr=learning_rate, eps=eps ) @@ -250,8 +251,9 @@ class OpenFoldWrapper(pl.LightningModule): def on_load_checkpoint(self, checkpoint): ema = checkpoint["ema"] - if(not self.model.template_config.enabled): - ema["params"] = {k:v for k,v in ema["params"].items() if not "template" in k} + if (not self.model.template_config.enabled): + ema["params"] = {k: v for k, + v in ema["params"].items() if not "template" in k} self.ema.load_state_dict(ema) def on_save_checkpoint(self, checkpoint): @@ -262,23 +264,23 @@ class OpenFoldWrapper(pl.LightningModule): def load_from_jax(self, jax_path): model_basename = os.path.splitext( - os.path.basename( - os.path.normpath(jax_path) - ) + os.path.basename( + os.path.normpath(jax_path) + ) )[0] model_version = "_".join(model_basename.split("_")[1:]) import_jax_weights_( - self.model, jax_path, version=model_version + self.model, jax_path, version=model_version ) def main(args): - if(args.seed is not None): - seed_everything(args.seed) + if (args.seed is not None): + seed_everything(args.seed) config = model_config( - args.config_preset, - train=True, + args.config_preset, + train=True, low_prec=(str(args.precision) == "16") ) if args.experiment_config_json: @@ -321,30 +323,31 @@ def main(args): if args.resume_from_jax_params: model_module.load_from_jax(args.resume_from_jax_params) - logging.info(f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") - + logging.info( + f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...") + # TorchScript components of the model - if(args.script_modules): + if (args.script_modules): script_preset_(model_module) if "multimer" in args.config_preset: data_module = OpenFoldMultimerDataModule( - config=config.data, - batch_seed=args.seed, - **vars(args) - ) + config=config.data, + batch_seed=args.seed, + **vars(args) + ) else: data_module = OpenFoldDataModule( - config=config.data, + config=config.data, batch_seed=args.seed, **vars(args) ) data_module.prepare_data() data_module.setup() - + callbacks = [] - if(args.checkpoint_every_epoch): + if (args.checkpoint_every_epoch): mc = ModelCheckpoint( every_n_epochs=1, auto_insert_metric_name=False, @@ -352,7 +355,7 @@ def main(args): ) callbacks.append(mc) - if(args.early_stopping): + if (args.early_stopping): es = EarlyStoppingVerbose( monitor="val/lddt_ca", min_delta=args.min_delta, @@ -364,7 +367,7 @@ def main(args): ) callbacks.append(es) - if(args.log_performance): + if (args.log_performance): global_batch_size = args.num_nodes * args.gpus perf = PerformanceLoggingCallback( log_file=os.path.join(args.output_dir, "performance_log.json"), @@ -372,12 +375,12 @@ def main(args): ) callbacks.append(perf) - if(args.log_lr): + if (args.log_lr): lr_monitor = LearningRateMonitor(logging_interval="step") callbacks.append(lr_monitor) loggers = [] - if(args.wandb): + if (args.wandb): wdb_logger = WandbLogger( name=args.experiment_name, save_dir=args.output_dir, @@ -388,38 +391,43 @@ def main(args): ) loggers.append(wdb_logger) - if(args.deepspeed_config_path is not None): - strategy = DeepSpeedPlugin( + if (args.deepspeed_config_path is not None): + strategy = DeepSpeedStrategy( config=args.deepspeed_config_path, ) - if(args.wandb): + if (args.wandb): wdb_logger.experiment.save(args.deepspeed_config_path) wdb_logger.experiment.save("openfold/config.py") elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1: - strategy = DDPPlugin(find_unused_parameters=False) + strategy = DDPStrategy(find_unused_parameters=False) else: strategy = None - - if(args.wandb): + + if (args.wandb): freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt" os.system(f"{sys.executable} -m pip freeze > {freeze_path}") wdb_logger.experiment.save(f"{freeze_path}") - trainer = pl.Trainer.from_argparse_args( - args, - default_root_dir=args.output_dir, - strategy=strategy, - callbacks=callbacks, - logger=loggers, - ) + # Raw dump of all args from pl.Trainer constructor + trainer_kws = set([ + 'accelerator', 'strategy', 'devices', 'num_nodes', 'precision', 'logger', 'callbacks', 'fast_dev_run', 'max_epochs', 'min_epochs', 'max_steps', 'min_steps', 'max_tim', 'limit_train_batches', 'limit_val_batches', 'limit_test_batches', 'limit_predict_batches', 'overfit_batches', 'val_check_interval', 'check_val_every_n_epoch', 'num_sanity_val_steps', 'log_every_n_steps', 'enable_checkpointing', 'enable_progress_bar', 'enable_model_summary', 'accumulate_grad_batches', 'gradient_clip_val', 'gradient_clip_algorithm', 'deterministic', 'benchmark', 'inference_mode', 'use_distributed_sampler', 'profiler', 'detect_anomaly', 'barebones', 'plugins', 'sync_batchnorm', 'reload_dataloaders_every_n_epochs', 'default_root_dir', + ]) + trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws} + trainer_args.update({ + 'default_root_dir': args.output_dir, + 'strategy': strategy, + 'callbacks': callbacks, + 'logger': loggers, + }) + trainer = pl.Trainer(**trainer_args) - if(args.resume_model_weights_only): + if (args.resume_model_weights_only): ckpt_path = None else: ckpt_path = args.resume_from_ckpt trainer.fit( - model_module, + model_module, datamodule=data_module, ckpt_path=ckpt_path, ) @@ -621,36 +629,59 @@ if __name__ == "__main__": parser.add_argument( "--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting", ) - parser = pl.Trainer.add_argparse_args(parser) - - # Disable the initial validation pass - parser.set_defaults( - num_sanity_val_steps=0, + parser.add_argument( + "--num_nodes", type=int, default=1, + ) + parser.add_argument( + "--gpus", type=int, default=1, + ) + parser.add_argument( + "--precision", type=str, default=None, + ) + parser.add_argument( + "--replace_sampler_ddp", type=bool_type, default=True, + ) + parser.add_argument( + "--max_epochs", type=int, default=1, + ) + parser.add_argument( + "--log_every_n_steps", type=int, default=25, + ) + parser.add_argument( + "--num_sanity_val_steps", type=int, default=0, ) - # Remove some buggy/redundant arguments introduced by the Trainer - remove_arguments( - parser, - [ - "--accelerator", - "--resume_from_checkpoint", - "--reload_dataloaders_every_epoch", - "--reload_dataloaders_every_n_epochs", - ] - ) + # parser = pl.Trainer.add_argparse_args(parser) + # + # # Disable the initial validation pass + # parser.set_defaults( + # num_sanity_val_steps=0, + # ) + + # # Remove some buggy/redundant arguments introduced by the Trainer + # remove_arguments( + # parser, + # [ + # "--accelerator", + # "--resume_from_checkpoint", + # "--reload_dataloaders_every_epoch", + # "--reload_dataloaders_every_n_epochs", + # ] + # ) args = parser.parse_args() - if(args.seed is None and - ((args.gpus is not None and args.gpus > 1) or + if (args.seed is None and + ((args.gpus is not None and args.gpus > 1) or (args.num_nodes is not None and args.num_nodes > 1))): raise ValueError("For distributed training, --seed must be specified") - if(str(args.precision) == "16" and args.deepspeed_config_path is not None): + if (str(args.precision) == "16" and args.deepspeed_config_path is not None): raise ValueError("DeepSpeed and FP16 training are not compatible") - if(args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): - raise ValueError("Choose between loading pretrained Jax-weights and a checkpoint-path") + if (args.resume_from_jax_params is not None and args.resume_from_ckpt is not None): + raise ValueError( + "Choose between loading pretrained Jax-weights and a checkpoint-path") # This re-applies the training-time filters at the beginning of every epoch args.reload_dataloaders_every_n_epochs = 1 From a317ad270282b9cd6aa429e15c22d55139af5b27 Mon Sep 17 00:00:00 2001 From: Jennifer Date: Thu, 21 Mar 2024 02:50:13 -0400 Subject: [PATCH 12/34] superimposition fix from Aymen --- openfold/utils/superimposition.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/openfold/utils/superimposition.py b/openfold/utils/superimposition.py index ad6b15d..d1dca27 100644 --- a/openfold/utils/superimposition.py +++ b/openfold/utils/superimposition.py @@ -35,10 +35,10 @@ def _superimpose_np(reference, coords): def _superimpose_single(reference, coords): - reference_np = reference.detach().cpu().numpy() - coords_np = coords.detach().cpu().numpy() - superimposed, rmsd = _superimpose_np(reference_np, coords_np) - return coords.new_tensor(superimposed), coords.new_tensor(rmsd) + reference_np = reference.detach().to(torch.float).cpu().numpy() + coords_np = coords.detach().to(torch.float).cpu().numpy() + superimposed, rmsd = _superimpose_np(reference_np, coords_np) + return coords.new_tensor(superimposed), coords.new_tensor(rmsd) def superimpose(reference, coords, mask): From cfd2e71981cce833d94bf401ab970527e5b61c13 Mon Sep 17 00:00:00 2001 From: Jennifer Date: Mon, 25 Mar 2024 04:39:39 -0400 Subject: [PATCH 13/34] seed workers fix and validation_epoch_end extra argument --- train_openfold.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_openfold.py b/train_openfold.py index 18396d7..28fd942 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -9,6 +9,7 @@ from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy +from pytorch_lightning import seed_everything import torch from openfold.config import model_config @@ -24,7 +25,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage from openfold.utils.loss import AlphaFoldLoss, lddt_ca from openfold.utils.lr_schedulers import AlphaFoldLRScheduler from openfold.utils.multi_chain_permutation import multi_chain_permutation_align -from openfold.utils.seed import seed_everything from openfold.utils.superimposition import superimpose from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.validation_metrics import ( @@ -155,7 +155,7 @@ class OpenFoldWrapper(pl.LightningModule): self._log(loss_breakdown, batch, outputs, train=False) - def on_validation_epoch_end(self, _): + def on_validation_epoch_end(self): # Restore the model weights to normal self.model.load_state_dict(self.cached_weights) self.cached_weights = None @@ -276,7 +276,7 @@ class OpenFoldWrapper(pl.LightningModule): def main(args): if (args.seed is not None): - seed_everything(args.seed) + seed_everything(args.seed, workers=True) config = model_config( args.config_preset, From 0c3435cc758861d89b11da8d864243e170b053cc Mon Sep 17 00:00:00 2001 From: Jennifer Date: Mon, 25 Mar 2024 06:26:00 -0400 Subject: [PATCH 14/34] add metric logging to progress bar. --- train_openfold.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/train_openfold.py b/train_openfold.py index 28fd942..cccd418 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -17,7 +17,6 @@ from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataM from openfold.model.model import AlphaFold from openfold.model.torchscript import script_preset_ from openfold.np import residue_constants -from openfold.utils.argparse_utils import remove_arguments from openfold.utils.callbacks import ( EarlyStoppingVerbose, ) @@ -70,13 +69,14 @@ class OpenFoldWrapper(pl.LightningModule): self.log( f"{phase}/{loss_name}", indiv_loss, + prog_bar=(loss_name == 'loss'), on_step=train, on_epoch=(not train), logger=True, ) if (train): self.log( f"{phase}/{loss_name}_epoch", - indiv_loss, + indiv_loss, on_step=False, on_epoch=True, logger=True, ) @@ -91,7 +91,8 @@ class OpenFoldWrapper(pl.LightningModule): self.log( f"{phase}/{k}", torch.mean(v), - on_step=False, on_epoch=True, logger=True + prog_bar = (k == 'loss'), + on_step=False, on_epoch=True, logger=True, ) def training_step(self, batch, batch_idx): @@ -629,12 +630,17 @@ if __name__ == "__main__": parser.add_argument( "--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting", ) + # Trainer additional arguments + # Ideally we'd want something like config.add_trainer_args() parser.add_argument( "--num_nodes", type=int, default=1, ) parser.add_argument( "--gpus", type=int, default=1, ) + parser.add_argument( + "--num_workers", type=int, default=4, # interaction with num_data_workers? + ) parser.add_argument( "--precision", type=str, default=None, ) @@ -647,6 +653,9 @@ if __name__ == "__main__": parser.add_argument( "--log_every_n_steps", type=int, default=25, ) + parser.add_argument( + "--flush_logs_every_n_steps", type=int, default=5, + ) parser.add_argument( "--num_sanity_val_steps", type=int, default=0, ) From 5ff5177bc63acc3c367ac0b22552fc0f0339e3cb Mon Sep 17 00:00:00 2001 From: Jennifer Date: Wed, 27 Mar 2024 04:03:20 -0400 Subject: [PATCH 15/34] more logging changes --- train_openfold.py | 50 ++++++++++++----------------------------------- 1 file changed, 12 insertions(+), 38 deletions(-) diff --git a/train_openfold.py b/train_openfold.py index cccd418..b989f77 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -410,9 +410,7 @@ def main(args): wdb_logger.experiment.save(f"{freeze_path}") # Raw dump of all args from pl.Trainer constructor - trainer_kws = set([ - 'accelerator', 'strategy', 'devices', 'num_nodes', 'precision', 'logger', 'callbacks', 'fast_dev_run', 'max_epochs', 'min_epochs', 'max_steps', 'min_steps', 'max_tim', 'limit_train_batches', 'limit_val_batches', 'limit_test_batches', 'limit_predict_batches', 'overfit_batches', 'val_check_interval', 'check_val_every_n_epoch', 'num_sanity_val_steps', 'log_every_n_steps', 'enable_checkpointing', 'enable_progress_bar', 'enable_model_summary', 'accumulate_grad_batches', 'gradient_clip_val', 'gradient_clip_algorithm', 'deterministic', 'benchmark', 'inference_mode', 'use_distributed_sampler', 'profiler', 'detect_anomaly', 'barebones', 'plugins', 'sync_batchnorm', 'reload_dataloaders_every_n_epochs', 'default_root_dir', - ]) + trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps', 'flush_logs_ever_n_steps', 'num_sanity_val_steps'] trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws} trainer_args.update({ 'default_root_dir': args.output_dir, @@ -630,54 +628,30 @@ if __name__ == "__main__": parser.add_argument( "--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting", ) - # Trainer additional arguments - # Ideally we'd want something like config.add_trainer_args() parser.add_argument( + "--gpus", type=int, default=1, help='For determining optimal strategy and effective batch size.' + ) + + trainer_group = parser.add_argument_group('PyTorch Lightning Trainer Args') + trainer_group.add_argument( "--num_nodes", type=int, default=1, ) - parser.add_argument( - "--gpus", type=int, default=1, + trainer_group.add_argument( + "--precision", type=str, default='bf16', help='Sets precision, lower precision improves runtime performance.' ) - parser.add_argument( - "--num_workers", type=int, default=4, # interaction with num_data_workers? - ) - parser.add_argument( - "--precision", type=str, default=None, - ) - parser.add_argument( - "--replace_sampler_ddp", type=bool_type, default=True, - ) - parser.add_argument( + trainer_group.add_argument( "--max_epochs", type=int, default=1, ) - parser.add_argument( + trainer_group.add_argument( "--log_every_n_steps", type=int, default=25, ) - parser.add_argument( + trainer_group.add_argument( "--flush_logs_every_n_steps", type=int, default=5, ) - parser.add_argument( + trainer_group.add_argument( "--num_sanity_val_steps", type=int, default=0, ) - # parser = pl.Trainer.add_argparse_args(parser) - # - # # Disable the initial validation pass - # parser.set_defaults( - # num_sanity_val_steps=0, - # ) - - # # Remove some buggy/redundant arguments introduced by the Trainer - # remove_arguments( - # parser, - # [ - # "--accelerator", - # "--resume_from_checkpoint", - # "--reload_dataloaders_every_epoch", - # "--reload_dataloaders_every_n_epochs", - # ] - # ) - args = parser.parse_args() if (args.seed is None and From 862635834a25ee926f31de74707696860a5455a9 Mon Sep 17 00:00:00 2001 From: Jennifer Date: Tue, 2 Apr 2024 02:49:16 -0400 Subject: [PATCH 16/34] add paren to save_hyperparameters --- train_openfold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_openfold.py b/train_openfold.py index b989f77..9cce305 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -58,7 +58,7 @@ class OpenFoldWrapper(pl.LightningModule): self.cached_weights = None self.last_lr_step = -1 - self.save_hyperparameters + self.save_hyperparameters() def forward(self, batch): return self.model(batch) From 577219c11254cb6e6fb14fb53d8eef2bc67707da Mon Sep 17 00:00:00 2001 From: Jennifer Date: Tue, 2 Apr 2024 05:01:33 -0400 Subject: [PATCH 17/34] Removes OF copy of zero_to_fp32.py favoring deepspeed.util version --- scripts/convert_v1_to_v2_weights.py | 4 +- scripts/zero_to_fp32.py | 598 ---------------------------- train_openfold.py | 23 +- 3 files changed, 18 insertions(+), 607 deletions(-) delete mode 100755 scripts/zero_to_fp32.py diff --git a/scripts/convert_v1_to_v2_weights.py b/scripts/convert_v1_to_v2_weights.py index da693b2..e5834f4 100755 --- a/scripts/convert_v1_to_v2_weights.py +++ b/scripts/convert_v1_to_v2_weights.py @@ -22,7 +22,9 @@ import shutil import torch from openfold.utils.import_weights import convert_deprecated_v1_keys -from zero_to_fp32 import get_optim_files, parse_optim_states, get_model_state_file +from deepspeed.utils.zero_to_fp32 import ( + get_optim_files, parse_optim_states, get_model_state_file +) def convert_v1_to_v2_weights(args): diff --git a/scripts/zero_to_fp32.py b/scripts/zero_to_fp32.py deleted file mode 100755 index 48e3b95..0000000 --- a/scripts/zero_to_fp32.py +++ /dev/null @@ -1,598 +0,0 @@ -#!/usr/bin/env python - -# Copyright (c) Microsoft Corporation. -# SPDX-License-Identifier: Apache-2.0 - -# DeepSpeed Team - -# This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets -# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in -# the future. Once extracted, the weights don't require DeepSpeed and can be used in any -# application. -# -# example: python zero_to_fp32.py . pytorch_model.bin - -import argparse -import torch -import glob -import math -import os -import re -from collections import OrderedDict -from dataclasses import dataclass - -# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with -# DeepSpeed data structures it has to be available in the current python environment. -from deepspeed.utils import logger -from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, - FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, - FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) - - -@dataclass -class zero_model_state: - buffers: dict() - param_shapes: dict() - shared_params: list - ds_version: int - frozen_param_shapes: dict() - frozen_param_fragments: dict() - - -debug = 0 - -# load to cpu -device = torch.device('cpu') - - -def atoi(text): - return int(text) if text.isdigit() else text - - -def natural_keys(text): - ''' - alist.sort(key=natural_keys) sorts in human order - http://nedbatchelder.com/blog/200712/human_sorting.html - (See Toothy's implementation in the comments) - ''' - return [atoi(c) for c in re.split(r'(\d+)', text)] - - -def get_model_state_file(checkpoint_dir, zero_stage): - if not os.path.isdir(checkpoint_dir): - raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") - - # there should be only one file - if zero_stage <= 2: - file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") - elif zero_stage == 3: - file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") - - if not os.path.exists(file): - raise FileNotFoundError(f"can't find model states file at '{file}'") - - return file - - -def get_checkpoint_files(checkpoint_dir, glob_pattern): - # XXX: need to test that this simple glob rule works for multi-node setup too - ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) - - if len(ckpt_files) == 0: - raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") - - return ckpt_files - - -def get_optim_files(checkpoint_dir): - return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") - - -def get_model_state_files(checkpoint_dir): - return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") - - -def parse_model_states(files): - zero_model_states = [] - for file in files: - state_dict = torch.load(file, map_location=device) - - if BUFFER_NAMES not in state_dict: - raise ValueError(f"{file} is not a model state checkpoint") - buffer_names = state_dict[BUFFER_NAMES] - if debug: - print("Found buffers:", buffer_names) - - # recover just the buffers while restoring them to fp32 if they were saved in fp16 - buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} - param_shapes = state_dict[PARAM_SHAPES] - - # collect parameters that are included in param_shapes - param_names = [] - for s in param_shapes: - for name in s.keys(): - param_names.append(name) - - # update with frozen parameters - frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) - if frozen_param_shapes is not None: - if debug: - print(f"Found frozen_param_shapes: {frozen_param_shapes}") - param_names += list(frozen_param_shapes.keys()) - - # handle shared params - shared_params = [[k, v] for k, v in state_dict["shared_params"].items()] - - ds_version = state_dict.get(DS_VERSION, None) - - frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) - - z_model_state = zero_model_state(buffers=buffers, - param_shapes=param_shapes, - shared_params=shared_params, - ds_version=ds_version, - frozen_param_shapes=frozen_param_shapes, - frozen_param_fragments=frozen_param_fragments) - zero_model_states.append(z_model_state) - - return zero_model_states - - -def parse_optim_states(files, ds_checkpoint_dir): - - total_files = len(files) - state_dicts = [] - for f in files: - state_dict = torch.load(f, map_location=device) - # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights - # and also handle the case where it was already removed by another helper script - state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None) - state_dicts.append(state_dict) - - if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: - raise ValueError(f"{files[0]} is not a zero checkpoint") - zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] - world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] - - # For ZeRO-2 each param group can have different partition_count as data parallelism for expert - # parameters can be different from data parallelism for non-expert parameters. So we can just - # use the max of the partition_count to get the dp world_size. - - if type(world_size) is list: - world_size = max(world_size) - - if world_size != total_files: - raise ValueError( - f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " - "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." - ) - - # the groups are named differently in each stage - if zero_stage <= 2: - fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS - elif zero_stage == 3: - fp32_groups_key = FP32_FLAT_GROUPS - else: - raise ValueError(f"unknown zero stage {zero_stage}") - - if zero_stage <= 2: - fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] - elif zero_stage == 3: - # if there is more than one param group, there will be multiple flattened tensors - one - # flattened tensor per group - for simplicity merge them into a single tensor - # - # XXX: could make the script more memory efficient for when there are multiple groups - it - # will require matching the sub-lists of param_shapes for each param group flattened tensor - - fp32_flat_groups = [ - torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) - ] - - return zero_stage, world_size, fp32_flat_groups - - -def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): - """ - Returns fp32 state_dict reconstructed from ds checkpoint - - Args: - - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) - - """ - print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") - - optim_files = get_optim_files(ds_checkpoint_dir) - zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) - print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") - - model_files = get_model_state_files(ds_checkpoint_dir) - - zero_model_states = parse_model_states(model_files) - print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') - - if zero_stage <= 2: - return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states) - elif zero_stage == 3: - return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states) - - -def _zero2_merge_frozen_params(state_dict, zero_model_states): - if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: - return - - frozen_param_shapes = zero_model_states[0].frozen_param_shapes - frozen_param_fragments = zero_model_states[0].frozen_param_fragments - - if debug: - num_elem = sum(s.numel() for s in frozen_param_shapes.values()) - print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') - - wanted_params = len(frozen_param_shapes) - wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) - avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) - print(f'Frozen params: Have {avail_numel} numels to process.') - print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') - - total_params = 0 - total_numel = 0 - for name, shape in frozen_param_shapes.items(): - total_params += 1 - unpartitioned_numel = shape.numel() - total_numel += unpartitioned_numel - - state_dict[name] = frozen_param_fragments[name] - - if debug: - print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") - - print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") - - -def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): - param_shapes = zero_model_states[0].param_shapes - - # Reconstruction protocol: - # - # XXX: document this - - if debug: - for i in range(world_size): - for j in range(len(fp32_flat_groups[0])): - print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") - - # XXX: memory usage doubles here (zero2) - num_param_groups = len(fp32_flat_groups[0]) - merged_single_partition_of_fp32_groups = [] - for i in range(num_param_groups): - merged_partitions = [sd[i] for sd in fp32_flat_groups] - full_single_fp32_vector = torch.cat(merged_partitions, 0) - merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) - avail_numel = sum( - [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) - - if debug: - wanted_params = sum([len(shapes) for shapes in param_shapes]) - wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) - # not asserting if there is a mismatch due to possible padding - print(f"Have {avail_numel} numels to process.") - print(f"Need {wanted_numel} numels in {wanted_params} params.") - - # params - # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support - # out-of-core computing solution - total_numel = 0 - total_params = 0 - for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): - offset = 0 - avail_numel = full_single_fp32_vector.numel() - for name, shape in shapes.items(): - - unpartitioned_numel = shape.numel() - total_numel += unpartitioned_numel - total_params += 1 - - if debug: - print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") - state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) - offset += unpartitioned_numel - - # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and - # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex - # paddings performed in the code it's almost impossible to predict the exact numbers w/o the - # live optimizer object, so we are checking that the numbers are within the right range - align_to = 2 * world_size - - def zero2_align(x): - return align_to * math.ceil(x / align_to) - - if debug: - print(f"original offset={offset}, avail_numel={avail_numel}") - - offset = zero2_align(offset) - avail_numel = zero2_align(avail_numel) - - if debug: - print(f"aligned offset={offset}, avail_numel={avail_numel}") - - # Sanity check - if offset != avail_numel: - raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") - - print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") - - -def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states): - state_dict = OrderedDict() - - # buffers - buffers = zero_model_states[0].buffers - state_dict.update(buffers) - if debug: - print(f"added {len(buffers)} buffers") - - _zero2_merge_frozen_params(state_dict, zero_model_states) - - _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) - - # recover shared parameters - for pair in zero_model_states[0].shared_params: - if pair[1] in state_dict: - state_dict[pair[0]] = state_dict[pair[1]] - - return state_dict - - -def zero3_partitioned_param_info(unpartitioned_numel, world_size): - remainder = unpartitioned_numel % world_size - padding_numel = (world_size - remainder) if remainder else 0 - partitioned_numel = math.ceil(unpartitioned_numel / world_size) - return partitioned_numel, padding_numel - - -def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): - if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: - return - - if debug: - for i in range(world_size): - num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) - print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') - - frozen_param_shapes = zero_model_states[0].frozen_param_shapes - wanted_params = len(frozen_param_shapes) - wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) - avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size - print(f'Frozen params: Have {avail_numel} numels to process.') - print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') - - total_params = 0 - total_numel = 0 - for name, shape in zero_model_states[0].frozen_param_shapes.items(): - total_params += 1 - unpartitioned_numel = shape.numel() - total_numel += unpartitioned_numel - - param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) - state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) - - partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) - - if debug: - print( - f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" - ) - - print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") - - -def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): - param_shapes = zero_model_states[0].param_shapes - avail_numel = fp32_flat_groups[0].numel() * world_size - # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each - # param, re-consolidating each param, while dealing with padding if any - - # merge list of dicts, preserving order - param_shapes = {k: v for d in param_shapes for k, v in d.items()} - - if debug: - for i in range(world_size): - print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") - - wanted_params = len(param_shapes) - wanted_numel = sum(shape.numel() for shape in param_shapes.values()) - # not asserting if there is a mismatch due to possible padding - avail_numel = fp32_flat_groups[0].numel() * world_size - print(f"Trainable params: Have {avail_numel} numels to process.") - print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") - - # params - # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support - # out-of-core computing solution - offset = 0 - total_numel = 0 - total_params = 0 - for name, shape in param_shapes.items(): - - unpartitioned_numel = shape.numel() - total_numel += unpartitioned_numel - total_params += 1 - - partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) - - if debug: - print( - f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" - ) - - # XXX: memory usage doubles here - state_dict[name] = torch.cat( - tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), - 0).narrow(0, 0, unpartitioned_numel).view(shape) - offset += partitioned_numel - - offset *= world_size - - # Sanity check - if offset != avail_numel: - raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") - - print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") - - -def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states): - state_dict = OrderedDict() - - # buffers - buffers = zero_model_states[0].buffers - state_dict.update(buffers) - if debug: - print(f"added {len(buffers)} buffers") - - _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) - - _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) - - # recover shared parameters - for pair in zero_model_states[0].shared_params: - if pair[1] in state_dict: - state_dict[pair[0]] = state_dict[pair[1]] - - return state_dict - - -def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): - """ - Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with - ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example - via a model hub. - - Args: - - ``checkpoint_dir``: path to the desired checkpoint folder - - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` - - Returns: - - pytorch ``state_dict`` - - Note: this approach may not work if your application doesn't have sufficient free CPU memory and - you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with - the checkpoint. - - A typical usage might be :: - - from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint - # do the training and checkpoint saving - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu - model = model.cpu() # move to cpu - model.load_state_dict(state_dict) - # submit to model hub or save the model to share with others - - In this example the ``model`` will no longer be usable in the deepspeed context of the same - application. i.e. you will need to re-initialize the deepspeed engine, since - ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. - - If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. - - """ - if tag is None: - latest_path = os.path.join(checkpoint_dir, 'latest') - if os.path.isfile(latest_path): - with open(latest_path, 'r') as fd: - tag = fd.read().strip() - else: - raise ValueError(f"Unable to find 'latest' file at {latest_path}") - - ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) - - if not os.path.isdir(ds_checkpoint_dir): - raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") - - return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) - - -def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): - """ - Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be - loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. - - Args: - - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) - - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) - - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` - """ - - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) - print(f"Saving fp32 state dict to {output_file}") - torch.save(state_dict, output_file) - - -def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): - """ - 1. Put the provided model to cpu - 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` - 3. Load it into the provided model - - Args: - - ``model``: the model object to update - - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) - - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` - - Returns: - - ``model`: modified model - - Make sure you have plenty of CPU memory available before you call this function. If you don't - have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it - conveniently placed for you in the checkpoint folder. - - A typical usage might be :: - - from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint - model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) - # submit to model hub or save the model to share with others - - Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context - of the same application. i.e. you will need to re-initialize the deepspeed engine, since - ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. - - """ - logger.info(f"Extracting fp32 weights") - state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) - - logger.info(f"Overwriting model with fp32 weights") - model = model.cpu() - model.load_state_dict(state_dict, strict=False) - - return model - -def get_global_step_from_zero_checkpoint(checkpoint_dir): - global_step = -1 - latest_path = os.path.join(checkpoint_dir, 'latest') - if os.path.isfile(latest_path): - with open(latest_path, 'r') as fd: - tag = fd.read().strip() - match = re.match(r"global_step([0-9]+)", tag) - global_step = int(match.group(1)) - else: - raise ValueError(f"Unable to find 'latest' file at {latest_path}") - return global_step - -if __name__ == "__main__": - - parser = argparse.ArgumentParser() - parser.add_argument("checkpoint_dir", - type=str, - help="path to the desired checkpoint folder, e.g., path/checkpoint-12") - parser.add_argument( - "output_file", - type=str, - help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") - parser.add_argument("-t", - "--tag", - type=str, - default=None, - help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1") - parser.add_argument("-d", "--debug", action='store_true', help="enable debug") - args = parser.parse_args() - - debug = args.debug - - convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file, tag=args.tag) diff --git a/train_openfold.py b/train_openfold.py index 9cce305..8c9c861 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -11,6 +11,7 @@ from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy from pytorch_lightning import seed_everything import torch +from deepspeed.utils import zero_to_fp32 from openfold.config import model_config from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule @@ -35,11 +36,6 @@ from openfold.utils.import_weights import ( import_jax_weights_, import_openfold_weights_ ) -from scripts.zero_to_fp32 import ( - get_fp32_state_dict_from_zero_checkpoint, - get_global_step_from_zero_checkpoint -) - from openfold.utils.logger import PerformanceLoggingCallback @@ -274,6 +270,18 @@ class OpenFoldWrapper(pl.LightningModule): self.model, jax_path, version=model_version ) +def get_model_state_dict_from_ds_checkpoint(checkpoint_dir): + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + _DS_CHECKPOINT_VERSION = 2 # based on manual parsing of checkpoint files + state_file = zero_to_fp32.get_model_state_file(ds_checkpoint_dir, _DS_CHECKPOINT_VERSION) + return torch.load(state_file) def main(args): if (args.seed is not None): @@ -314,11 +322,10 @@ def main(args): else: # Loads a checkpoint to start from a specific time step if os.path.isdir(args.resume_from_ckpt): - last_global_step = get_global_step_from_zero_checkpoint( - args.resume_from_ckpt) + sd = get_model_state_dict_from_ds_checkpoint(args.resume_from_ckpt) else: sd = torch.load(args.resume_from_ckpt) - last_global_step = int(sd['global_step']) + last_global_step = int(sd['global_step']) model_module.resume_last_lr_step(last_global_step) logging.info("Successfully loaded last lr step...") From 523adaf448b316ec4221220a402b4ccbfbba3935 Mon Sep 17 00:00:00 2001 From: Jennifer Date: Thu, 11 Apr 2024 02:57:49 -0400 Subject: [PATCH 18/34] adds reload_dataloaders_every_n_epochs flag --- train_openfold.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/train_openfold.py b/train_openfold.py index 8c9c861..d10f79b 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -416,8 +416,7 @@ def main(args): os.system(f"{sys.executable} -m pip freeze > {freeze_path}") wdb_logger.experiment.save(f"{freeze_path}") - # Raw dump of all args from pl.Trainer constructor - trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps', 'flush_logs_ever_n_steps', 'num_sanity_val_steps'] + trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps', 'flush_logs_ever_n_steps', 'num_sanity_val_steps', 'reload_dataloaders_every_n_epochs'] trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws} trainer_args.update({ 'default_root_dir': args.output_dir, @@ -658,6 +657,9 @@ if __name__ == "__main__": trainer_group.add_argument( "--num_sanity_val_steps", type=int, default=0, ) + trainer_group.add_argument( + "--reload_dataloaders_every_n_epochs", type=int, default=1, + ) args = parser.parse_args() @@ -673,7 +675,4 @@ if __name__ == "__main__": raise ValueError( "Choose between loading pretrained Jax-weights and a checkpoint-path") - # This re-applies the training-time filters at the beginning of every epoch - args.reload_dataloaders_every_n_epochs = 1 - main(args) From 80e6341022ef07849135a67a950fc618b8620770 Mon Sep 17 00:00:00 2001 From: Jennifer Date: Wed, 17 Apr 2024 06:47:13 -0400 Subject: [PATCH 19/34] change message for test_model.py compare --- tests/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_model.py b/tests/test_model.py index 3d19f14..ecf5af1 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -202,4 +202,4 @@ class TestModel(unittest.TestCase): out_repro = out_repro["sm"]["positions"][-1] out_repro = out_repro.squeeze(0) - self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3) + compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 1e-3) From 1ae833bfc8b6d22c256f2328f64c74cb5b6c215f Mon Sep 17 00:00:00 2001 From: Jennifer Date: Fri, 19 Apr 2024 06:12:32 -0400 Subject: [PATCH 20/34] Updates low_precision check to use current precision settings. --- train_openfold.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/train_openfold.py b/train_openfold.py index d10f79b..bfd3b18 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -287,10 +287,13 @@ def main(args): if (args.seed is not None): seed_everything(args.seed, workers=True) + is_low_precision = args.precision in [ + "bf16-mixed", "16", "bf16", "16-true", "16-mixed", "bf16-mixed"] + config = model_config( args.config_preset, train=True, - low_prec=(str(args.precision) == "16") + low_prec=is_low_precision, ) if args.experiment_config_json: with open(args.experiment_config_json, 'r') as f: @@ -643,7 +646,8 @@ if __name__ == "__main__": "--num_nodes", type=int, default=1, ) trainer_group.add_argument( - "--precision", type=str, default='bf16', help='Sets precision, lower precision improves runtime performance.' + "--precision", type=str, default='bf16', + help='Sets precision, lower precision improves runtime performance.', ) trainer_group.add_argument( "--max_epochs", type=int, default=1, From ea142a0a6831e140cd742f31f21b6a87479ca818 Mon Sep 17 00:00:00 2001 From: Jennifer Date: Fri, 19 Apr 2024 06:40:37 -0400 Subject: [PATCH 21/34] fixes deepspeed function definition. --- train_openfold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_openfold.py b/train_openfold.py index bfd3b18..839154c 100644 --- a/train_openfold.py +++ b/train_openfold.py @@ -306,7 +306,7 @@ def main(args): if args.resume_model_weights_only: # Load the checkpoint if os.path.isdir(args.resume_from_ckpt): - sd = get_fp32_state_dict_from_zero_checkpoint( + sd = zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint( args.resume_from_ckpt) else: sd = torch.load(args.resume_from_ckpt) From 5ccb7de370275fb426f09ffb24ecf74425fb1457 Mon Sep 17 00:00:00 2001 From: jnwei Date: Fri, 19 Apr 2024 17:51:10 +0700 Subject: [PATCH 22/34] updates Dockerfile --- Dockerfile | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/Dockerfile b/Dockerfile index 8353003..192a00a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,17 +1,14 @@ -FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu18.04 +FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 # metainformation -LABEL org.opencontainers.image.version = "1.0.0" -LABEL org.opencontainers.image.authors = "Gustaf Ahdritz" +LABEL org.opencontainers.image.version = "2.0.0" +LABEL org.opencontainers.image.authors = "OpenFold Team" LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfold" LABEL org.opencontainers.image.licenses = "Apache License 2.0" -LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04" +LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:12.4.1-devel-ubuntu22.04" -RUN apt-key del 7fa2af80 -RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub -RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub +RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-12-4 libcusparse-dev-12-4 libcublas-dev-12-4 libcusolver-dev-12-4 git -RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git RUN wget -P /tmp \ "https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \ && bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \ From 3cab807e060bf9cda1106b26e9717ee3f211794e Mon Sep 17 00:00:00 2001 From: jnwei Date: Fri, 19 Apr 2024 18:13:33 +0700 Subject: [PATCH 23/34] fix mkl version to 2024.0.0 --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index d6ccb46..f15d63b 100644 --- a/environment.yml +++ b/environment.yml @@ -24,6 +24,7 @@ dependencies: - modelcif==0.7 - awscli - ml-collections + - mkl=2024.0.0 - aria2 - git - bioconda::hmmer==3.3.2 From 866477a26815800318a96540953c473698365339 Mon Sep 17 00:00:00 2001 From: jnwei Date: Sat, 20 Apr 2024 14:09:58 +0700 Subject: [PATCH 24/34] Update gpg keys for Docker build --- Dockerfile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Dockerfile b/Dockerfile index 192a00a..f7a0938 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,6 +7,9 @@ LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfol LABEL org.opencontainers.image.licenses = "Apache License 2.0" LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:12.4.1-devel-ubuntu22.04" +RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb +RUN sudo dpkg -i cuda-keyring_1.0-1_all.deb + RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-12-4 libcusparse-dev-12-4 libcublas-dev-12-4 libcusolver-dev-12-4 git RUN wget -P /tmp \ From addb80a849364891345fafd7cec8ec0b97dfb7e3 Mon Sep 17 00:00:00 2001 From: jnwei Date: Sat, 20 Apr 2024 16:36:54 +0700 Subject: [PATCH 25/34] changes to Dockerfile ane pin mkl to 2024 --- Dockerfile | 9 +++++---- environment.yml | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/Dockerfile b/Dockerfile index f7a0938..c048879 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 +FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04 # metainformation LABEL org.opencontainers.image.version = "2.0.0" @@ -7,11 +7,12 @@ LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfol LABEL org.opencontainers.image.licenses = "Apache License 2.0" LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:12.4.1-devel-ubuntu22.04" -RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb -RUN sudo dpkg -i cuda-keyring_1.0-1_all.deb - RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-12-4 libcusparse-dev-12-4 libcublas-dev-12-4 libcusolver-dev-12-4 git +RUN apt-key del 7fa2af80 +RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb +RUN dpkg -i cuda-keyring_1.0-1_all.deb + RUN wget -P /tmp \ "https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \ && bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \ diff --git a/environment.yml b/environment.yml index f15d63b..69c8326 100644 --- a/environment.yml +++ b/environment.yml @@ -5,7 +5,7 @@ channels: - pytorch - nvidia dependencies: - - python=3.9 + - python=3.10 - libgcc=7.2 - setuptools=59.5.0 - pip @@ -24,7 +24,7 @@ dependencies: - modelcif==0.7 - awscli - ml-collections - - mkl=2024.0.0 + - mkl=2024.0 - aria2 - git - bioconda::hmmer==3.3.2 From 793eb966b2606c9d09e72371457b271b7154e575 Mon Sep 17 00:00:00 2001 From: jnwei Date: Sat, 20 Apr 2024 16:52:06 +0700 Subject: [PATCH 26/34] adjust pytorch version number --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 69c8326..fb868df 100644 --- a/environment.yml +++ b/environment.yml @@ -30,7 +30,7 @@ dependencies: - bioconda::hmmer==3.3.2 - bioconda::hhsuite==3.3.0 - bioconda::kalign2==2.04 - - pytorch::pytorch=2.1 + - pytorch::pytorch=2.2 - pytorch::pytorch-cuda=12.1 - pip: - deepspeed==0.12.4 From ad34fc3c5d36a48f66c43b0858c52236d76afe58 Mon Sep 17 00:00:00 2001 From: Jennifer Date: Mon, 22 Apr 2024 02:55:09 -0400 Subject: [PATCH 27/34] updates Bio.PDBData call and environment.yml --- environment.yml | 6 +++--- openfold/data/mmcif_parsing.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/environment.yml b/environment.yml index fb868df..f5933ce 100644 --- a/environment.yml +++ b/environment.yml @@ -12,7 +12,7 @@ dependencies: - openmm=7.7 - pdbfixer - pytorch-lightning - - biopython==1.79 + - biopython - numpy - pandas - PyYAML==5.4.1 @@ -24,13 +24,13 @@ dependencies: - modelcif==0.7 - awscli - ml-collections - - mkl=2024.0 + - mkl=2022.1 - aria2 - git - bioconda::hmmer==3.3.2 - bioconda::hhsuite==3.3.0 - bioconda::kalign2==2.04 - - pytorch::pytorch=2.2 + - pytorch::pytorch=2.1 - pytorch::pytorch-cuda=12.1 - pip: - deepspeed==0.12.4 diff --git a/openfold/data/mmcif_parsing.py b/openfold/data/mmcif_parsing.py index 1bc5ef8..6ef17fb 100644 --- a/openfold/data/mmcif_parsing.py +++ b/openfold/data/mmcif_parsing.py @@ -24,7 +24,7 @@ import os from typing import Any, Mapping, Optional, Sequence, Tuple from Bio import PDB -from Bio.Data import SCOPData +from Bio.Data import PDBData import numpy as np from openfold.data.errors import MultipleChainsError @@ -283,7 +283,7 @@ def parse( author_chain = mmcif_to_author_chain_id[chain_id] seq = [] for monomer in seq_info: - code = SCOPData.protein_letters_3to1.get(monomer.id, "X") + code = PDBData.protein_letters_3to1.get(monomer.id, "X") seq.append(code if len(code) == 1 else "X") seq = "".join(seq) author_chain_to_sequence[author_chain] = seq From ed5261f880fa438bf73fa692944b8818e51d4f1f Mon Sep 17 00:00:00 2001 From: jnwei Date: Mon, 22 Apr 2024 14:32:22 +0700 Subject: [PATCH 28/34] Split cuda install commands in Dockerfile --- Dockerfile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index c048879..fcaeb56 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,12 +7,14 @@ LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfol LABEL org.opencontainers.image.licenses = "Apache License 2.0" LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:12.4.1-devel-ubuntu22.04" -RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-12-4 libcusparse-dev-12-4 libcublas-dev-12-4 libcusolver-dev-12-4 git +RUN apt-get update && apt-get install -y wget RUN apt-key del 7fa2af80 RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb RUN dpkg -i cuda-keyring_1.0-1_all.deb +RUN apt-get install -y libxml2 cuda-minimal-build-12-1 libcusparse-dev-12-1 libcublas-dev-12-1 libcusolver-dev-12-1 git + RUN wget -P /tmp \ "https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \ && bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \ From 0b11ced1d8264c5a17bfbe73c9d5adaa696b12f6 Mon Sep 17 00:00:00 2001 From: jnwei Date: Mon, 22 Apr 2024 14:55:43 +0700 Subject: [PATCH 29/34] change mamba version --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index fcaeb56..e0c4d87 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,7 @@ RUN dpkg -i cuda-keyring_1.0-1_all.deb RUN apt-get install -y libxml2 cuda-minimal-build-12-1 libcusparse-dev-12-1 libcublas-dev-12-1 libcusolver-dev-12-1 git RUN wget -P /tmp \ - "https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \ + "https://github.com/conda-forge/miniforge/releases/download/24.1.2-0/Miniforge3-Linux-x86_64.sh" \ && bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \ && rm /tmp/Miniforge3-Linux-x86_64.sh ENV PATH /opt/conda/bin:$PATH From 1d2237347e0201cac8f77d46a1e96d339349ca8d Mon Sep 17 00:00:00 2001 From: Jennifer Date: Mon, 22 Apr 2024 04:48:05 -0400 Subject: [PATCH 30/34] upgrading hmmer hhsuite and kalign2 packages --- environment.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/environment.yml b/environment.yml index f5933ce..cbabdfb 100644 --- a/environment.yml +++ b/environment.yml @@ -27,9 +27,9 @@ dependencies: - mkl=2022.1 - aria2 - git - - bioconda::hmmer==3.3.2 - - bioconda::hhsuite==3.3.0 - - bioconda::kalign2==2.04 + - bioconda::hmmer + - bioconda::hhsuite + - bioconda::kalign2 - pytorch::pytorch=2.1 - pytorch::pytorch-cuda=12.1 - pip: From cf0cc8bcc532ad93d98dd3bbd0b7247baa022905 Mon Sep 17 00:00:00 2001 From: Jennifer Date: Mon, 22 Apr 2024 05:00:06 -0400 Subject: [PATCH 31/34] small edit to Dockerfile --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index e0c4d87..72c3f7f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,7 @@ RUN dpkg -i cuda-keyring_1.0-1_all.deb RUN apt-get install -y libxml2 cuda-minimal-build-12-1 libcusparse-dev-12-1 libcublas-dev-12-1 libcusolver-dev-12-1 git RUN wget -P /tmp \ - "https://github.com/conda-forge/miniforge/releases/download/24.1.2-0/Miniforge3-Linux-x86_64.sh" \ + "https://github.com/conda-forge/miniforge/releases/download/23.11.0-0/Miniforge3-Linux-x86_64.sh" \ && bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \ && rm /tmp/Miniforge3-Linux-x86_64.sh ENV PATH /opt/conda/bin:$PATH From 12eb81bac5599394c3b1ed983b591ea0440446bb Mon Sep 17 00:00:00 2001 From: Jennifer Date: Mon, 22 Apr 2024 05:14:04 -0400 Subject: [PATCH 32/34] Reset miniforge version to 23.3.1-1 --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 72c3f7f..fcaeb56 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,7 @@ RUN dpkg -i cuda-keyring_1.0-1_all.deb RUN apt-get install -y libxml2 cuda-minimal-build-12-1 libcusparse-dev-12-1 libcublas-dev-12-1 libcusolver-dev-12-1 git RUN wget -P /tmp \ - "https://github.com/conda-forge/miniforge/releases/download/23.11.0-0/Miniforge3-Linux-x86_64.sh" \ + "https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \ && bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \ && rm /tmp/Miniforge3-Linux-x86_64.sh ENV PATH /opt/conda/bin:$PATH From 4ee9943e167302715a3fbd1a292e1d30c79ccdcd Mon Sep 17 00:00:00 2001 From: jnwei Date: Tue, 23 Apr 2024 11:37:05 +0700 Subject: [PATCH 33/34] Remove nvcc compute capability 37 which caused kernel build issues --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 9179856..4873570 100644 --- a/setup.py +++ b/setup.py @@ -52,7 +52,6 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_major, bare_metal_minor compute_capabilities = set([ - (3, 7), # K80, e.g. (5, 2), # Titan X (6, 1), # GeForce 1000-series ]) @@ -130,7 +129,7 @@ setup( classifiers=[ 'License :: OSI Approved :: Apache Software License', 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python :: 3.9,' + 'Programming Language :: Python :: 3.10,' 'Topic :: Scientific/Engineering :: Artificial Intelligence', ], ) From 76fb7ce62bd5a4355c56d0f1e42147e7ffc4728c Mon Sep 17 00:00:00 2001 From: Jennifer Wei Date: Mon, 6 May 2024 08:08:10 +0000 Subject: [PATCH 34/34] remove test print statements --- tests/test_permutation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_permutation.py b/tests/test_permutation.py index d0db977..740de44 100644 --- a/tests/test_permutation.py +++ b/tests/test_permutation.py @@ -113,7 +113,6 @@ class TestPermutation(unittest.TestCase): aligns, _ = compute_permutation_alignment(out, batch, batch) - print(f"##### aligns is {aligns}") possible_outcome = [[(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)], [(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]] wrong_outcome = [[(0, 1), (1, 0), (2, 4), (3, 2), (4, 3)], [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]] self.assertIn(aligns, possible_outcome) @@ -163,7 +162,6 @@ class TestPermutation(unittest.TestCase): aligns, per_asym_residue_index = compute_permutation_alignment(out, batch, batch) - print(f"##### aligns is {aligns}") labels = split_ground_truth_labels(batch) labels = merge_labels(per_asym_residue_index, labels, aligns,