cfg for bp_partners feature

This commit is contained in:
Raktim Mitra
2026-02-27 15:17:09 -08:00
committed by Raktim Mitra
parent bfe513ab17
commit 42680bdf1e
2 changed files with 12 additions and 6 deletions

View File

@@ -25,9 +25,10 @@ inference_sampler:
kind: "default" # "default" or "symmetry" to choose the sampler
# Classifier-free guidance args:
cfg_features: # set to 0 in the reference CFG step
- active_donor
- active_acceptor
- ref_atomwise_rasa
#- active_donor
#- active_acceptor
#- ref_atomwise_rasa
- bp_partners
use_classifier_free_guidance: False
cfg_t_max: null # max t to apply cfg guidance

View File

@@ -57,9 +57,14 @@ def strip_f(
# set the feature to default value if it is in the cfg_features
if k in cfg_features:
v_cropped = torch.zeros_like(v_cropped).to(
v_cropped.device, dtype=v_cropped.dtype
)
if k not in ["bp_partners"]:
v_cropped = torch.zeros_like(v_cropped).to(
v_cropped.device, dtype=v_cropped.dtype
)
else:
## for bp_partners default is a mask feature
v_cropped[:,:,0] = 1
v_cropped[:,:,1:] = 0
# update the feature in the dictionary
f_stripped[k] = v_cropped