mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
cfg for bp_partners feature
This commit is contained in:
committed by
Raktim Mitra
parent
bfe513ab17
commit
42680bdf1e
@@ -25,9 +25,10 @@ inference_sampler:
|
||||
kind: "default" # "default" or "symmetry" to choose the sampler
|
||||
# Classifier-free guidance args:
|
||||
cfg_features: # set to 0 in the reference CFG step
|
||||
- active_donor
|
||||
- active_acceptor
|
||||
- ref_atomwise_rasa
|
||||
#- active_donor
|
||||
#- active_acceptor
|
||||
#- ref_atomwise_rasa
|
||||
- bp_partners
|
||||
|
||||
use_classifier_free_guidance: False
|
||||
cfg_t_max: null # max t to apply cfg guidance
|
||||
|
||||
@@ -57,9 +57,14 @@ def strip_f(
|
||||
|
||||
# set the feature to default value if it is in the cfg_features
|
||||
if k in cfg_features:
|
||||
v_cropped = torch.zeros_like(v_cropped).to(
|
||||
v_cropped.device, dtype=v_cropped.dtype
|
||||
)
|
||||
if k not in ["bp_partners"]:
|
||||
v_cropped = torch.zeros_like(v_cropped).to(
|
||||
v_cropped.device, dtype=v_cropped.dtype
|
||||
)
|
||||
else:
|
||||
## for bp_partners default is a mask feature
|
||||
v_cropped[:,:,0] = 1
|
||||
v_cropped[:,:,1:] = 0
|
||||
|
||||
# update the feature in the dictionary
|
||||
f_stripped[k] = v_cropped
|
||||
|
||||
Reference in New Issue
Block a user