mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
Add log statement to weight conversion script
This commit is contained in:
@@ -670,7 +670,6 @@ def import_jax_weights_(model, npz_path, version="model_1"):
|
||||
|
||||
def convert_deprecated_v1_keys(state_dict):
|
||||
"""Update older OpenFold model weight names to match the current model code."""
|
||||
logging.warning('converting keys...')
|
||||
|
||||
replacements = {
|
||||
'template_angle_embedder': 'template_single_embedder',
|
||||
@@ -683,27 +682,18 @@ def convert_deprecated_v1_keys(state_dict):
|
||||
}
|
||||
|
||||
convert_key_re = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys())))
|
||||
template_emb_re = re.compile("((module\\.)?(model\\.))?(template(?!_embedder).*)")
|
||||
template_emb_re = re.compile(r"^((module\.)?(model\.)?)(template(?!_embedder).*)")
|
||||
|
||||
converted_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
# For each match, look-up replacement value in the dictionary
|
||||
new_key = convert_key_re.sub(lambda m: replacements[m.group(1)], key)
|
||||
### DEBUG: remove before final commit
|
||||
if key == 'template_angle_embedder.linear_1.weight':
|
||||
logging.warning(f'old key: {key}, new_key: {new_key}')
|
||||
### DEBUG: remove before final commit
|
||||
|
||||
# Add prefix for template layers
|
||||
template_match = re.match(template_emb_re, new_key)
|
||||
if template_match:
|
||||
prefix = template_match.group(1)
|
||||
new_key = f'{prefix if prefix else ""}template_embedder.{template_match.group(4)}'
|
||||
# DEBUG: remove before final commit
|
||||
if key == 'template_angle_embedder.linear_1.weight':
|
||||
breakpoint()
|
||||
logging.warning(f'old key: {key}, new_key: {new_key}')
|
||||
### DEBUG: remove before final commit
|
||||
|
||||
converted_state_dict[new_key] = value
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ def convert_v1_to_v2_weights(args):
|
||||
if is_dir:
|
||||
# A DeepSpeed checkpoint
|
||||
logging.info(
|
||||
'Converting checkpoint found at {args.input_checkpoint_path}')
|
||||
'Converting deepspeed checkpoint found at {args.input_checkpoint_path}')
|
||||
state_dict_key = 'module'
|
||||
latest_path = os.path.join(checkpoint_path, 'latest')
|
||||
if os.path.isfile(latest_path):
|
||||
@@ -47,6 +47,8 @@ def convert_v1_to_v2_weights(args):
|
||||
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
|
||||
else:
|
||||
# A Pytorch Lightning checkpoint
|
||||
logging.info(
|
||||
'Converting pytorch lightning checkpoint found at {args.input_checkpoint_path}')
|
||||
state_dict_key = 'state_dict'
|
||||
model_output_path = args.output_ckpt_path
|
||||
model_file = checkpoint_path
|
||||
|
||||
@@ -289,16 +289,6 @@ def main(args):
|
||||
sd = torch.load(args.resume_from_ckpt)
|
||||
last_global_step = int(sd['global_step'])
|
||||
model_module.resume_last_lr_step(last_global_step)
|
||||
|
||||
### DEBUG:
|
||||
ds_checkpoint_dir = os.path.join(args.resume_from_ckpt, 'global_step210')
|
||||
optim_files = get_optim_files(ds_checkpoint_dir)
|
||||
zero_stage, _, _ = parse_optim_states(optim_files, ds_checkpoint_dir)
|
||||
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
|
||||
|
||||
model_dict = torch.load(model_file, map_location=torch.device('cpu'))
|
||||
###
|
||||
|
||||
logging.info("Successfully loaded last lr step...")
|
||||
if(args.resume_from_ckpt and args.resume_model_weights_only):
|
||||
if(os.path.isdir(args.resume_from_ckpt)):
|
||||
|
||||
Reference in New Issue
Block a user