Fix for loading old OF weights into refactored model

This commit is contained in:
Christina Floristean
2023-11-03 14:26:18 -04:00
parent 5fcd6ed221
commit f65b75fe48
6 changed files with 72 additions and 7 deletions

View File

@@ -26,6 +26,7 @@ from openfold.utils.import_weights import (
ParamType,
generate_translation_dict,
process_translation_dict,
import_openfold_weights_
)
from openfold.utils.tensor_utils import tree_map
@@ -63,7 +64,7 @@ def main(args):
config = model_config(args.config_preset)
model = AlphaFold(config)
model.load_state_dict(d)
import_openfold_weights_(model=model, state_dict=d)
translation = generate_translation_dict(model, args.config_preset)
translation = process_translation_dict(translation)